Refactor load/util, start fixing packnet to support NHWC format

This commit is contained in:
Piv
2021-07-19 12:32:56 +09:30
parent d8bf493999
commit 38e7ad069e
6 changed files with 85 additions and 93 deletions

21
util.py Normal file
View File

@@ -0,0 +1,21 @@
import tensorflow as tf
import tensorflow.keras as keras
def crop_and_resize(x, out_shape=(224, 224)):
shape = tf.shape(x['depth'])
img_shape = tf.shape(x['image'])
# Ensure we get a square for when we resize is later.
# For horizontal images this is basically just cropping the sides off
center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2])
def layer():
return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(
center_shape, center_shape),
keras.layers.experimental.preprocessing.Resizing(
out_shape[0], out_shape[1], interpolation='nearest')
])
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))