Files
fast-depth-tf/util.py

22 lines
896 B
Python

import tensorflow as tf
import tensorflow.keras as keras
def crop_and_resize(x, out_shape=(224, 224)):
shape = tf.shape(x['depth'])
img_shape = tf.shape(x['image'])
# Ensure we get a square for when we resize it later.
# For horizontal images this is basically just cropping the sides off
center_shape = tf.minimum(shape[1], tf.minimum(shape[2], tf.minimum(img_shape[1], img_shape[2])))
def layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
center_shape, center_shape),
keras.layers.experimental.preprocessing.Resizing(
out_shape[0], out_shape[1], interpolation='nearest')
])
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))