1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Locally-connected layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.keras import activations
22from tensorflow.python.keras import backend as K
23from tensorflow.python.keras import constraints
24from tensorflow.python.keras import initializers
25from tensorflow.python.keras import regularizers
26from tensorflow.python.keras.engine.base_layer import Layer
27from tensorflow.python.keras.engine.input_spec import InputSpec
28from tensorflow.python.keras.utils import conv_utils
29from tensorflow.python.keras.utils import tf_utils
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export('keras.layers.LocallyConnected1D')
34class LocallyConnected1D(Layer):
35  """Locally-connected layer for 1D inputs.
36
37  The `LocallyConnected1D` layer works similarly to
38  the `Conv1D` layer, except that weights are unshared,
39  that is, a different set of filters is applied at each different patch
40  of the input.
41
42  Example:
43  ```python
44      # apply a unshared weight convolution 1d of length 3 to a sequence with
45      # 10 timesteps, with 64 output filters
46      model = Sequential()
47      model.add(LocallyConnected1D(64, 3, input_shape=(10, 32)))
48      # now model.output_shape == (None, 8, 64)
49      # add a new conv1d on top
50      model.add(LocallyConnected1D(32, 3))
51      # now model.output_shape == (None, 6, 32)
52  ```
53
54  Arguments:
55      filters: Integer, the dimensionality of the output space
56          (i.e. the number of output filters in the convolution).
57      kernel_size: An integer or tuple/list of a single integer,
58          specifying the length of the 1D convolution window.
59      strides: An integer or tuple/list of a single integer,
60          specifying the stride length of the convolution.
61          Specifying any stride value != 1 is incompatible with specifying
62          any `dilation_rate` value != 1.
63      padding: Currently only supports `"valid"` (case-insensitive).
64          `"same"` may be supported in the future.
65      data_format: A string,
66          one of `channels_last` (default) or `channels_first`.
67          The ordering of the dimensions in the inputs.
68          `channels_last` corresponds to inputs with shape
69          `(batch, length, channels)` while `channels_first`
70          corresponds to inputs with shape
71          `(batch, channels, length)`.
72          It defaults to the `image_data_format` value found in your
73          Keras config file at `~/.keras/keras.json`.
74          If you never set it, then it will be "channels_last".
75      activation: Activation function to use.
76          If you don't specify anything, no activation is applied
77          (ie. "linear" activation: `a(x) = x`).
78      use_bias: Boolean, whether the layer uses a bias vector.
79      kernel_initializer: Initializer for the `kernel` weights matrix.
80      bias_initializer: Initializer for the bias vector.
81      kernel_regularizer: Regularizer function applied to
82          the `kernel` weights matrix.
83      bias_regularizer: Regularizer function applied to the bias vector.
84      activity_regularizer: Regularizer function applied to
85          the output of the layer (its "activation")..
86      kernel_constraint: Constraint function applied to the kernel matrix.
87      bias_constraint: Constraint function applied to the bias vector.
88      implementation: implementation mode, either `1` or `2`.
89          `1` loops over input spatial locations to perform the forward pass.
90          It is memory-efficient but performs a lot of (small) ops.
91
92          `2` stores layer weights in a dense but sparsely-populated 2D matrix
93          and implements the forward pass as a single matrix-multiply. It uses
94          a lot of RAM but performs few (large) ops.
95
96          Depending on the inputs, layer parameters, hardware, and
97          `tf.executing_eagerly()` one implementation can be dramatically faster
98          (e.g. 50X) than another.
99
100          It is recommended to benchmark both in the setting of interest to pick
101          the most efficient one (in terms of speed and memory usage).
102
103          Following scenarios could benefit from setting `implementation=2`:
104              - eager execution;
105              - inference;
106              - running on CPU;
107              - large amount of RAM available;
108              - small models (few filters, small kernel);
109              - using `padding=same` (only possible with `implementation=2`).
110
111  Input shape:
112      3D tensor with shape: `(batch_size, steps, input_dim)`
113
114  Output shape:
115      3D tensor with shape: `(batch_size, new_steps, filters)`
116      `steps` value might have changed due to padding or strides.
117  """
118
119  def __init__(self,
120               filters,
121               kernel_size,
122               strides=1,
123               padding='valid',
124               data_format=None,
125               activation=None,
126               use_bias=True,
127               kernel_initializer='glorot_uniform',
128               bias_initializer='zeros',
129               kernel_regularizer=None,
130               bias_regularizer=None,
131               activity_regularizer=None,
132               kernel_constraint=None,
133               bias_constraint=None,
134               implementation=1,
135               **kwargs):
136    super(LocallyConnected1D, self).__init__(**kwargs)
137    self.filters = filters
138    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size')
139    self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
140    self.padding = conv_utils.normalize_padding(padding)
141    if self.padding != 'valid' and implementation == 1:
142      raise ValueError('Invalid border mode for LocallyConnected1D '
143                       '(only "valid" is supported if implementation is 1): '
144                       + padding)
145    self.data_format = conv_utils.normalize_data_format(data_format)
146    self.activation = activations.get(activation)
147    self.use_bias = use_bias
148    self.kernel_initializer = initializers.get(kernel_initializer)
149    self.bias_initializer = initializers.get(bias_initializer)
150    self.kernel_regularizer = regularizers.get(kernel_regularizer)
151    self.bias_regularizer = regularizers.get(bias_regularizer)
152    self.activity_regularizer = regularizers.get(activity_regularizer)
153    self.kernel_constraint = constraints.get(kernel_constraint)
154    self.bias_constraint = constraints.get(bias_constraint)
155    self.implementation = implementation
156    self.input_spec = InputSpec(ndim=3)
157
158  @tf_utils.shape_type_conversion
159  def build(self, input_shape):
160    if self.data_format == 'channels_first':
161      input_dim, input_length = input_shape[1], input_shape[2]
162    else:
163      input_dim, input_length = input_shape[2], input_shape[1]
164
165    if input_dim is None:
166      raise ValueError('Axis 2 of input should be fully-defined. '
167                       'Found shape:', input_shape)
168    self.output_length = conv_utils.conv_output_length(
169        input_length, self.kernel_size[0], self.padding, self.strides[0])
170
171    if self.implementation == 1:
172      self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
173                           self.filters)
174
175      self.kernel = self.add_weight(
176          shape=self.kernel_shape,
177          initializer=self.kernel_initializer,
178          name='kernel',
179          regularizer=self.kernel_regularizer,
180          constraint=self.kernel_constraint)
181
182    elif self.implementation == 2:
183      if self.data_format == 'channels_first':
184        self.kernel_shape = (input_dim, input_length,
185                             self.filters, self.output_length)
186      else:
187        self.kernel_shape = (input_length, input_dim,
188                             self.output_length, self.filters)
189
190      self.kernel = self.add_weight(shape=self.kernel_shape,
191                                    initializer=self.kernel_initializer,
192                                    name='kernel',
193                                    regularizer=self.kernel_regularizer,
194                                    constraint=self.kernel_constraint)
195
196      self.kernel_mask = get_locallyconnected_mask(
197          input_shape=(input_length,),
198          kernel_shape=self.kernel_size,
199          strides=self.strides,
200          padding=self.padding,
201          data_format=self.data_format,
202          dtype=self.kernel.dtype
203      )
204
205    else:
206      raise ValueError('Unrecognized implementation mode: %d.'
207                       % self.implementation)
208
209    if self.use_bias:
210      self.bias = self.add_weight(
211          shape=(self.output_length, self.filters),
212          initializer=self.bias_initializer,
213          name='bias',
214          regularizer=self.bias_regularizer,
215          constraint=self.bias_constraint)
216    else:
217      self.bias = None
218
219    if self.data_format == 'channels_first':
220      self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
221    else:
222      self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
223    self.built = True
224
225  @tf_utils.shape_type_conversion
226  def compute_output_shape(self, input_shape):
227    if self.data_format == 'channels_first':
228      input_length = input_shape[2]
229    else:
230      input_length = input_shape[1]
231
232    length = conv_utils.conv_output_length(input_length, self.kernel_size[0],
233                                           self.padding, self.strides[0])
234
235    if self.data_format == 'channels_first':
236      return (input_shape[0], self.filters, length)
237    elif self.data_format == 'channels_last':
238      return (input_shape[0], length, self.filters)
239
240  def call(self, inputs):
241    if self.implementation == 1:
242      output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
243                            (self.output_length,), self.data_format)
244
245    elif self.implementation == 2:
246      output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
247                                 self.compute_output_shape(inputs.shape))
248
249    else:
250      raise ValueError('Unrecognized implementation mode: %d.'
251                       % self.implementation)
252
253    if self.use_bias:
254      output = K.bias_add(output, self.bias, data_format=self.data_format)
255
256    output = self.activation(output)
257    return output
258
259  def get_config(self):
260    config = {
261        'filters':
262            self.filters,
263        'kernel_size':
264            self.kernel_size,
265        'strides':
266            self.strides,
267        'padding':
268            self.padding,
269        'data_format':
270            self.data_format,
271        'activation':
272            activations.serialize(self.activation),
273        'use_bias':
274            self.use_bias,
275        'kernel_initializer':
276            initializers.serialize(self.kernel_initializer),
277        'bias_initializer':
278            initializers.serialize(self.bias_initializer),
279        'kernel_regularizer':
280            regularizers.serialize(self.kernel_regularizer),
281        'bias_regularizer':
282            regularizers.serialize(self.bias_regularizer),
283        'activity_regularizer':
284            regularizers.serialize(self.activity_regularizer),
285        'kernel_constraint':
286            constraints.serialize(self.kernel_constraint),
287        'bias_constraint':
288            constraints.serialize(self.bias_constraint),
289        'implementation':
290            self.implementation
291    }
292    base_config = super(LocallyConnected1D, self).get_config()
293    return dict(list(base_config.items()) + list(config.items()))
294
295
296@keras_export('keras.layers.LocallyConnected2D')
297class LocallyConnected2D(Layer):
298  """Locally-connected layer for 2D inputs.
299
300  The `LocallyConnected2D` layer works similarly
301  to the `Conv2D` layer, except that weights are unshared,
302  that is, a different set of filters is applied at each
303  different patch of the input.
304
305  Examples:
306  ```python
307      # apply a 3x3 unshared weights convolution with 64 output filters on a
308      32x32 image
309      # with `data_format="channels_last"`:
310      model = Sequential()
311      model.add(LocallyConnected2D(64, (3, 3), input_shape=(32, 32, 3)))
312      # now model.output_shape == (None, 30, 30, 64)
313      # notice that this layer will consume (30*30)*(3*3*3*64) + (30*30)*64
314      parameters
315
316      # add a 3x3 unshared weights convolution on top, with 32 output filters:
317      model.add(LocallyConnected2D(32, (3, 3)))
318      # now model.output_shape == (None, 28, 28, 32)
319  ```
320
321  Arguments:
322      filters: Integer, the dimensionality of the output space
323          (i.e. the number of output filters in the convolution).
324      kernel_size: An integer or tuple/list of 2 integers, specifying the
325          width and height of the 2D convolution window.
326          Can be a single integer to specify the same value for
327          all spatial dimensions.
328      strides: An integer or tuple/list of 2 integers,
329          specifying the strides of the convolution along the width and height.
330          Can be a single integer to specify the same value for
331          all spatial dimensions.
332      padding: Currently only support `"valid"` (case-insensitive).
333          `"same"` will be supported in future.
334      data_format: A string,
335          one of `channels_last` (default) or `channels_first`.
336          The ordering of the dimensions in the inputs.
337          `channels_last` corresponds to inputs with shape
338          `(batch, height, width, channels)` while `channels_first`
339          corresponds to inputs with shape
340          `(batch, channels, height, width)`.
341          It defaults to the `image_data_format` value found in your
342          Keras config file at `~/.keras/keras.json`.
343          If you never set it, then it will be "channels_last".
344      activation: Activation function to use.
345          If you don't specify anything, no activation is applied
346          (ie. "linear" activation: `a(x) = x`).
347      use_bias: Boolean, whether the layer uses a bias vector.
348      kernel_initializer: Initializer for the `kernel` weights matrix.
349      bias_initializer: Initializer for the bias vector.
350      kernel_regularizer: Regularizer function applied to
351          the `kernel` weights matrix.
352      bias_regularizer: Regularizer function applied to the bias vector.
353      activity_regularizer: Regularizer function applied to
354          the output of the layer (its "activation").
355      kernel_constraint: Constraint function applied to the kernel matrix.
356      bias_constraint: Constraint function applied to the bias vector.
357      implementation: implementation mode, either `1` or `2`.
358          `1` loops over input spatial locations to perform the forward pass.
359          It is memory-efficient but performs a lot of (small) ops.
360
361          `2` stores layer weights in a dense but sparsely-populated 2D matrix
362          and implements the forward pass as a single matrix-multiply. It uses
363          a lot of RAM but performs few (large) ops.
364
365          Depending on the inputs, layer parameters, hardware, and
366          `tf.executing_eagerly()` one implementation can be dramatically faster
367          (e.g. 50X) than another.
368
369          It is recommended to benchmark both in the setting of interest to pick
370          the most efficient one (in terms of speed and memory usage).
371
372          Following scenarios could benefit from setting `implementation=2`:
373              - eager execution;
374              - inference;
375              - running on CPU;
376              - large amount of RAM available;
377              - small models (few filters, small kernel);
378              - using `padding=same` (only possible with `implementation=2`).
379
380  Input shape:
381      4D tensor with shape:
382      `(samples, channels, rows, cols)` if data_format='channels_first'
383      or 4D tensor with shape:
384      `(samples, rows, cols, channels)` if data_format='channels_last'.
385
386  Output shape:
387      4D tensor with shape:
388      `(samples, filters, new_rows, new_cols)` if data_format='channels_first'
389      or 4D tensor with shape:
390      `(samples, new_rows, new_cols, filters)` if data_format='channels_last'.
391      `rows` and `cols` values might have changed due to padding.
392  """
393
394  def __init__(self,
395               filters,
396               kernel_size,
397               strides=(1, 1),
398               padding='valid',
399               data_format=None,
400               activation=None,
401               use_bias=True,
402               kernel_initializer='glorot_uniform',
403               bias_initializer='zeros',
404               kernel_regularizer=None,
405               bias_regularizer=None,
406               activity_regularizer=None,
407               kernel_constraint=None,
408               bias_constraint=None,
409               implementation=1,
410               **kwargs):
411    super(LocallyConnected2D, self).__init__(**kwargs)
412    self.filters = filters
413    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
414    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
415    self.padding = conv_utils.normalize_padding(padding)
416    if self.padding != 'valid' and implementation == 1:
417      raise ValueError('Invalid border mode for LocallyConnected2D '
418                       '(only "valid" is supported if implementation is 1): '
419                       + padding)
420    self.data_format = conv_utils.normalize_data_format(data_format)
421    self.activation = activations.get(activation)
422    self.use_bias = use_bias
423    self.kernel_initializer = initializers.get(kernel_initializer)
424    self.bias_initializer = initializers.get(bias_initializer)
425    self.kernel_regularizer = regularizers.get(kernel_regularizer)
426    self.bias_regularizer = regularizers.get(bias_regularizer)
427    self.activity_regularizer = regularizers.get(activity_regularizer)
428    self.kernel_constraint = constraints.get(kernel_constraint)
429    self.bias_constraint = constraints.get(bias_constraint)
430    self.implementation = implementation
431    self.input_spec = InputSpec(ndim=4)
432
433  @tf_utils.shape_type_conversion
434  def build(self, input_shape):
435    if self.data_format == 'channels_last':
436      input_row, input_col = input_shape[1:-1]
437      input_filter = input_shape[3]
438    else:
439      input_row, input_col = input_shape[2:]
440      input_filter = input_shape[1]
441    if input_row is None or input_col is None:
442      raise ValueError('The spatial dimensions of the inputs to '
443                       ' a LocallyConnected2D layer '
444                       'should be fully-defined, but layer received '
445                       'the inputs shape ' + str(input_shape))
446    output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
447                                               self.padding, self.strides[0])
448    output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
449                                               self.padding, self.strides[1])
450    self.output_row = output_row
451    self.output_col = output_col
452
453    if self.implementation == 1:
454      self.kernel_shape = (
455          output_row * output_col,
456          self.kernel_size[0] * self.kernel_size[1] * input_filter,
457          self.filters)
458
459      self.kernel = self.add_weight(
460          shape=self.kernel_shape,
461          initializer=self.kernel_initializer,
462          name='kernel',
463          regularizer=self.kernel_regularizer,
464          constraint=self.kernel_constraint)
465
466    elif self.implementation == 2:
467      if self.data_format == 'channels_first':
468        self.kernel_shape = (input_filter, input_row, input_col,
469                             self.filters, self.output_row, self.output_col)
470      else:
471        self.kernel_shape = (input_row, input_col, input_filter,
472                             self.output_row, self.output_col, self.filters)
473
474      self.kernel = self.add_weight(shape=self.kernel_shape,
475                                    initializer=self.kernel_initializer,
476                                    name='kernel',
477                                    regularizer=self.kernel_regularizer,
478                                    constraint=self.kernel_constraint)
479
480      self.kernel_mask = get_locallyconnected_mask(
481          input_shape=(input_row, input_col),
482          kernel_shape=self.kernel_size,
483          strides=self.strides,
484          padding=self.padding,
485          data_format=self.data_format,
486          dtype=self.kernel.dtype
487      )
488
489    else:
490      raise ValueError('Unrecognized implementation mode: %d.'
491                       % self.implementation)
492
493    if self.use_bias:
494      self.bias = self.add_weight(
495          shape=(output_row, output_col, self.filters),
496          initializer=self.bias_initializer,
497          name='bias',
498          regularizer=self.bias_regularizer,
499          constraint=self.bias_constraint)
500    else:
501      self.bias = None
502    if self.data_format == 'channels_first':
503      self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
504    else:
505      self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
506    self.built = True
507
508  @tf_utils.shape_type_conversion
509  def compute_output_shape(self, input_shape):
510    if self.data_format == 'channels_first':
511      rows = input_shape[2]
512      cols = input_shape[3]
513    elif self.data_format == 'channels_last':
514      rows = input_shape[1]
515      cols = input_shape[2]
516
517    rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
518                                         self.padding, self.strides[0])
519    cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
520                                         self.padding, self.strides[1])
521
522    if self.data_format == 'channels_first':
523      return (input_shape[0], self.filters, rows, cols)
524    elif self.data_format == 'channels_last':
525      return (input_shape[0], rows, cols, self.filters)
526
527  def call(self, inputs):
528    if self.implementation == 1:
529      output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
530                            (self.output_row, self.output_col),
531                            self.data_format)
532
533    elif self.implementation == 2:
534      output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
535                                 self.compute_output_shape(inputs.shape))
536
537    else:
538      raise ValueError('Unrecognized implementation mode: %d.'
539                       % self.implementation)
540
541    if self.use_bias:
542      output = K.bias_add(output, self.bias, data_format=self.data_format)
543
544    output = self.activation(output)
545    return output
546
547  def get_config(self):
548    config = {
549        'filters':
550            self.filters,
551        'kernel_size':
552            self.kernel_size,
553        'strides':
554            self.strides,
555        'padding':
556            self.padding,
557        'data_format':
558            self.data_format,
559        'activation':
560            activations.serialize(self.activation),
561        'use_bias':
562            self.use_bias,
563        'kernel_initializer':
564            initializers.serialize(self.kernel_initializer),
565        'bias_initializer':
566            initializers.serialize(self.bias_initializer),
567        'kernel_regularizer':
568            regularizers.serialize(self.kernel_regularizer),
569        'bias_regularizer':
570            regularizers.serialize(self.bias_regularizer),
571        'activity_regularizer':
572            regularizers.serialize(self.activity_regularizer),
573        'kernel_constraint':
574            constraints.serialize(self.kernel_constraint),
575        'bias_constraint':
576            constraints.serialize(self.bias_constraint),
577        'implementation':
578            self.implementation
579    }
580    base_config = super(LocallyConnected2D, self).get_config()
581    return dict(list(base_config.items()) + list(config.items()))
582
583
584def get_locallyconnected_mask(input_shape,
585                              kernel_shape,
586                              strides,
587                              padding,
588                              data_format,
589                              dtype):
590  """Return a mask representing connectivity of a locally-connected operation.
591
592  This method returns a masking tensor of 0s and 1s (of type `dtype`) that,
593  when element-wise multiplied with a fully-connected weight tensor, masks out
594  the weights between disconnected input-output pairs and thus implements local
595  connectivity through a sparse fully-connected weight tensor.
596
597  Assume an unshared convolution with given parameters is applied to an input
598  having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)`
599  to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined
600  by layer parameters such as `strides`).
601
602  This method returns a mask which can be broadcast-multiplied (element-wise)
603  with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer between
604  (N+1)-D activations (N spatial + 1 channel dimensions for input and output)
605  to make it perform an unshared convolution with given `kernel_shape`,
606  `strides`, `padding` and `data_format`.
607
608  Arguments:
609    input_shape: tuple of size N: `(d_in1, ..., d_inN)`
610                 spatial shape of the input.
611    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
612                  / receptive field.
613    strides: tuple of size N, strides along each spatial dimension.
614    padding: type of padding, string `"same"` or `"valid"`.
615    data_format: a string, `"channels_first"` or `"channels_last"`.
616    dtype: type of the layer operation, e.g. `tf.float64`.
617
618  Returns:
619    a `dtype`-tensor of shape
620    `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)`
621    if `data_format == `"channels_first"`, or
622    `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)`
623    if `data_format == "channels_last"`.
624
625  Raises:
626    ValueError: if `data_format` is neither `"channels_first"` nor
627                `"channels_last"`.
628  """
629  mask = conv_utils.conv_kernel_mask(
630      input_shape=input_shape,
631      kernel_shape=kernel_shape,
632      strides=strides,
633      padding=padding
634  )
635
636  ndims = int(mask.ndim / 2)
637  mask = K.variable(mask, dtype)
638
639  if data_format == 'channels_first':
640    mask = K.expand_dims(mask, 0)
641    mask = K.expand_dims(mask, - ndims - 1)
642
643  elif data_format == 'channels_last':
644    mask = K.expand_dims(mask, ndims)
645    mask = K.expand_dims(mask, -1)
646
647  else:
648    raise ValueError('Unrecognized data_format: ' + str(data_format))
649
650  return mask
651
652
653def local_conv_matmul(inputs, kernel, kernel_mask, output_shape):
654  """Apply N-D convolution with un-shared weights using a single matmul call.
655
656  This method outputs `inputs . (kernel * kernel_mask)`
657  (with `.` standing for matrix-multiply and `*` for element-wise multiply)
658  and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and
659  hence perform the same operation as a convolution with un-shared
660  (the remaining entries in `kernel`) weights. It also does the necessary
661  reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D.
662
663  Arguments:
664      inputs: (N+2)-D tensor with shape
665          `(batch_size, channels_in, d_in1, ..., d_inN)`
666          or
667          `(batch_size, d_in1, ..., d_inN, channels_in)`.
668      kernel: the unshared weights for N-D convolution,
669          an (N+2)-D tensor of shape:
670          `(d_in1, ..., d_inN, channels_in, d_out2, ..., d_outN, channels_out)`
671          or
672          `(channels_in, d_in1, ..., d_inN, channels_out, d_out2, ..., d_outN)`,
673          with the ordering of channels and spatial dimensions matching
674          that of the input.
675          Each entry is the weight between a particular input and
676          output location, similarly to a fully-connected weight matrix.
677      kernel_mask: a float 0/1 mask tensor of shape:
678           `(d_in1, ..., d_inN, 1, d_out2, ..., d_outN, 1)`
679           or
680           `(1, d_in1, ..., d_inN, 1, d_out2, ..., d_outN)`,
681           with the ordering of singleton and spatial dimensions
682           matching that of the input.
683           Mask represents the connectivity pattern of the layer and is
684           precomputed elsewhere based on layer parameters: stride,
685           padding, and the receptive field shape.
686      output_shape: a tuple of (N+2) elements representing the output shape:
687          `(batch_size, channels_out, d_out1, ..., d_outN)`
688          or
689          `(batch_size, d_out1, ..., d_outN, channels_out)`,
690          with the ordering of channels and spatial dimensions matching that of
691          the input.
692
693  Returns:
694      Output (N+2)-D tensor with shape `output_shape`.
695  """
696  inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
697
698  kernel = kernel_mask * kernel
699  kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)
700
701  output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
702  output = K.reshape(output_flat,
703                     [K.shape(output_flat)[0],] + output_shape.as_list()[1:])
704  return output
705
706
707def make_2d(tensor, split_dim):
708  """Reshapes an N-dimensional tensor into a 2D tensor.
709
710  Dimensions before (excluding) and after (including) `split_dim` are grouped
711  together.
712
713  Arguments:
714    tensor: a tensor of shape `(d0, ..., d(N-1))`.
715    split_dim: an integer from 1 to N-1, index of the dimension to group
716        dimensions before (excluding) and after (including).
717
718  Returns:
719    Tensor of shape
720    `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
721  """
722  shape = K.array_ops.shape(tensor)
723  in_dims = shape[:split_dim]
724  out_dims = shape[split_dim:]
725
726  in_size = K.math_ops.reduce_prod(in_dims)
727  out_size = K.math_ops.reduce_prod(out_dims)
728
729  return K.array_ops.reshape(tensor, (in_size, out_size))
730