From 19940c12b771dc6148d07e73506428b1fff9c434 Mon Sep 17 00:00:00 2001 From: Michael Pivato Date: Fri, 1 Mar 2019 15:49:01 +1030 Subject: [PATCH] Add example of inferencing with keras using plaidml. --- GestureRecognition/keras_ex.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 GestureRecognition/keras_ex.py diff --git a/GestureRecognition/keras_ex.py b/GestureRecognition/keras_ex.py new file mode 100644 index 0000000..150ed6f --- /dev/null +++ b/GestureRecognition/keras_ex.py @@ -0,0 +1,28 @@ +import time +import os + +import numpy as np + +os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" + +import keras +import keras.applications as kapp +from keras.datasets import cifar10 + +(x_train, y_train_cats), (x_test, y_test_cats) = cifar10.load_data() +batch_size = 8 +x_train = x_train[:batch_size] +x_train = np.repeat(np.repeat(x_train, 7, axis=1), 7, axis=2) +model = kapp.VGG19() +model.compile(optimizer='sgd', loss='categorical_crossentropy', + metrics=['accuracy']) + +print("Running initial batch (compiling tile program)") +y = model.predict(x=x_train, batch_size=batch_size) + +# Now start the clock and run 10 batches +print("Timing inference...") +start = time.time() +for i in range(10): + y = model.predict(x=x_train, batch_size=batch_size) +print("Ran in {} seconds".format(time.time() - start)) \ No newline at end of file