# MIT License # # Copyright (c) 2019 Somshubra Majumdar # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # Taken from: https://github.com/titu1994/Keras-Group-Normalization/blob/master/group_norm.py from tensorflow.keras import backend as K from tensorflow.keras import constraints from tensorflow.keras import initializers from tensorflow.keras import regularizers from tensorflow.keras.layers import Layer, InputSpec class GroupNormalization(Layer): """Group normalization layer Group Normalization divides the channels into groups and computes within each group the mean and variance for normalization. GN's computation is independent of batch sizes, and its accuracy is stable in a wide range of batch sizes # Arguments groups: Integer, the number of groups for Group Normalization. axis: Integer, the axis that should be normalized (typically the features axis). For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. epsilon: Small float added to variance to avoid dividing by zero. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling will be done by the next layer. beta_initializer: Initializer for the beta weight. gamma_initializer: Initializer for the gamma weight. beta_regularizer: Optional regularizer for the beta weight. gamma_regularizer: Optional regularizer for the gamma weight. beta_constraint: Optional constraint for the beta weight. gamma_constraint: Optional constraint for the gamma weight. # Input shape Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. # Output shape Same shape as input. # References - [Group Normalization](https://arxiv.org/abs/1803.08494) """ def __init__(self, groups=32, axis=-1, epsilon=1e-5, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', beta_regularizer=None, gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs): super(GroupNormalization, self).__init__(**kwargs) self.supports_masking = True self.groups = groups self.axis = axis self.epsilon = epsilon self.center = center self.scale = scale self.beta_initializer = initializers.get(beta_initializer) self.gamma_initializer = initializers.get(gamma_initializer) self.beta_regularizer = regularizers.get(beta_regularizer) self.gamma_regularizer = regularizers.get(gamma_regularizer) self.beta_constraint = constraints.get(beta_constraint) self.gamma_constraint = constraints.get(gamma_constraint) self.gamma = None self.beta = None def build(self, input_shape): dim = input_shape[self.axis] if dim is None: raise ValueError('Axis ' + str(self.axis) + ' of ' 'input tensor should have a defined dimension ' 'but the layer received an input with shape ' + str(input_shape) + '.') if dim < self.groups: raise ValueError('Number of groups (' + str(self.groups) + ') cannot be ' 'more than the number of channels (' + str(dim) + ').') if dim % self.groups != 0: raise ValueError('Number of groups (' + str(self.groups) + ') must be a ' 'multiple of the number of channels (' + str(dim) + ').') self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: dim}) shape = (dim,) if self.scale: self.gamma = self.add_weight(shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint) if self.center: self.beta = self.add_weight(shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint) self.built = True def call(self, inputs, **kwargs): input_shape = K.int_shape(inputs) tensor_input_shape = K.shape(inputs) # Prepare broadcasting shape. reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] // self.groups broadcast_shape.insert(1, self.groups) reshape_group_shape = K.shape(inputs) group_axes = [reshape_group_shape[i] for i in range(len(input_shape))] group_axes[self.axis] = input_shape[self.axis] // self.groups group_axes.insert(1, self.groups) # reshape inputs to new group shape group_shape = [group_axes[0], self.groups] + group_axes[2:] group_shape = K.stack(group_shape) inputs = K.reshape(inputs, group_shape) group_reduction_axes = list(range(len(group_axes))) group_reduction_axes = group_reduction_axes[2:] mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True) variance = K.var(inputs, axis=group_reduction_axes, keepdims=True) inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) # prepare broadcast shape inputs = K.reshape(inputs, group_shape) outputs = inputs # In this case we must explicitly broadcast all parameters. if self.scale: broadcast_gamma = K.reshape(self.gamma, broadcast_shape) outputs = outputs * broadcast_gamma if self.center: broadcast_beta = K.reshape(self.beta, broadcast_shape) outputs = outputs + broadcast_beta outputs = K.reshape(outputs, tensor_input_shape) return outputs def get_config(self): config = { 'groups': self.groups, 'axis': self.axis, 'epsilon': self.epsilon, 'center': self.center, 'scale': self.scale, 'beta_initializer': initializers.serialize(self.beta_initializer), 'gamma_initializer': initializers.serialize(self.gamma_initializer), 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_constraint': constraints.serialize(self.gamma_constraint) } base_config = super(GroupNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): return input_shape if __name__ == '__main__': from tensorflow.keras.layers import Input from tensorflow.keras.models import Model ip = Input(shape=(None, None, 4)) # ip = Input(batch_shape=(100, None, None, 2)) x = GroupNormalization(groups=2, axis=-1, epsilon=0.1)(ip) model = Model(ip, x) model.summary()