Refactor load/util, start fixing packnet to support NHWC format
This commit is contained in:
21
util.py
Normal file
21
util.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
|
||||
|
||||
def crop_and_resize(x, out_shape=(224, 224)):
|
||||
shape = tf.shape(x['depth'])
|
||||
img_shape = tf.shape(x['image'])
|
||||
# Ensure we get a square for when we resize is later.
|
||||
# For horizontal images this is basically just cropping the sides off
|
||||
center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2])
|
||||
|
||||
def layer():
|
||||
return keras.Sequential([
|
||||
keras.layers.experimental.preprocessing.CenterCrop(
|
||||
center_shape, center_shape),
|
||||
keras.layers.experimental.preprocessing.Resizing(
|
||||
out_shape[0], out_shape[1], interpolation='nearest')
|
||||
])
|
||||
|
||||
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
|
||||
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
||||
Reference in New Issue
Block a user