1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Locally-connected layers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.keras import activations 22from tensorflow.python.keras import backend as K 23from tensorflow.python.keras import constraints 24from tensorflow.python.keras import initializers 25from tensorflow.python.keras import regularizers 26from tensorflow.python.keras.engine.base_layer import Layer 27from tensorflow.python.keras.engine.input_spec import InputSpec 28from tensorflow.python.keras.utils import conv_utils 29from tensorflow.python.keras.utils import tf_utils 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export('keras.layers.LocallyConnected1D') 34class LocallyConnected1D(Layer): 35 """Locally-connected layer for 1D inputs. 36 37 The `LocallyConnected1D` layer works similarly to 38 the `Conv1D` layer, except that weights are unshared, 39 that is, a different set of filters is applied at each different patch 40 of the input. 41 42 Example: 43 ```python 44 # apply a unshared weight convolution 1d of length 3 to a sequence with 45 # 10 timesteps, with 64 output filters 46 model = Sequential() 47 model.add(LocallyConnected1D(64, 3, input_shape=(10, 32))) 48 # now model.output_shape == (None, 8, 64) 49 # add a new conv1d on top 50 model.add(LocallyConnected1D(32, 3)) 51 # now model.output_shape == (None, 6, 32) 52 ``` 53 54 Arguments: 55 filters: Integer, the dimensionality of the output space 56 (i.e. the number of output filters in the convolution). 57 kernel_size: An integer or tuple/list of a single integer, 58 specifying the length of the 1D convolution window. 59 strides: An integer or tuple/list of a single integer, 60 specifying the stride length of the convolution. 61 Specifying any stride value != 1 is incompatible with specifying 62 any `dilation_rate` value != 1. 63 padding: Currently only supports `"valid"` (case-insensitive). 64 `"same"` may be supported in the future. 65 data_format: A string, 66 one of `channels_last` (default) or `channels_first`. 67 The ordering of the dimensions in the inputs. 68 `channels_last` corresponds to inputs with shape 69 `(batch, length, channels)` while `channels_first` 70 corresponds to inputs with shape 71 `(batch, channels, length)`. 72 It defaults to the `image_data_format` value found in your 73 Keras config file at `~/.keras/keras.json`. 74 If you never set it, then it will be "channels_last". 75 activation: Activation function to use. 76 If you don't specify anything, no activation is applied 77 (ie. "linear" activation: `a(x) = x`). 78 use_bias: Boolean, whether the layer uses a bias vector. 79 kernel_initializer: Initializer for the `kernel` weights matrix. 80 bias_initializer: Initializer for the bias vector. 81 kernel_regularizer: Regularizer function applied to 82 the `kernel` weights matrix. 83 bias_regularizer: Regularizer function applied to the bias vector. 84 activity_regularizer: Regularizer function applied to 85 the output of the layer (its "activation").. 86 kernel_constraint: Constraint function applied to the kernel matrix. 87 bias_constraint: Constraint function applied to the bias vector. 88 implementation: implementation mode, either `1` or `2`. 89 `1` loops over input spatial locations to perform the forward pass. 90 It is memory-efficient but performs a lot of (small) ops. 91 92 `2` stores layer weights in a dense but sparsely-populated 2D matrix 93 and implements the forward pass as a single matrix-multiply. It uses 94 a lot of RAM but performs few (large) ops. 95 96 Depending on the inputs, layer parameters, hardware, and 97 `tf.executing_eagerly()` one implementation can be dramatically faster 98 (e.g. 50X) than another. 99 100 It is recommended to benchmark both in the setting of interest to pick 101 the most efficient one (in terms of speed and memory usage). 102 103 Following scenarios could benefit from setting `implementation=2`: 104 - eager execution; 105 - inference; 106 - running on CPU; 107 - large amount of RAM available; 108 - small models (few filters, small kernel); 109 - using `padding=same` (only possible with `implementation=2`). 110 111 Input shape: 112 3D tensor with shape: `(batch_size, steps, input_dim)` 113 114 Output shape: 115 3D tensor with shape: `(batch_size, new_steps, filters)` 116 `steps` value might have changed due to padding or strides. 117 """ 118 119 def __init__(self, 120 filters, 121 kernel_size, 122 strides=1, 123 padding='valid', 124 data_format=None, 125 activation=None, 126 use_bias=True, 127 kernel_initializer='glorot_uniform', 128 bias_initializer='zeros', 129 kernel_regularizer=None, 130 bias_regularizer=None, 131 activity_regularizer=None, 132 kernel_constraint=None, 133 bias_constraint=None, 134 implementation=1, 135 **kwargs): 136 super(LocallyConnected1D, self).__init__(**kwargs) 137 self.filters = filters 138 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size') 139 self.strides = conv_utils.normalize_tuple(strides, 1, 'strides') 140 self.padding = conv_utils.normalize_padding(padding) 141 if self.padding != 'valid' and implementation == 1: 142 raise ValueError('Invalid border mode for LocallyConnected1D ' 143 '(only "valid" is supported if implementation is 1): ' 144 + padding) 145 self.data_format = conv_utils.normalize_data_format(data_format) 146 self.activation = activations.get(activation) 147 self.use_bias = use_bias 148 self.kernel_initializer = initializers.get(kernel_initializer) 149 self.bias_initializer = initializers.get(bias_initializer) 150 self.kernel_regularizer = regularizers.get(kernel_regularizer) 151 self.bias_regularizer = regularizers.get(bias_regularizer) 152 self.activity_regularizer = regularizers.get(activity_regularizer) 153 self.kernel_constraint = constraints.get(kernel_constraint) 154 self.bias_constraint = constraints.get(bias_constraint) 155 self.implementation = implementation 156 self.input_spec = InputSpec(ndim=3) 157 158 @tf_utils.shape_type_conversion 159 def build(self, input_shape): 160 if self.data_format == 'channels_first': 161 input_dim, input_length = input_shape[1], input_shape[2] 162 else: 163 input_dim, input_length = input_shape[2], input_shape[1] 164 165 if input_dim is None: 166 raise ValueError('Axis 2 of input should be fully-defined. ' 167 'Found shape:', input_shape) 168 self.output_length = conv_utils.conv_output_length( 169 input_length, self.kernel_size[0], self.padding, self.strides[0]) 170 171 if self.implementation == 1: 172 self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim, 173 self.filters) 174 175 self.kernel = self.add_weight( 176 shape=self.kernel_shape, 177 initializer=self.kernel_initializer, 178 name='kernel', 179 regularizer=self.kernel_regularizer, 180 constraint=self.kernel_constraint) 181 182 elif self.implementation == 2: 183 if self.data_format == 'channels_first': 184 self.kernel_shape = (input_dim, input_length, 185 self.filters, self.output_length) 186 else: 187 self.kernel_shape = (input_length, input_dim, 188 self.output_length, self.filters) 189 190 self.kernel = self.add_weight(shape=self.kernel_shape, 191 initializer=self.kernel_initializer, 192 name='kernel', 193 regularizer=self.kernel_regularizer, 194 constraint=self.kernel_constraint) 195 196 self.kernel_mask = get_locallyconnected_mask( 197 input_shape=(input_length,), 198 kernel_shape=self.kernel_size, 199 strides=self.strides, 200 padding=self.padding, 201 data_format=self.data_format, 202 dtype=self.kernel.dtype 203 ) 204 205 else: 206 raise ValueError('Unrecognized implementation mode: %d.' 207 % self.implementation) 208 209 if self.use_bias: 210 self.bias = self.add_weight( 211 shape=(self.output_length, self.filters), 212 initializer=self.bias_initializer, 213 name='bias', 214 regularizer=self.bias_regularizer, 215 constraint=self.bias_constraint) 216 else: 217 self.bias = None 218 219 if self.data_format == 'channels_first': 220 self.input_spec = InputSpec(ndim=3, axes={1: input_dim}) 221 else: 222 self.input_spec = InputSpec(ndim=3, axes={-1: input_dim}) 223 self.built = True 224 225 @tf_utils.shape_type_conversion 226 def compute_output_shape(self, input_shape): 227 if self.data_format == 'channels_first': 228 input_length = input_shape[2] 229 else: 230 input_length = input_shape[1] 231 232 length = conv_utils.conv_output_length(input_length, self.kernel_size[0], 233 self.padding, self.strides[0]) 234 235 if self.data_format == 'channels_first': 236 return (input_shape[0], self.filters, length) 237 elif self.data_format == 'channels_last': 238 return (input_shape[0], length, self.filters) 239 240 def call(self, inputs): 241 if self.implementation == 1: 242 output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, 243 (self.output_length,), self.data_format) 244 245 elif self.implementation == 2: 246 output = local_conv_matmul(inputs, self.kernel, self.kernel_mask, 247 self.compute_output_shape(inputs.shape)) 248 249 else: 250 raise ValueError('Unrecognized implementation mode: %d.' 251 % self.implementation) 252 253 if self.use_bias: 254 output = K.bias_add(output, self.bias, data_format=self.data_format) 255 256 output = self.activation(output) 257 return output 258 259 def get_config(self): 260 config = { 261 'filters': 262 self.filters, 263 'kernel_size': 264 self.kernel_size, 265 'strides': 266 self.strides, 267 'padding': 268 self.padding, 269 'data_format': 270 self.data_format, 271 'activation': 272 activations.serialize(self.activation), 273 'use_bias': 274 self.use_bias, 275 'kernel_initializer': 276 initializers.serialize(self.kernel_initializer), 277 'bias_initializer': 278 initializers.serialize(self.bias_initializer), 279 'kernel_regularizer': 280 regularizers.serialize(self.kernel_regularizer), 281 'bias_regularizer': 282 regularizers.serialize(self.bias_regularizer), 283 'activity_regularizer': 284 regularizers.serialize(self.activity_regularizer), 285 'kernel_constraint': 286 constraints.serialize(self.kernel_constraint), 287 'bias_constraint': 288 constraints.serialize(self.bias_constraint), 289 'implementation': 290 self.implementation 291 } 292 base_config = super(LocallyConnected1D, self).get_config() 293 return dict(list(base_config.items()) + list(config.items())) 294 295 296@keras_export('keras.layers.LocallyConnected2D') 297class LocallyConnected2D(Layer): 298 """Locally-connected layer for 2D inputs. 299 300 The `LocallyConnected2D` layer works similarly 301 to the `Conv2D` layer, except that weights are unshared, 302 that is, a different set of filters is applied at each 303 different patch of the input. 304 305 Examples: 306 ```python 307 # apply a 3x3 unshared weights convolution with 64 output filters on a 308 32x32 image 309 # with `data_format="channels_last"`: 310 model = Sequential() 311 model.add(LocallyConnected2D(64, (3, 3), input_shape=(32, 32, 3))) 312 # now model.output_shape == (None, 30, 30, 64) 313 # notice that this layer will consume (30*30)*(3*3*3*64) + (30*30)*64 314 parameters 315 316 # add a 3x3 unshared weights convolution on top, with 32 output filters: 317 model.add(LocallyConnected2D(32, (3, 3))) 318 # now model.output_shape == (None, 28, 28, 32) 319 ``` 320 321 Arguments: 322 filters: Integer, the dimensionality of the output space 323 (i.e. the number of output filters in the convolution). 324 kernel_size: An integer or tuple/list of 2 integers, specifying the 325 width and height of the 2D convolution window. 326 Can be a single integer to specify the same value for 327 all spatial dimensions. 328 strides: An integer or tuple/list of 2 integers, 329 specifying the strides of the convolution along the width and height. 330 Can be a single integer to specify the same value for 331 all spatial dimensions. 332 padding: Currently only support `"valid"` (case-insensitive). 333 `"same"` will be supported in future. 334 data_format: A string, 335 one of `channels_last` (default) or `channels_first`. 336 The ordering of the dimensions in the inputs. 337 `channels_last` corresponds to inputs with shape 338 `(batch, height, width, channels)` while `channels_first` 339 corresponds to inputs with shape 340 `(batch, channels, height, width)`. 341 It defaults to the `image_data_format` value found in your 342 Keras config file at `~/.keras/keras.json`. 343 If you never set it, then it will be "channels_last". 344 activation: Activation function to use. 345 If you don't specify anything, no activation is applied 346 (ie. "linear" activation: `a(x) = x`). 347 use_bias: Boolean, whether the layer uses a bias vector. 348 kernel_initializer: Initializer for the `kernel` weights matrix. 349 bias_initializer: Initializer for the bias vector. 350 kernel_regularizer: Regularizer function applied to 351 the `kernel` weights matrix. 352 bias_regularizer: Regularizer function applied to the bias vector. 353 activity_regularizer: Regularizer function applied to 354 the output of the layer (its "activation"). 355 kernel_constraint: Constraint function applied to the kernel matrix. 356 bias_constraint: Constraint function applied to the bias vector. 357 implementation: implementation mode, either `1` or `2`. 358 `1` loops over input spatial locations to perform the forward pass. 359 It is memory-efficient but performs a lot of (small) ops. 360 361 `2` stores layer weights in a dense but sparsely-populated 2D matrix 362 and implements the forward pass as a single matrix-multiply. It uses 363 a lot of RAM but performs few (large) ops. 364 365 Depending on the inputs, layer parameters, hardware, and 366 `tf.executing_eagerly()` one implementation can be dramatically faster 367 (e.g. 50X) than another. 368 369 It is recommended to benchmark both in the setting of interest to pick 370 the most efficient one (in terms of speed and memory usage). 371 372 Following scenarios could benefit from setting `implementation=2`: 373 - eager execution; 374 - inference; 375 - running on CPU; 376 - large amount of RAM available; 377 - small models (few filters, small kernel); 378 - using `padding=same` (only possible with `implementation=2`). 379 380 Input shape: 381 4D tensor with shape: 382 `(samples, channels, rows, cols)` if data_format='channels_first' 383 or 4D tensor with shape: 384 `(samples, rows, cols, channels)` if data_format='channels_last'. 385 386 Output shape: 387 4D tensor with shape: 388 `(samples, filters, new_rows, new_cols)` if data_format='channels_first' 389 or 4D tensor with shape: 390 `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. 391 `rows` and `cols` values might have changed due to padding. 392 """ 393 394 def __init__(self, 395 filters, 396 kernel_size, 397 strides=(1, 1), 398 padding='valid', 399 data_format=None, 400 activation=None, 401 use_bias=True, 402 kernel_initializer='glorot_uniform', 403 bias_initializer='zeros', 404 kernel_regularizer=None, 405 bias_regularizer=None, 406 activity_regularizer=None, 407 kernel_constraint=None, 408 bias_constraint=None, 409 implementation=1, 410 **kwargs): 411 super(LocallyConnected2D, self).__init__(**kwargs) 412 self.filters = filters 413 self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') 414 self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 415 self.padding = conv_utils.normalize_padding(padding) 416 if self.padding != 'valid' and implementation == 1: 417 raise ValueError('Invalid border mode for LocallyConnected2D ' 418 '(only "valid" is supported if implementation is 1): ' 419 + padding) 420 self.data_format = conv_utils.normalize_data_format(data_format) 421 self.activation = activations.get(activation) 422 self.use_bias = use_bias 423 self.kernel_initializer = initializers.get(kernel_initializer) 424 self.bias_initializer = initializers.get(bias_initializer) 425 self.kernel_regularizer = regularizers.get(kernel_regularizer) 426 self.bias_regularizer = regularizers.get(bias_regularizer) 427 self.activity_regularizer = regularizers.get(activity_regularizer) 428 self.kernel_constraint = constraints.get(kernel_constraint) 429 self.bias_constraint = constraints.get(bias_constraint) 430 self.implementation = implementation 431 self.input_spec = InputSpec(ndim=4) 432 433 @tf_utils.shape_type_conversion 434 def build(self, input_shape): 435 if self.data_format == 'channels_last': 436 input_row, input_col = input_shape[1:-1] 437 input_filter = input_shape[3] 438 else: 439 input_row, input_col = input_shape[2:] 440 input_filter = input_shape[1] 441 if input_row is None or input_col is None: 442 raise ValueError('The spatial dimensions of the inputs to ' 443 ' a LocallyConnected2D layer ' 444 'should be fully-defined, but layer received ' 445 'the inputs shape ' + str(input_shape)) 446 output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0], 447 self.padding, self.strides[0]) 448 output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1], 449 self.padding, self.strides[1]) 450 self.output_row = output_row 451 self.output_col = output_col 452 453 if self.implementation == 1: 454 self.kernel_shape = ( 455 output_row * output_col, 456 self.kernel_size[0] * self.kernel_size[1] * input_filter, 457 self.filters) 458 459 self.kernel = self.add_weight( 460 shape=self.kernel_shape, 461 initializer=self.kernel_initializer, 462 name='kernel', 463 regularizer=self.kernel_regularizer, 464 constraint=self.kernel_constraint) 465 466 elif self.implementation == 2: 467 if self.data_format == 'channels_first': 468 self.kernel_shape = (input_filter, input_row, input_col, 469 self.filters, self.output_row, self.output_col) 470 else: 471 self.kernel_shape = (input_row, input_col, input_filter, 472 self.output_row, self.output_col, self.filters) 473 474 self.kernel = self.add_weight(shape=self.kernel_shape, 475 initializer=self.kernel_initializer, 476 name='kernel', 477 regularizer=self.kernel_regularizer, 478 constraint=self.kernel_constraint) 479 480 self.kernel_mask = get_locallyconnected_mask( 481 input_shape=(input_row, input_col), 482 kernel_shape=self.kernel_size, 483 strides=self.strides, 484 padding=self.padding, 485 data_format=self.data_format, 486 dtype=self.kernel.dtype 487 ) 488 489 else: 490 raise ValueError('Unrecognized implementation mode: %d.' 491 % self.implementation) 492 493 if self.use_bias: 494 self.bias = self.add_weight( 495 shape=(output_row, output_col, self.filters), 496 initializer=self.bias_initializer, 497 name='bias', 498 regularizer=self.bias_regularizer, 499 constraint=self.bias_constraint) 500 else: 501 self.bias = None 502 if self.data_format == 'channels_first': 503 self.input_spec = InputSpec(ndim=4, axes={1: input_filter}) 504 else: 505 self.input_spec = InputSpec(ndim=4, axes={-1: input_filter}) 506 self.built = True 507 508 @tf_utils.shape_type_conversion 509 def compute_output_shape(self, input_shape): 510 if self.data_format == 'channels_first': 511 rows = input_shape[2] 512 cols = input_shape[3] 513 elif self.data_format == 'channels_last': 514 rows = input_shape[1] 515 cols = input_shape[2] 516 517 rows = conv_utils.conv_output_length(rows, self.kernel_size[0], 518 self.padding, self.strides[0]) 519 cols = conv_utils.conv_output_length(cols, self.kernel_size[1], 520 self.padding, self.strides[1]) 521 522 if self.data_format == 'channels_first': 523 return (input_shape[0], self.filters, rows, cols) 524 elif self.data_format == 'channels_last': 525 return (input_shape[0], rows, cols, self.filters) 526 527 def call(self, inputs): 528 if self.implementation == 1: 529 output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides, 530 (self.output_row, self.output_col), 531 self.data_format) 532 533 elif self.implementation == 2: 534 output = local_conv_matmul(inputs, self.kernel, self.kernel_mask, 535 self.compute_output_shape(inputs.shape)) 536 537 else: 538 raise ValueError('Unrecognized implementation mode: %d.' 539 % self.implementation) 540 541 if self.use_bias: 542 output = K.bias_add(output, self.bias, data_format=self.data_format) 543 544 output = self.activation(output) 545 return output 546 547 def get_config(self): 548 config = { 549 'filters': 550 self.filters, 551 'kernel_size': 552 self.kernel_size, 553 'strides': 554 self.strides, 555 'padding': 556 self.padding, 557 'data_format': 558 self.data_format, 559 'activation': 560 activations.serialize(self.activation), 561 'use_bias': 562 self.use_bias, 563 'kernel_initializer': 564 initializers.serialize(self.kernel_initializer), 565 'bias_initializer': 566 initializers.serialize(self.bias_initializer), 567 'kernel_regularizer': 568 regularizers.serialize(self.kernel_regularizer), 569 'bias_regularizer': 570 regularizers.serialize(self.bias_regularizer), 571 'activity_regularizer': 572 regularizers.serialize(self.activity_regularizer), 573 'kernel_constraint': 574 constraints.serialize(self.kernel_constraint), 575 'bias_constraint': 576 constraints.serialize(self.bias_constraint), 577 'implementation': 578 self.implementation 579 } 580 base_config = super(LocallyConnected2D, self).get_config() 581 return dict(list(base_config.items()) + list(config.items())) 582 583 584def get_locallyconnected_mask(input_shape, 585 kernel_shape, 586 strides, 587 padding, 588 data_format, 589 dtype): 590 """Return a mask representing connectivity of a locally-connected operation. 591 592 This method returns a masking tensor of 0s and 1s (of type `dtype`) that, 593 when element-wise multiplied with a fully-connected weight tensor, masks out 594 the weights between disconnected input-output pairs and thus implements local 595 connectivity through a sparse fully-connected weight tensor. 596 597 Assume an unshared convolution with given parameters is applied to an input 598 having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)` 599 to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined 600 by layer parameters such as `strides`). 601 602 This method returns a mask which can be broadcast-multiplied (element-wise) 603 with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer between 604 (N+1)-D activations (N spatial + 1 channel dimensions for input and output) 605 to make it perform an unshared convolution with given `kernel_shape`, 606 `strides`, `padding` and `data_format`. 607 608 Arguments: 609 input_shape: tuple of size N: `(d_in1, ..., d_inN)` 610 spatial shape of the input. 611 kernel_shape: tuple of size N, spatial shape of the convolutional kernel 612 / receptive field. 613 strides: tuple of size N, strides along each spatial dimension. 614 padding: type of padding, string `"same"` or `"valid"`. 615 data_format: a string, `"channels_first"` or `"channels_last"`. 616 dtype: type of the layer operation, e.g. `tf.float64`. 617 618 Returns: 619 a `dtype`-tensor of shape 620 `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)` 621 if `data_format == `"channels_first"`, or 622 `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)` 623 if `data_format == "channels_last"`. 624 625 Raises: 626 ValueError: if `data_format` is neither `"channels_first"` nor 627 `"channels_last"`. 628 """ 629 mask = conv_utils.conv_kernel_mask( 630 input_shape=input_shape, 631 kernel_shape=kernel_shape, 632 strides=strides, 633 padding=padding 634 ) 635 636 ndims = int(mask.ndim / 2) 637 mask = K.variable(mask, dtype) 638 639 if data_format == 'channels_first': 640 mask = K.expand_dims(mask, 0) 641 mask = K.expand_dims(mask, - ndims - 1) 642 643 elif data_format == 'channels_last': 644 mask = K.expand_dims(mask, ndims) 645 mask = K.expand_dims(mask, -1) 646 647 else: 648 raise ValueError('Unrecognized data_format: ' + str(data_format)) 649 650 return mask 651 652 653def local_conv_matmul(inputs, kernel, kernel_mask, output_shape): 654 """Apply N-D convolution with un-shared weights using a single matmul call. 655 656 This method outputs `inputs . (kernel * kernel_mask)` 657 (with `.` standing for matrix-multiply and `*` for element-wise multiply) 658 and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and 659 hence perform the same operation as a convolution with un-shared 660 (the remaining entries in `kernel`) weights. It also does the necessary 661 reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D. 662 663 Arguments: 664 inputs: (N+2)-D tensor with shape 665 `(batch_size, channels_in, d_in1, ..., d_inN)` 666 or 667 `(batch_size, d_in1, ..., d_inN, channels_in)`. 668 kernel: the unshared weights for N-D convolution, 669 an (N+2)-D tensor of shape: 670 `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)` 671 or 672 `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`, 673 with the ordering of channels and spatial dimensions matching 674 that of the input. 675 Each entry is the weight between a particular input and 676 output location, similarly to a fully-connected weight matrix. 677 kernel_mask: a float 0/1 mask tensor of shape: 678 `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)` 679 or 680 `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`, 681 with the ordering of singleton and spatial dimensions 682 matching that of the input. 683 Mask represents the connectivity pattern of the layer and is 684 precomputed elsewhere based on layer parameters: stride, 685 padding, and the receptive field shape. 686 output_shape: a tuple of (N+2) elements representing the output shape: 687 `(batch_size, channels_out, d_out1, ..., d_outN)` 688 or 689 `(batch_size, d_out1, ..., d_outN, channels_out)`, 690 with the ordering of channels and spatial dimensions matching that of 691 the input. 692 693 Returns: 694 Output (N+2)-D tensor with shape `output_shape`. 695 """ 696 inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1)) 697 698 kernel = kernel_mask * kernel 699 kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2) 700 701 output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True) 702 output = K.reshape(output_flat, 703 [K.shape(output_flat)[0],] + output_shape.as_list()[1:]) 704 return output 705 706 707def make_2d(tensor, split_dim): 708 """Reshapes an N-dimensional tensor into a 2D tensor. 709 710 Dimensions before (excluding) and after (including) `split_dim` are grouped 711 together. 712 713 Arguments: 714 tensor: a tensor of shape `(d0, ..., d(N-1))`. 715 split_dim: an integer from 1 to N-1, index of the dimension to group 716 dimensions before (excluding) and after (including). 717 718 Returns: 719 Tensor of shape 720 `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`. 721 """ 722 shape = K.array_ops.shape(tensor) 723 in_dims = shape[:split_dim] 724 out_dims = shape[split_dim:] 725 726 in_size = K.math_ops.reduce_prod(in_dims) 727 out_size = K.math_ops.reduce_prod(out_dims) 728 729 return K.array_ops.reshape(tensor, (in_size, out_size)) 730