Add spatial transformer network sampler

This commit is contained in:
Piv
2021-07-12 18:19:59 +09:30
parent f501beb6f2
commit 6f7da21977

View File

@@ -280,3 +280,75 @@ def bilinear_sampler(imgs, coords):
w10 * im10, w11 * im11
])
return output
# Spatial transformer network bilinear sampler, taken from https://github.com/kevinzakka/spatial-transformer-network/blob/master/stn/transformer.py
def stn_bilinear_sampler(img, x, y):
"""
Performs bilinear sampling of the input images according to the
normalized coordinates provided by the sampling grid. Note that
the sampling is done identically for each channel of the input.
To test if the function works properly, output image should be
identical to input image when theta is initialized to identity
transform.
Input
-----
- img: batch of images in (B, H, W, C) layout.
- grid: x, y which is the output of affine_grid_generator.
Returns
-------
- out: interpolated images according to grids. Same size as grid.
"""
H = tf.shape(img)[1]
W = tf.shape(img)[2]
max_y = tf.cast(H - 1, 'int32')
max_x = tf.cast(W - 1, 'int32')
zero = tf.zeros([], dtype='int32')
# rescale x and y to [0, W-1/H-1]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))
# grab 4 nearest corner points for each (x_i, y_i)
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
# clip to range [0, H-1/W-1] to not violate img boundaries
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
# get pixel value at corner coords
Ia = get_pixel_value(img, x0, y0)
Ib = get_pixel_value(img, x0, y1)
Ic = get_pixel_value(img, x1, y0)
Id = get_pixel_value(img, x1, y1)
# recast as float for delta calculation
x0 = tf.cast(x0, 'float32')
x1 = tf.cast(x1, 'float32')
y0 = tf.cast(y0, 'float32')
y1 = tf.cast(y1, 'float32')
# calculate deltas
wa = (x1-x) * (y1-y)
wb = (x1-x) * (y-y0)
wc = (x-x0) * (y1-y)
wd = (x-x0) * (y-y0)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# compute output
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return out