Format pep8, include pass shaped to mobilenet

This commit is contained in:
Piv
2021-03-25 21:56:50 +10:30
parent 9449ddef01
commit 3325ea0c0c
2 changed files with 17 additions and 18 deletions

View File

@@ -22,7 +22,8 @@ def fix_windows_gpu():
for gpu in gpus: for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True) tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU') logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") print(len(gpus), "Physical GPUs,", len(
logical_gpus), "Logical GPUs")
except RuntimeError as e: except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized # Memory growth must be set before GPUs have been initialized
print(e) print(e)
@@ -47,7 +48,8 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
:return: FastDepth keras Model :return: FastDepth keras Model
""" """
input = keras.layers.Input(shape=shape) input = keras.layers.Input(shape=shape)
mobilenet = keras.applications.MobileNet(input_tensor=input, include_top=False, weights=weights) mobilenet = keras.applications.MobileNet(
input_shape=shape, input_tensor=input, include_top=False, weights=weights)
for layer in mobilenet.layers: for layer in mobilenet.layers:
layer.trainable = True layer.trainable = True
@@ -57,13 +59,16 @@ def mobilenet_nnconv5(weights=None, shape=(224, 224, 3)):
x = keras.layers.UpSampling2D()(x) x = keras.layers.UpSampling2D()(x)
x = FDDepthwiseBlock(x, 256, block_id=15) x = FDDepthwiseBlock(x, 256, block_id=15)
x = keras.layers.UpSampling2D()(x) x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_5_relu").output]) x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_5_relu").output])
x = FDDepthwiseBlock(x, 128, block_id=16) x = FDDepthwiseBlock(x, 128, block_id=16)
x = keras.layers.UpSampling2D()(x) x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_3_relu").output]) x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_3_relu").output])
x = FDDepthwiseBlock(x, 64, block_id=17) x = FDDepthwiseBlock(x, 64, block_id=17)
x = keras.layers.UpSampling2D()(x) x = keras.layers.UpSampling2D()(x)
x = keras.layers.Add()([x, mobilenet.get_layer(name="conv_pw_1_relu").output]) x = keras.layers.Add()(
[x, mobilenet.get_layer(name="conv_pw_1_relu").output])
x = FDDepthwiseBlock(x, 32, block_id=18) x = FDDepthwiseBlock(x, 32, block_id=18)
x = keras.layers.UpSampling2D()(x) x = keras.layers.UpSampling2D()(x)
@@ -163,8 +168,10 @@ def crop_and_resize(x):
def layer(): def layer():
return keras.Sequential([ return keras.Sequential([
keras.layers.experimental.preprocessing.CenterCrop(shape[1], shape[2]), keras.layers.experimental.preprocessing.CenterCrop(
keras.layers.experimental.preprocessing.Resizing(224, 224, interpolation='nearest') shape[1], shape[2]),
keras.layers.experimental.preprocessing.Resizing(
224, 224, interpolation='nearest')
]) ])
# Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow # Reshape label to 4d, can't use array unwrap as it's unsupported by tensorflow

View File

@@ -15,7 +15,7 @@
"orig_nbformat": 2, "orig_nbformat": 2,
"kernelspec": { "kernelspec": {
"name": "python3", "name": "python3",
"display_name": "Python 3.8.8 64-bit", "display_name": "Python 3.8.8 64-bit ('tensorflow2': conda)",
"metadata": { "metadata": {
"interpreter": { "interpreter": {
"hash": "ee99f7bd678359d45d92ad289bdab8f6bcfaae579cfd1bff07d2bb16d7ba024f" "hash": "ee99f7bd678359d45d92ad289bdab8f6bcfaae579cfd1bff07d2bb16d7ba024f"
@@ -70,15 +70,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"output_type": "stream",
"name": "stdout",
"text": [
"WARNING:tensorflow:`input_shape` is undefined or non-square, or `rows` is not in [128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default.\n"
]
}
],
"source": [ "source": [
"model = fd.mobilenet_nnconv5(weights='imagenet')\n", "model = fd.mobilenet_nnconv5(weights='imagenet')\n",
"fd.compile(model)" "fd.compile(model)"
@@ -102,7 +94,7 @@
}, },
{ {
"source": [ "source": [
"## Evaluate the model" "## Evaluate the trained model"
], ],
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {} "metadata": {}