Add spatial transformer network sampler
This commit is contained in:
72
unsupervised/third-party/utils.py
vendored
72
unsupervised/third-party/utils.py
vendored
@@ -280,3 +280,75 @@ def bilinear_sampler(imgs, coords):
|
||||
w10 * im10, w11 * im11
|
||||
])
|
||||
return output
|
||||
|
||||
# Spatial transformer network bilinear sampler, taken from https://github.com/kevinzakka/spatial-transformer-network/blob/master/stn/transformer.py
|
||||
|
||||
|
||||
def stn_bilinear_sampler(img, x, y):
|
||||
"""
|
||||
Performs bilinear sampling of the input images according to the
|
||||
normalized coordinates provided by the sampling grid. Note that
|
||||
the sampling is done identically for each channel of the input.
|
||||
To test if the function works properly, output image should be
|
||||
identical to input image when theta is initialized to identity
|
||||
transform.
|
||||
Input
|
||||
-----
|
||||
- img: batch of images in (B, H, W, C) layout.
|
||||
- grid: x, y which is the output of affine_grid_generator.
|
||||
Returns
|
||||
-------
|
||||
- out: interpolated images according to grids. Same size as grid.
|
||||
"""
|
||||
H = tf.shape(img)[1]
|
||||
W = tf.shape(img)[2]
|
||||
max_y = tf.cast(H - 1, 'int32')
|
||||
max_x = tf.cast(W - 1, 'int32')
|
||||
zero = tf.zeros([], dtype='int32')
|
||||
|
||||
# rescale x and y to [0, W-1/H-1]
|
||||
x = tf.cast(x, 'float32')
|
||||
y = tf.cast(y, 'float32')
|
||||
x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
|
||||
y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))
|
||||
|
||||
# grab 4 nearest corner points for each (x_i, y_i)
|
||||
x0 = tf.cast(tf.floor(x), 'int32')
|
||||
x1 = x0 + 1
|
||||
y0 = tf.cast(tf.floor(y), 'int32')
|
||||
y1 = y0 + 1
|
||||
|
||||
# clip to range [0, H-1/W-1] to not violate img boundaries
|
||||
x0 = tf.clip_by_value(x0, zero, max_x)
|
||||
x1 = tf.clip_by_value(x1, zero, max_x)
|
||||
y0 = tf.clip_by_value(y0, zero, max_y)
|
||||
y1 = tf.clip_by_value(y1, zero, max_y)
|
||||
|
||||
# get pixel value at corner coords
|
||||
Ia = get_pixel_value(img, x0, y0)
|
||||
Ib = get_pixel_value(img, x0, y1)
|
||||
Ic = get_pixel_value(img, x1, y0)
|
||||
Id = get_pixel_value(img, x1, y1)
|
||||
|
||||
# recast as float for delta calculation
|
||||
x0 = tf.cast(x0, 'float32')
|
||||
x1 = tf.cast(x1, 'float32')
|
||||
y0 = tf.cast(y0, 'float32')
|
||||
y1 = tf.cast(y1, 'float32')
|
||||
|
||||
# calculate deltas
|
||||
wa = (x1-x) * (y1-y)
|
||||
wb = (x1-x) * (y-y0)
|
||||
wc = (x-x0) * (y1-y)
|
||||
wd = (x-x0) * (y-y0)
|
||||
|
||||
# add dimension for addition
|
||||
wa = tf.expand_dims(wa, axis=3)
|
||||
wb = tf.expand_dims(wb, axis=3)
|
||||
wc = tf.expand_dims(wc, axis=3)
|
||||
wd = tf.expand_dims(wd, axis=3)
|
||||
|
||||
# compute output
|
||||
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user