Add working train and eval functions for nyu_v2
This commit is contained in:
14
main.py
14
main.py
@@ -1,11 +1,7 @@
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
|
||||
def load_nyu():
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
return builder.as_dataset(split='train', shuffle_files=True)
|
||||
|
||||
import fast_depth_functional as fd
|
||||
|
||||
if __name__ == '__main__':
|
||||
load_nyu()
|
||||
fd.fix_windows_gpu()
|
||||
model = fd.load_model('fast_depth_nyu_v2_224_224_3_e1')
|
||||
fd.compile(model)
|
||||
fd.evaluate(model)
|
||||
|
||||
Reference in New Issue
Block a user