From 6f7da21977246f8522273f030697ac09a2313129 Mon Sep 17 00:00:00 2001 From: Piv <18462828+Piv200@users.noreply.github.com> Date: Mon, 12 Jul 2021 18:19:59 +0930 Subject: [PATCH] Add spatial transformer network sampler --- unsupervised/third-party/utils.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/unsupervised/third-party/utils.py b/unsupervised/third-party/utils.py index 0517c9e..22ca1cb 100644 --- a/unsupervised/third-party/utils.py +++ b/unsupervised/third-party/utils.py @@ -280,3 +280,75 @@ def bilinear_sampler(imgs, coords): w10 * im10, w11 * im11 ]) return output + +# Spatial transformer network bilinear sampler, taken from https://github.com/kevinzakka/spatial-transformer-network/blob/master/stn/transformer.py + + +def stn_bilinear_sampler(img, x, y): + """ + Performs bilinear sampling of the input images according to the + normalized coordinates provided by the sampling grid. Note that + the sampling is done identically for each channel of the input. + To test if the function works properly, output image should be + identical to input image when theta is initialized to identity + transform. + Input + ----- + - img: batch of images in (B, H, W, C) layout. + - grid: x, y which is the output of affine_grid_generator. + Returns + ------- + - out: interpolated images according to grids. Same size as grid. + """ + H = tf.shape(img)[1] + W = tf.shape(img)[2] + max_y = tf.cast(H - 1, 'int32') + max_x = tf.cast(W - 1, 'int32') + zero = tf.zeros([], dtype='int32') + + # rescale x and y to [0, W-1/H-1] + x = tf.cast(x, 'float32') + y = tf.cast(y, 'float32') + x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32')) + y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32')) + + # grab 4 nearest corner points for each (x_i, y_i) + x0 = tf.cast(tf.floor(x), 'int32') + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), 'int32') + y1 = y0 + 1 + + # clip to range [0, H-1/W-1] to not violate img boundaries + x0 = tf.clip_by_value(x0, zero, max_x) + x1 = tf.clip_by_value(x1, zero, max_x) + y0 = tf.clip_by_value(y0, zero, max_y) + y1 = tf.clip_by_value(y1, zero, max_y) + + # get pixel value at corner coords + Ia = get_pixel_value(img, x0, y0) + Ib = get_pixel_value(img, x0, y1) + Ic = get_pixel_value(img, x1, y0) + Id = get_pixel_value(img, x1, y1) + + # recast as float for delta calculation + x0 = tf.cast(x0, 'float32') + x1 = tf.cast(x1, 'float32') + y0 = tf.cast(y0, 'float32') + y1 = tf.cast(y1, 'float32') + + # calculate deltas + wa = (x1-x) * (y1-y) + wb = (x1-x) * (y-y0) + wc = (x-x0) * (y1-y) + wd = (x-x0) * (y-y0) + + # add dimension for addition + wa = tf.expand_dims(wa, axis=3) + wb = tf.expand_dims(wb, axis=3) + wc = tf.expand_dims(wc, axis=3) + wd = tf.expand_dims(wd, axis=3) + + # compute output + out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) + + return out