import tensorflow as tf import tensorflow.keras as keras def crop_and_resize(x, out_shape=(224, 224)): shape = tf.shape(x['depth']) img_shape = tf.shape(x['image']) # Ensure we get a square for when we resize is later. # For horizontal images this is basically just cropping the sides off center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2]) def layer(): return keras.Sequential([ keras.layers.experimental.preprocessing.CenterCrop( center_shape, center_shape), keras.layers.experimental.preprocessing.Resizing( out_shape[0], out_shape[1], interpolation='nearest') ]) # Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))