Files
fast-depth-tf/main.py

22 lines
803 B
Python

import fast_depth_functional as fd
from unsupervised.models import pose_net, wrap_mobilenet_nnconv5_for_utrain
from unsupervised.train import UnsupervisedPoseDepthLearner
if __name__ == '__main__':
fd.fix_windows_gpu()
model = fd.mobilenet_nnconv5(weights='imagenet')
fd.compile(model)
fd.train(existing_model=model, save_file='../fast-depth-experimental')
fd.evaluate(model)
# Save in Tensorflow SavedModel format
# tf.saved_model.save(model, 'fast_depth_nyu_v2_224_224_3_e1_saved_model')
# Unsupervised
depth_model = fd.mobilenet_nnconv5()
pose_model = pose_net()
model = UnsupervisedPoseDepthLearner(wrap_mobilenet_nnconv5_for_utrain(depth_model), pose_model)
model.compile(optimizer='adam')
# TODO: Incorporate data generator
# model.fit()