210 lines
9.0 KiB
Python
210 lines
9.0 KiB
Python
# MIT License
|
|
#
|
|
# Copyright (c) 2019 Somshubra Majumdar
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
# Taken from: https://github.com/titu1994/Keras-Group-Normalization/blob/master/group_norm.py
|
|
|
|
from tensorflow.keras import backend as K
|
|
from tensorflow.keras import constraints
|
|
from tensorflow.keras import initializers
|
|
from tensorflow.keras import regularizers
|
|
from tensorflow.keras.layers import Layer, InputSpec
|
|
|
|
|
|
class GroupNormalization(Layer):
|
|
"""Group normalization layer
|
|
Group Normalization divides the channels into groups and computes within each group
|
|
the mean and variance for normalization. GN's computation is independent of batch sizes,
|
|
and its accuracy is stable in a wide range of batch sizes
|
|
# Arguments
|
|
groups: Integer, the number of groups for Group Normalization.
|
|
axis: Integer, the axis that should be normalized
|
|
(typically the features axis).
|
|
For instance, after a `Conv2D` layer with
|
|
`data_format="channels_first"`,
|
|
set `axis=1` in `BatchNormalization`.
|
|
epsilon: Small float added to variance to avoid dividing by zero.
|
|
center: If True, add offset of `beta` to normalized tensor.
|
|
If False, `beta` is ignored.
|
|
scale: If True, multiply by `gamma`.
|
|
If False, `gamma` is not used.
|
|
When the next layer is linear (also e.g. `nn.relu`),
|
|
this can be disabled since the scaling
|
|
will be done by the next layer.
|
|
beta_initializer: Initializer for the beta weight.
|
|
gamma_initializer: Initializer for the gamma weight.
|
|
beta_regularizer: Optional regularizer for the beta weight.
|
|
gamma_regularizer: Optional regularizer for the gamma weight.
|
|
beta_constraint: Optional constraint for the beta weight.
|
|
gamma_constraint: Optional constraint for the gamma weight.
|
|
# Input shape
|
|
Arbitrary. Use the keyword argument `input_shape`
|
|
(tuple of integers, does not include the samples axis)
|
|
when using this layer as the first layer in a model.
|
|
# Output shape
|
|
Same shape as input.
|
|
# References
|
|
- [Group Normalization](https://arxiv.org/abs/1803.08494)
|
|
"""
|
|
|
|
def __init__(self,
|
|
groups=32,
|
|
axis=-1,
|
|
epsilon=1e-5,
|
|
center=True,
|
|
scale=True,
|
|
beta_initializer='zeros',
|
|
gamma_initializer='ones',
|
|
beta_regularizer=None,
|
|
gamma_regularizer=None,
|
|
beta_constraint=None,
|
|
gamma_constraint=None,
|
|
**kwargs):
|
|
super(GroupNormalization, self).__init__(**kwargs)
|
|
self.supports_masking = True
|
|
self.groups = groups
|
|
self.axis = axis
|
|
self.epsilon = epsilon
|
|
self.center = center
|
|
self.scale = scale
|
|
self.beta_initializer = initializers.get(beta_initializer)
|
|
self.gamma_initializer = initializers.get(gamma_initializer)
|
|
self.beta_regularizer = regularizers.get(beta_regularizer)
|
|
self.gamma_regularizer = regularizers.get(gamma_regularizer)
|
|
self.beta_constraint = constraints.get(beta_constraint)
|
|
self.gamma_constraint = constraints.get(gamma_constraint)
|
|
self.gamma = None
|
|
self.beta = None
|
|
|
|
def build(self, input_shape):
|
|
dim = input_shape[self.axis]
|
|
|
|
if dim is None:
|
|
raise ValueError('Axis ' + str(self.axis) + ' of '
|
|
'input tensor should have a defined dimension '
|
|
'but the layer received an input with shape ' +
|
|
str(input_shape) + '.')
|
|
|
|
if dim < self.groups:
|
|
raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
|
|
'more than the number of channels (' +
|
|
str(dim) + ').')
|
|
|
|
if dim % self.groups != 0:
|
|
raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
|
|
'multiple of the number of channels (' +
|
|
str(dim) + ').')
|
|
|
|
self.input_spec = InputSpec(ndim=len(input_shape),
|
|
axes={self.axis: dim})
|
|
shape = (dim,)
|
|
|
|
if self.scale:
|
|
self.gamma = self.add_weight(shape=shape,
|
|
name='gamma',
|
|
initializer=self.gamma_initializer,
|
|
regularizer=self.gamma_regularizer,
|
|
constraint=self.gamma_constraint)
|
|
if self.center:
|
|
self.beta = self.add_weight(shape=shape,
|
|
name='beta',
|
|
initializer=self.beta_initializer,
|
|
regularizer=self.beta_regularizer,
|
|
constraint=self.beta_constraint)
|
|
self.built = True
|
|
|
|
def call(self, inputs, **kwargs):
|
|
input_shape = K.int_shape(inputs)
|
|
tensor_input_shape = K.shape(inputs)
|
|
|
|
# Prepare broadcasting shape.
|
|
reduction_axes = list(range(len(input_shape)))
|
|
del reduction_axes[self.axis]
|
|
broadcast_shape = [1] * len(input_shape)
|
|
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
|
|
broadcast_shape.insert(1, self.groups)
|
|
|
|
reshape_group_shape = K.shape(inputs)
|
|
group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
|
|
group_axes[self.axis] = input_shape[self.axis] // self.groups
|
|
group_axes.insert(1, self.groups)
|
|
|
|
# reshape inputs to new group shape
|
|
group_shape = [group_axes[0], self.groups] + group_axes[2:]
|
|
group_shape = K.stack(group_shape)
|
|
inputs = K.reshape(inputs, group_shape)
|
|
|
|
group_reduction_axes = list(range(len(group_axes)))
|
|
group_reduction_axes = group_reduction_axes[2:]
|
|
|
|
mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True)
|
|
variance = K.var(inputs, axis=group_reduction_axes, keepdims=True)
|
|
|
|
inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
|
|
|
|
# prepare broadcast shape
|
|
inputs = K.reshape(inputs, group_shape)
|
|
outputs = inputs
|
|
|
|
# In this case we must explicitly broadcast all parameters.
|
|
if self.scale:
|
|
broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
|
|
outputs = outputs * broadcast_gamma
|
|
|
|
if self.center:
|
|
broadcast_beta = K.reshape(self.beta, broadcast_shape)
|
|
outputs = outputs + broadcast_beta
|
|
|
|
outputs = K.reshape(outputs, tensor_input_shape)
|
|
|
|
return outputs
|
|
|
|
def get_config(self):
|
|
config = {
|
|
'groups': self.groups,
|
|
'axis': self.axis,
|
|
'epsilon': self.epsilon,
|
|
'center': self.center,
|
|
'scale': self.scale,
|
|
'beta_initializer': initializers.serialize(self.beta_initializer),
|
|
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
|
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
|
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
|
'beta_constraint': constraints.serialize(self.beta_constraint),
|
|
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
|
}
|
|
base_config = super(GroupNormalization, self).get_config()
|
|
return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
def compute_output_shape(self, input_shape):
|
|
return input_shape
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from tensorflow.keras.layers import Input
|
|
from tensorflow.keras.models import Model
|
|
|
|
ip = Input(shape=(None, None, 4))
|
|
# ip = Input(batch_shape=(100, None, None, 2))
|
|
x = GroupNormalization(groups=2, axis=-1, epsilon=0.1)(ip)
|
|
model = Model(ip, x)
|
|
model.summary()
|