Add spatial transformer network sampler
This commit is contained in:
72
unsupervised/third-party/utils.py
vendored
72
unsupervised/third-party/utils.py
vendored
@@ -280,3 +280,75 @@ def bilinear_sampler(imgs, coords):
|
|||||||
w10 * im10, w11 * im11
|
w10 * im10, w11 * im11
|
||||||
])
|
])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
# Spatial transformer network bilinear sampler, taken from https://github.com/kevinzakka/spatial-transformer-network/blob/master/stn/transformer.py
|
||||||
|
|
||||||
|
|
||||||
|
def stn_bilinear_sampler(img, x, y):
|
||||||
|
"""
|
||||||
|
Performs bilinear sampling of the input images according to the
|
||||||
|
normalized coordinates provided by the sampling grid. Note that
|
||||||
|
the sampling is done identically for each channel of the input.
|
||||||
|
To test if the function works properly, output image should be
|
||||||
|
identical to input image when theta is initialized to identity
|
||||||
|
transform.
|
||||||
|
Input
|
||||||
|
-----
|
||||||
|
- img: batch of images in (B, H, W, C) layout.
|
||||||
|
- grid: x, y which is the output of affine_grid_generator.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
- out: interpolated images according to grids. Same size as grid.
|
||||||
|
"""
|
||||||
|
H = tf.shape(img)[1]
|
||||||
|
W = tf.shape(img)[2]
|
||||||
|
max_y = tf.cast(H - 1, 'int32')
|
||||||
|
max_x = tf.cast(W - 1, 'int32')
|
||||||
|
zero = tf.zeros([], dtype='int32')
|
||||||
|
|
||||||
|
# rescale x and y to [0, W-1/H-1]
|
||||||
|
x = tf.cast(x, 'float32')
|
||||||
|
y = tf.cast(y, 'float32')
|
||||||
|
x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
|
||||||
|
y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))
|
||||||
|
|
||||||
|
# grab 4 nearest corner points for each (x_i, y_i)
|
||||||
|
x0 = tf.cast(tf.floor(x), 'int32')
|
||||||
|
x1 = x0 + 1
|
||||||
|
y0 = tf.cast(tf.floor(y), 'int32')
|
||||||
|
y1 = y0 + 1
|
||||||
|
|
||||||
|
# clip to range [0, H-1/W-1] to not violate img boundaries
|
||||||
|
x0 = tf.clip_by_value(x0, zero, max_x)
|
||||||
|
x1 = tf.clip_by_value(x1, zero, max_x)
|
||||||
|
y0 = tf.clip_by_value(y0, zero, max_y)
|
||||||
|
y1 = tf.clip_by_value(y1, zero, max_y)
|
||||||
|
|
||||||
|
# get pixel value at corner coords
|
||||||
|
Ia = get_pixel_value(img, x0, y0)
|
||||||
|
Ib = get_pixel_value(img, x0, y1)
|
||||||
|
Ic = get_pixel_value(img, x1, y0)
|
||||||
|
Id = get_pixel_value(img, x1, y1)
|
||||||
|
|
||||||
|
# recast as float for delta calculation
|
||||||
|
x0 = tf.cast(x0, 'float32')
|
||||||
|
x1 = tf.cast(x1, 'float32')
|
||||||
|
y0 = tf.cast(y0, 'float32')
|
||||||
|
y1 = tf.cast(y1, 'float32')
|
||||||
|
|
||||||
|
# calculate deltas
|
||||||
|
wa = (x1-x) * (y1-y)
|
||||||
|
wb = (x1-x) * (y-y0)
|
||||||
|
wc = (x-x0) * (y1-y)
|
||||||
|
wd = (x-x0) * (y-y0)
|
||||||
|
|
||||||
|
# add dimension for addition
|
||||||
|
wa = tf.expand_dims(wa, axis=3)
|
||||||
|
wb = tf.expand_dims(wb, axis=3)
|
||||||
|
wc = tf.expand_dims(wc, axis=3)
|
||||||
|
wd = tf.expand_dims(wd, axis=3)
|
||||||
|
|
||||||
|
# compute output
|
||||||
|
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user