22 lines
896 B
Python
22 lines
896 B
Python
import tensorflow as tf
|
|
import tensorflow.keras as keras
|
|
|
|
|
|
def crop_and_resize(x, out_shape=(224, 224)):
|
|
shape = tf.shape(x['depth'])
|
|
img_shape = tf.shape(x['image'])
|
|
# Ensure we get a square for when we resize it later.
|
|
# For horizontal images this is basically just cropping the sides off
|
|
center_shape = tf.minimum(shape[1], tf.minimum(shape[2], tf.minimum(img_shape[1], img_shape[2])))
|
|
|
|
def layer():
|
|
return keras.Sequential([
|
|
keras.layers.experimental.preprocessing.CenterCrop(
|
|
center_shape, center_shape),
|
|
keras.layers.experimental.preprocessing.Resizing(
|
|
out_shape[0], out_shape[1], interpolation='nearest')
|
|
])
|
|
|
|
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
|
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|