diff --git a/fast_depth_functional.py b/fast_depth_functional.py index 4da5289..b27aceb 100644 --- a/fast_depth_functional.py +++ b/fast_depth_functional.py @@ -1,7 +1,8 @@ +import tensorflow as tf import tensorflow.keras as keras ''' -Functional version of fastdepth model. Note that this doesn't work at the moment +Functional version of fastdepth model ''' @@ -16,11 +17,11 @@ def FDDepthwiseBlock(inputs, return keras.layers.ReLU(6., name='conv_pw_%d_relu' % block_id)(x) -def make_fastdepth_functional(): +def make_fastdepth_functional(weights=None): # This doesn't work, at least right now... input = keras.layers.Input(shape=(224, 224, 3)) x = input - mobilenet = keras.applications.MobileNet(include_top=False, weights=None) + mobilenet = keras.applications.MobileNet(include_top=False, weights=weights) for layer in mobilenet.layers: x = layer(x) if layer.name == 'conv_pw_5_relu': @@ -30,6 +31,7 @@ def make_fastdepth_functional(): elif layer.name == 'conv_pw_1_relu': conv1 = x + # Fast depth decoder x = FDDepthwiseBlock(x, 512, block_id=14) # TODO: Bilinear interpolation # x = keras.layers.experimental.preprocessing.Resizing(14, 14) @@ -53,5 +55,43 @@ def make_fastdepth_functional(): return keras.Model(inputs=input, outputs=x, name="fast_depth") +def delta1_metric(y_true, y_pred): + maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred) + return float((maxRatio < 1.25).float().mean()) + + +def delta2_metric(y_true, y_pred): + maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred) + return float((maxRatio < 1.2 ** 25).float().mean()) + + +def delta3_metric(y_true, y_pred): + maxRatio = tf.maximum(y_pred / y_true, y_true / y_pred) + return float((maxRatio < 1.25 ** 3).float().mean()) + + +def fastdepth_for_training(): + # Pretrained mobilenet on imagenet dataset + model = make_fastdepth_functional('imagenet') + return model.compile(optimizer=keras.optimizers.SGD(momentum=0.9), + loss=keras.losses.MSE(), + metrics=[keras.metrics.RootMeanSquaredError(), + keras.metrics.MeanSquaredError(), + delta1_metric, + delta2_metric, + delta3_metric]) + + +def train_compiled_model(compiled_model, dataset): + """ + + :param compiled_model: Compiled model to train on + :param dataset: Dataset to train on (must be compatible with model + :return: + """ + # TODO: Use tf nyu_v2 dataset to train. + pass + + if __name__ == '__main__': make_fastdepth_functional().summary() diff --git a/main.py b/main.py index 62b1885..2bdb6ca 100644 --- a/main.py +++ b/main.py @@ -7,13 +7,5 @@ def load_nyu(): return builder.as_dataset(split='train', shuffle_files=True) -def print_hi(name): - # Use a breakpoint in the code line below to debug your script. - print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. - - -# Press the green button in the gutter to run the script. if __name__ == '__main__': load_nyu() - -# See PyCharm help at https://www.jetbrains.com/help/pycharm/