1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Locally-connected layers."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.keras import activations
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras import constraints
25from tensorflow.python.keras import initializers
26from tensorflow.python.keras import regularizers
27from tensorflow.python.keras.engine.base_layer import Layer
28from tensorflow.python.keras.engine.input_spec import InputSpec
29from tensorflow.python.keras.utils import conv_utils
30from tensorflow.python.keras.utils import tf_utils
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import gen_sparse_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.util.tf_export import keras_export
35
36
37@keras_export('keras.layers.LocallyConnected1D')
38class LocallyConnected1D(Layer):
39  """Locally-connected layer for 1D inputs.
40
41  The `LocallyConnected1D` layer works similarly to
42  the `Conv1D` layer, except that weights are unshared,
43  that is, a different set of filters is applied at each different patch
44  of the input.
45
46  Note: layer attributes cannot be modified after the layer has been called
47  once (except the `trainable` attribute).
48
49  Example:
50  ```python
51      # apply a unshared weight convolution 1d of length 3 to a sequence with
52      # 10 timesteps, with 64 output filters
53      model = Sequential()
54      model.add(LocallyConnected1D(64, 3, input_shape=(10, 32)))
55      # now model.output_shape == (None, 8, 64)
56      # add a new conv1d on top
57      model.add(LocallyConnected1D(32, 3))
58      # now model.output_shape == (None, 6, 32)
59  ```
60
61  Args:
62      filters: Integer, the dimensionality of the output space (i.e. the number
63        of output filters in the convolution).
64      kernel_size: An integer or tuple/list of a single integer, specifying the
65        length of the 1D convolution window.
66      strides: An integer or tuple/list of a single integer, specifying the
67        stride length of the convolution.
68      padding: Currently only supports `"valid"` (case-insensitive). `"same"`
69        may be supported in the future. `"valid"` means no padding.
70      data_format: A string, one of `channels_last` (default) or
71        `channels_first`. The ordering of the dimensions in the inputs.
72        `channels_last` corresponds to inputs with shape `(batch, length,
73        channels)` while `channels_first` corresponds to inputs with shape
74        `(batch, channels, length)`. It defaults to the `image_data_format`
75        value found in your Keras config file at `~/.keras/keras.json`. If you
76        never set it, then it will be "channels_last".
77      activation: Activation function to use. If you don't specify anything, no
78        activation is applied
79          (ie. "linear" activation: `a(x) = x`).
80      use_bias: Boolean, whether the layer uses a bias vector.
81      kernel_initializer: Initializer for the `kernel` weights matrix.
82      bias_initializer: Initializer for the bias vector.
83      kernel_regularizer: Regularizer function applied to the `kernel` weights
84        matrix.
85      bias_regularizer: Regularizer function applied to the bias vector.
86      activity_regularizer: Regularizer function applied to the output of the
87        layer (its "activation")..
88      kernel_constraint: Constraint function applied to the kernel matrix.
89      bias_constraint: Constraint function applied to the bias vector.
90      implementation: implementation mode, either `1`, `2`, or `3`. `1` loops
91        over input spatial locations to perform the forward pass. It is
92        memory-efficient but performs a lot of (small) ops.  `2` stores layer
93        weights in a dense but sparsely-populated 2D matrix and implements the
94        forward pass as a single matrix-multiply. It uses a lot of RAM but
95        performs few (large) ops.  `3` stores layer weights in a sparse tensor
96        and implements the forward pass as a single sparse matrix-multiply.
97          How to choose:
98          `1`: large, dense models,
99          `2`: small models,
100          `3`: large, sparse models,  where "large" stands for large
101            input/output activations (i.e. many `filters`, `input_filters`,
102            large `input_size`, `output_size`), and "sparse" stands for few
103            connections between inputs and outputs, i.e. small ratio `filters *
104            input_filters * kernel_size / (input_size * strides)`, where inputs
105            to and outputs of the layer are assumed to have shapes `(input_size,
106            input_filters)`, `(output_size, filters)` respectively.  It is
107            recommended to benchmark each in the setting of interest to pick the
108            most efficient one (in terms of speed and memory usage). Correct
109            choice of implementation can lead to dramatic speed improvements
110            (e.g. 50X), potentially at the expense of RAM.  Also, only
111            `padding="valid"` is supported by `implementation=1`.
112  Input shape:
113      3D tensor with shape: `(batch_size, steps, input_dim)`
114  Output shape:
115      3D tensor with shape: `(batch_size, new_steps, filters)` `steps` value
116        might have changed due to padding or strides.
117  """
118
119  def __init__(self,
120               filters,
121               kernel_size,
122               strides=1,
123               padding='valid',
124               data_format=None,
125               activation=None,
126               use_bias=True,
127               kernel_initializer='glorot_uniform',
128               bias_initializer='zeros',
129               kernel_regularizer=None,
130               bias_regularizer=None,
131               activity_regularizer=None,
132               kernel_constraint=None,
133               bias_constraint=None,
134               implementation=1,
135               **kwargs):
136    super(LocallyConnected1D, self).__init__(**kwargs)
137    self.filters = filters
138    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size')
139    self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
140    self.padding = conv_utils.normalize_padding(padding)
141    if self.padding != 'valid' and implementation == 1:
142      raise ValueError('Invalid border mode for LocallyConnected1D '
143                       '(only "valid" is supported if implementation is 1): ' +
144                       padding)
145    self.data_format = conv_utils.normalize_data_format(data_format)
146    self.activation = activations.get(activation)
147    self.use_bias = use_bias
148    self.kernel_initializer = initializers.get(kernel_initializer)
149    self.bias_initializer = initializers.get(bias_initializer)
150    self.kernel_regularizer = regularizers.get(kernel_regularizer)
151    self.bias_regularizer = regularizers.get(bias_regularizer)
152    self.activity_regularizer = regularizers.get(activity_regularizer)
153    self.kernel_constraint = constraints.get(kernel_constraint)
154    self.bias_constraint = constraints.get(bias_constraint)
155    self.implementation = implementation
156    self.input_spec = InputSpec(ndim=3)
157
158  @tf_utils.shape_type_conversion
159  def build(self, input_shape):
160    if self.data_format == 'channels_first':
161      input_dim, input_length = input_shape[1], input_shape[2]
162    else:
163      input_dim, input_length = input_shape[2], input_shape[1]
164
165    if input_dim is None:
166      raise ValueError(
167          'Axis 2 of input should be fully-defined. '
168          'Found shape:', input_shape)
169    self.output_length = conv_utils.conv_output_length(input_length,
170                                                       self.kernel_size[0],
171                                                       self.padding,
172                                                       self.strides[0])
173
174    if self.implementation == 1:
175      self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
176                           self.filters)
177
178      self.kernel = self.add_weight(
179          shape=self.kernel_shape,
180          initializer=self.kernel_initializer,
181          name='kernel',
182          regularizer=self.kernel_regularizer,
183          constraint=self.kernel_constraint)
184
185    elif self.implementation == 2:
186      if self.data_format == 'channels_first':
187        self.kernel_shape = (input_dim, input_length, self.filters,
188                             self.output_length)
189      else:
190        self.kernel_shape = (input_length, input_dim, self.output_length,
191                             self.filters)
192
193      self.kernel = self.add_weight(
194          shape=self.kernel_shape,
195          initializer=self.kernel_initializer,
196          name='kernel',
197          regularizer=self.kernel_regularizer,
198          constraint=self.kernel_constraint)
199
200      self.kernel_mask = get_locallyconnected_mask(
201          input_shape=(input_length,),
202          kernel_shape=self.kernel_size,
203          strides=self.strides,
204          padding=self.padding,
205          data_format=self.data_format,
206      )
207
208    elif self.implementation == 3:
209      self.kernel_shape = (self.output_length * self.filters,
210                           input_length * input_dim)
211
212      self.kernel_idxs = sorted(
213          conv_utils.conv_kernel_idxs(
214              input_shape=(input_length,),
215              kernel_shape=self.kernel_size,
216              strides=self.strides,
217              padding=self.padding,
218              filters_in=input_dim,
219              filters_out=self.filters,
220              data_format=self.data_format))
221
222      self.kernel = self.add_weight(
223          shape=(len(self.kernel_idxs),),
224          initializer=self.kernel_initializer,
225          name='kernel',
226          regularizer=self.kernel_regularizer,
227          constraint=self.kernel_constraint)
228
229    else:
230      raise ValueError('Unrecognized implementation mode: %d.' %
231                       self.implementation)
232
233    if self.use_bias:
234      self.bias = self.add_weight(
235          shape=(self.output_length, self.filters),
236          initializer=self.bias_initializer,
237          name='bias',
238          regularizer=self.bias_regularizer,
239          constraint=self.bias_constraint)
240    else:
241      self.bias = None
242
243    if self.data_format == 'channels_first':
244      self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
245    else:
246      self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
247    self.built = True
248
249  @tf_utils.shape_type_conversion
250  def compute_output_shape(self, input_shape):
251    if self.data_format == 'channels_first':
252      input_length = input_shape[2]
253    else:
254      input_length = input_shape[1]
255
256    length = conv_utils.conv_output_length(input_length, self.kernel_size[0],
257                                           self.padding, self.strides[0])
258
259    if self.data_format == 'channels_first':
260      return (input_shape[0], self.filters, length)
261    elif self.data_format == 'channels_last':
262      return (input_shape[0], length, self.filters)
263
264  def call(self, inputs):
265    if self.implementation == 1:
266      output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
267                            (self.output_length,), self.data_format)
268
269    elif self.implementation == 2:
270      output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
271                                 self.compute_output_shape(inputs.shape))
272
273    elif self.implementation == 3:
274      output = local_conv_sparse_matmul(inputs, self.kernel, self.kernel_idxs,
275                                        self.kernel_shape,
276                                        self.compute_output_shape(inputs.shape))
277
278    else:
279      raise ValueError('Unrecognized implementation mode: %d.' %
280                       self.implementation)
281
282    if self.use_bias:
283      output = K.bias_add(output, self.bias, data_format=self.data_format)
284
285    output = self.activation(output)
286    return output
287
288  def get_config(self):
289    config = {
290        'filters':
291            self.filters,
292        'kernel_size':
293            self.kernel_size,
294        'strides':
295            self.strides,
296        'padding':
297            self.padding,
298        'data_format':
299            self.data_format,
300        'activation':
301            activations.serialize(self.activation),
302        'use_bias':
303            self.use_bias,
304        'kernel_initializer':
305            initializers.serialize(self.kernel_initializer),
306        'bias_initializer':
307            initializers.serialize(self.bias_initializer),
308        'kernel_regularizer':
309            regularizers.serialize(self.kernel_regularizer),
310        'bias_regularizer':
311            regularizers.serialize(self.bias_regularizer),
312        'activity_regularizer':
313            regularizers.serialize(self.activity_regularizer),
314        'kernel_constraint':
315            constraints.serialize(self.kernel_constraint),
316        'bias_constraint':
317            constraints.serialize(self.bias_constraint),
318        'implementation':
319            self.implementation
320    }
321    base_config = super(LocallyConnected1D, self).get_config()
322    return dict(list(base_config.items()) + list(config.items()))
323
324
325@keras_export('keras.layers.LocallyConnected2D')
326class LocallyConnected2D(Layer):
327  """Locally-connected layer for 2D inputs.
328
329  The `LocallyConnected2D` layer works similarly
330  to the `Conv2D` layer, except that weights are unshared,
331  that is, a different set of filters is applied at each
332  different patch of the input.
333
334  Note: layer attributes cannot be modified after the layer has been called
335  once (except the `trainable` attribute).
336
337  Examples:
338  ```python
339      # apply a 3x3 unshared weights convolution with 64 output filters on a
340      32x32 image
341      # with `data_format="channels_last"`:
342      model = Sequential()
343      model.add(LocallyConnected2D(64, (3, 3), input_shape=(32, 32, 3)))
344      # now model.output_shape == (None, 30, 30, 64)
345      # notice that this layer will consume (30*30)*(3*3*3*64) + (30*30)*64
346      parameters
347
348      # add a 3x3 unshared weights convolution on top, with 32 output filters:
349      model.add(LocallyConnected2D(32, (3, 3)))
350      # now model.output_shape == (None, 28, 28, 32)
351  ```
352
353  Args:
354      filters: Integer, the dimensionality of the output space (i.e. the number
355        of output filters in the convolution).
356      kernel_size: An integer or tuple/list of 2 integers, specifying the width
357        and height of the 2D convolution window. Can be a single integer to
358        specify the same value for all spatial dimensions.
359      strides: An integer or tuple/list of 2 integers, specifying the strides of
360        the convolution along the width and height. Can be a single integer to
361        specify the same value for all spatial dimensions.
362      padding: Currently only support `"valid"` (case-insensitive). `"same"`
363        will be supported in future. `"valid"` means no padding.
364      data_format: A string, one of `channels_last` (default) or
365        `channels_first`. The ordering of the dimensions in the inputs.
366        `channels_last` corresponds to inputs with shape `(batch, height, width,
367        channels)` while `channels_first` corresponds to inputs with shape
368        `(batch, channels, height, width)`. It defaults to the
369        `image_data_format` value found in your Keras config file at
370        `~/.keras/keras.json`. If you never set it, then it will be
371        "channels_last".
372      activation: Activation function to use. If you don't specify anything, no
373        activation is applied
374          (ie. "linear" activation: `a(x) = x`).
375      use_bias: Boolean, whether the layer uses a bias vector.
376      kernel_initializer: Initializer for the `kernel` weights matrix.
377      bias_initializer: Initializer for the bias vector.
378      kernel_regularizer: Regularizer function applied to the `kernel` weights
379        matrix.
380      bias_regularizer: Regularizer function applied to the bias vector.
381      activity_regularizer: Regularizer function applied to the output of the
382        layer (its "activation").
383      kernel_constraint: Constraint function applied to the kernel matrix.
384      bias_constraint: Constraint function applied to the bias vector.
385      implementation: implementation mode, either `1`, `2`, or `3`. `1` loops
386        over input spatial locations to perform the forward pass. It is
387        memory-efficient but performs a lot of (small) ops.  `2` stores layer
388        weights in a dense but sparsely-populated 2D matrix and implements the
389        forward pass as a single matrix-multiply. It uses a lot of RAM but
390        performs few (large) ops.  `3` stores layer weights in a sparse tensor
391        and implements the forward pass as a single sparse matrix-multiply.
392          How to choose:
393          `1`: large, dense models,
394          `2`: small models,
395          `3`: large, sparse models,  where "large" stands for large
396            input/output activations (i.e. many `filters`, `input_filters`,
397            large `np.prod(input_size)`, `np.prod(output_size)`), and "sparse"
398            stands for few connections between inputs and outputs, i.e. small
399            ratio `filters * input_filters * np.prod(kernel_size) /
400            (np.prod(input_size) * np.prod(strides))`, where inputs to and
401            outputs of the layer are assumed to have shapes `input_size +
402            (input_filters,)`, `output_size + (filters,)` respectively.  It is
403            recommended to benchmark each in the setting of interest to pick the
404            most efficient one (in terms of speed and memory usage). Correct
405            choice of implementation can lead to dramatic speed improvements
406            (e.g. 50X), potentially at the expense of RAM.  Also, only
407            `padding="valid"` is supported by `implementation=1`.
408  Input shape:
409      4D tensor with shape: `(samples, channels, rows, cols)` if
410        data_format='channels_first'
411      or 4D tensor with shape: `(samples, rows, cols, channels)` if
412        data_format='channels_last'.
413  Output shape:
414      4D tensor with shape: `(samples, filters, new_rows, new_cols)` if
415        data_format='channels_first'
416      or 4D tensor with shape: `(samples, new_rows, new_cols, filters)` if
417        data_format='channels_last'. `rows` and `cols` values might have changed
418        due to padding.
419  """
420
421  def __init__(self,
422               filters,
423               kernel_size,
424               strides=(1, 1),
425               padding='valid',
426               data_format=None,
427               activation=None,
428               use_bias=True,
429               kernel_initializer='glorot_uniform',
430               bias_initializer='zeros',
431               kernel_regularizer=None,
432               bias_regularizer=None,
433               activity_regularizer=None,
434               kernel_constraint=None,
435               bias_constraint=None,
436               implementation=1,
437               **kwargs):
438    super(LocallyConnected2D, self).__init__(**kwargs)
439    self.filters = filters
440    self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
441    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
442    self.padding = conv_utils.normalize_padding(padding)
443    if self.padding != 'valid' and implementation == 1:
444      raise ValueError('Invalid border mode for LocallyConnected2D '
445                       '(only "valid" is supported if implementation is 1): ' +
446                       padding)
447    self.data_format = conv_utils.normalize_data_format(data_format)
448    self.activation = activations.get(activation)
449    self.use_bias = use_bias
450    self.kernel_initializer = initializers.get(kernel_initializer)
451    self.bias_initializer = initializers.get(bias_initializer)
452    self.kernel_regularizer = regularizers.get(kernel_regularizer)
453    self.bias_regularizer = regularizers.get(bias_regularizer)
454    self.activity_regularizer = regularizers.get(activity_regularizer)
455    self.kernel_constraint = constraints.get(kernel_constraint)
456    self.bias_constraint = constraints.get(bias_constraint)
457    self.implementation = implementation
458    self.input_spec = InputSpec(ndim=4)
459
460  @tf_utils.shape_type_conversion
461  def build(self, input_shape):
462    if self.data_format == 'channels_last':
463      input_row, input_col = input_shape[1:-1]
464      input_filter = input_shape[3]
465    else:
466      input_row, input_col = input_shape[2:]
467      input_filter = input_shape[1]
468    if input_row is None or input_col is None:
469      raise ValueError('The spatial dimensions of the inputs to '
470                       ' a LocallyConnected2D layer '
471                       'should be fully-defined, but layer received '
472                       'the inputs shape ' + str(input_shape))
473    output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
474                                               self.padding, self.strides[0])
475    output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
476                                               self.padding, self.strides[1])
477    self.output_row = output_row
478    self.output_col = output_col
479
480    if self.implementation == 1:
481      self.kernel_shape = (output_row * output_col, self.kernel_size[0] *
482                           self.kernel_size[1] * input_filter, self.filters)
483
484      self.kernel = self.add_weight(
485          shape=self.kernel_shape,
486          initializer=self.kernel_initializer,
487          name='kernel',
488          regularizer=self.kernel_regularizer,
489          constraint=self.kernel_constraint)
490
491    elif self.implementation == 2:
492      if self.data_format == 'channels_first':
493        self.kernel_shape = (input_filter, input_row, input_col, self.filters,
494                             self.output_row, self.output_col)
495      else:
496        self.kernel_shape = (input_row, input_col, input_filter,
497                             self.output_row, self.output_col, self.filters)
498
499      self.kernel = self.add_weight(
500          shape=self.kernel_shape,
501          initializer=self.kernel_initializer,
502          name='kernel',
503          regularizer=self.kernel_regularizer,
504          constraint=self.kernel_constraint)
505
506      self.kernel_mask = get_locallyconnected_mask(
507          input_shape=(input_row, input_col),
508          kernel_shape=self.kernel_size,
509          strides=self.strides,
510          padding=self.padding,
511          data_format=self.data_format,
512      )
513
514    elif self.implementation == 3:
515      self.kernel_shape = (self.output_row * self.output_col * self.filters,
516                           input_row * input_col * input_filter)
517
518      self.kernel_idxs = sorted(
519          conv_utils.conv_kernel_idxs(
520              input_shape=(input_row, input_col),
521              kernel_shape=self.kernel_size,
522              strides=self.strides,
523              padding=self.padding,
524              filters_in=input_filter,
525              filters_out=self.filters,
526              data_format=self.data_format))
527
528      self.kernel = self.add_weight(
529          shape=(len(self.kernel_idxs),),
530          initializer=self.kernel_initializer,
531          name='kernel',
532          regularizer=self.kernel_regularizer,
533          constraint=self.kernel_constraint)
534
535    else:
536      raise ValueError('Unrecognized implementation mode: %d.' %
537                       self.implementation)
538
539    if self.use_bias:
540      self.bias = self.add_weight(
541          shape=(output_row, output_col, self.filters),
542          initializer=self.bias_initializer,
543          name='bias',
544          regularizer=self.bias_regularizer,
545          constraint=self.bias_constraint)
546    else:
547      self.bias = None
548    if self.data_format == 'channels_first':
549      self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
550    else:
551      self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
552    self.built = True
553
554  @tf_utils.shape_type_conversion
555  def compute_output_shape(self, input_shape):
556    if self.data_format == 'channels_first':
557      rows = input_shape[2]
558      cols = input_shape[3]
559    elif self.data_format == 'channels_last':
560      rows = input_shape[1]
561      cols = input_shape[2]
562
563    rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
564                                         self.padding, self.strides[0])
565    cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
566                                         self.padding, self.strides[1])
567
568    if self.data_format == 'channels_first':
569      return (input_shape[0], self.filters, rows, cols)
570    elif self.data_format == 'channels_last':
571      return (input_shape[0], rows, cols, self.filters)
572
573  def call(self, inputs):
574    if self.implementation == 1:
575      output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
576                            (self.output_row, self.output_col),
577                            self.data_format)
578
579    elif self.implementation == 2:
580      output = local_conv_matmul(inputs, self.kernel, self.kernel_mask,
581                                 self.compute_output_shape(inputs.shape))
582
583    elif self.implementation == 3:
584      output = local_conv_sparse_matmul(inputs, self.kernel, self.kernel_idxs,
585                                        self.kernel_shape,
586                                        self.compute_output_shape(inputs.shape))
587
588    else:
589      raise ValueError('Unrecognized implementation mode: %d.' %
590                       self.implementation)
591
592    if self.use_bias:
593      output = K.bias_add(output, self.bias, data_format=self.data_format)
594
595    output = self.activation(output)
596    return output
597
598  def get_config(self):
599    config = {
600        'filters':
601            self.filters,
602        'kernel_size':
603            self.kernel_size,
604        'strides':
605            self.strides,
606        'padding':
607            self.padding,
608        'data_format':
609            self.data_format,
610        'activation':
611            activations.serialize(self.activation),
612        'use_bias':
613            self.use_bias,
614        'kernel_initializer':
615            initializers.serialize(self.kernel_initializer),
616        'bias_initializer':
617            initializers.serialize(self.bias_initializer),
618        'kernel_regularizer':
619            regularizers.serialize(self.kernel_regularizer),
620        'bias_regularizer':
621            regularizers.serialize(self.bias_regularizer),
622        'activity_regularizer':
623            regularizers.serialize(self.activity_regularizer),
624        'kernel_constraint':
625            constraints.serialize(self.kernel_constraint),
626        'bias_constraint':
627            constraints.serialize(self.bias_constraint),
628        'implementation':
629            self.implementation
630    }
631    base_config = super(LocallyConnected2D, self).get_config()
632    return dict(list(base_config.items()) + list(config.items()))
633
634
635def get_locallyconnected_mask(input_shape, kernel_shape, strides, padding,
636                              data_format):
637  """Return a mask representing connectivity of a locally-connected operation.
638
639  This method returns a masking numpy array of 0s and 1s (of type `np.float32`)
640  that, when element-wise multiplied with a fully-connected weight tensor, masks
641  out the weights between disconnected input-output pairs and thus implements
642  local connectivity through a sparse fully-connected weight tensor.
643
644  Assume an unshared convolution with given parameters is applied to an input
645  having N spatial dimensions with `input_shape = (d_in1, ..., d_inN)`
646  to produce an output with spatial shape `(d_out1, ..., d_outN)` (determined
647  by layer parameters such as `strides`).
648
649  This method returns a mask which can be broadcast-multiplied (element-wise)
650  with a 2*(N+1)-D weight matrix (equivalent to a fully-connected layer between
651  (N+1)-D activations (N spatial + 1 channel dimensions for input and output)
652  to make it perform an unshared convolution with given `kernel_shape`,
653  `strides`, `padding` and `data_format`.
654
655  Args:
656    input_shape: tuple of size N: `(d_in1, ..., d_inN)` spatial shape of the
657      input.
658    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
659      receptive field.
660    strides: tuple of size N, strides along each spatial dimension.
661    padding: type of padding, string `"same"` or `"valid"`.
662    data_format: a string, `"channels_first"` or `"channels_last"`.
663
664  Returns:
665    a `np.float32`-type `np.ndarray` of shape
666    `(1, d_in1, ..., d_inN, 1, d_out1, ..., d_outN)`
667    if `data_format == `"channels_first"`, or
668    `(d_in1, ..., d_inN, 1, d_out1, ..., d_outN, 1)`
669    if `data_format == "channels_last"`.
670
671  Raises:
672    ValueError: if `data_format` is neither `"channels_first"` nor
673                `"channels_last"`.
674  """
675  mask = conv_utils.conv_kernel_mask(
676      input_shape=input_shape,
677      kernel_shape=kernel_shape,
678      strides=strides,
679      padding=padding)
680
681  ndims = int(mask.ndim / 2)
682
683  if data_format == 'channels_first':
684    mask = np.expand_dims(mask, 0)
685    mask = np.expand_dims(mask, -ndims - 1)
686
687  elif data_format == 'channels_last':
688    mask = np.expand_dims(mask, ndims)
689    mask = np.expand_dims(mask, -1)
690
691  else:
692    raise ValueError('Unrecognized data_format: ' + str(data_format))
693
694  return mask
695
696
697def local_conv_matmul(inputs, kernel, kernel_mask, output_shape):
698  """Apply N-D convolution with un-shared weights using a single matmul call.
699
700  This method outputs `inputs . (kernel * kernel_mask)`
701  (with `.` standing for matrix-multiply and `*` for element-wise multiply)
702  and requires a precomputed `kernel_mask` to zero-out weights in `kernel` and
703  hence perform the same operation as a convolution with un-shared
704  (the remaining entries in `kernel`) weights. It also does the necessary
705  reshapes to make `inputs` and `kernel` 2-D and `output` (N+2)-D.
706
707  Args:
708      inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ...,
709        d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`.
710      kernel: the unshared weights for N-D convolution,
711          an (N+2)-D tensor of shape: `(d_in1, ..., d_inN, channels_in, d_out2,
712            ..., d_outN, channels_out)` or `(channels_in, d_in1, ..., d_inN,
713            channels_out, d_out2, ..., d_outN)`, with the ordering of channels
714            and spatial dimensions matching that of the input. Each entry is the
715            weight between a particular input and output location, similarly to
716            a fully-connected weight matrix.
717      kernel_mask: a float 0/1 mask tensor of shape: `(d_in1, ..., d_inN, 1,
718        d_out2, ..., d_outN, 1)` or `(1, d_in1, ..., d_inN, 1, d_out2, ...,
719        d_outN)`, with the ordering of singleton and spatial dimensions matching
720        that of the input. Mask represents the connectivity pattern of the layer
721        and is
722           precomputed elsewhere based on layer parameters: stride, padding, and
723             the receptive field shape.
724      output_shape: a tuple of (N+2) elements representing the output shape:
725        `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size,
726        d_out1, ..., d_outN, channels_out)`, with the ordering of channels and
727        spatial dimensions matching that of the input.
728
729  Returns:
730      Output (N+2)-D tensor with shape `output_shape`.
731  """
732  inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
733
734  kernel = kernel_mask * kernel
735  kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2)
736
737  output_flat = math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True)
738  output = K.reshape(output_flat, [
739      K.shape(output_flat)[0],
740  ] + output_shape.as_list()[1:])
741  return output
742
743
744def local_conv_sparse_matmul(inputs, kernel, kernel_idxs, kernel_shape,
745                             output_shape):
746  """Apply N-D convolution with un-shared weights using a single sparse matmul.
747
748  This method outputs `inputs . tf.sparse.SparseTensor(indices=kernel_idxs,
749  values=kernel, dense_shape=kernel_shape)`, with `.` standing for
750  matrix-multiply. It also reshapes `inputs` to 2-D and `output` to (N+2)-D.
751
752  Args:
753      inputs: (N+2)-D tensor with shape `(batch_size, channels_in, d_in1, ...,
754        d_inN)` or `(batch_size, d_in1, ..., d_inN, channels_in)`.
755      kernel: a 1-D tensor with shape `(len(kernel_idxs),)` containing all the
756        weights of the layer.
757      kernel_idxs:  a list of integer tuples representing indices in a sparse
758        matrix performing the un-shared convolution as a matrix-multiply.
759      kernel_shape: a tuple `(input_size, output_size)`, where `input_size =
760        channels_in * d_in1 * ... * d_inN` and `output_size = channels_out *
761        d_out1 * ... * d_outN`.
762      output_shape: a tuple of (N+2) elements representing the output shape:
763        `(batch_size, channels_out, d_out1, ..., d_outN)` or `(batch_size,
764        d_out1, ..., d_outN, channels_out)`, with the ordering of channels and
765        spatial dimensions matching that of the input.
766
767  Returns:
768      Output (N+2)-D dense tensor with shape `output_shape`.
769  """
770  inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1))
771  output_flat = gen_sparse_ops.SparseTensorDenseMatMul(
772      a_indices=kernel_idxs,
773      a_values=kernel,
774      a_shape=kernel_shape,
775      b=inputs_flat,
776      adjoint_b=True)
777  output_flat_transpose = K.transpose(output_flat)
778
779  output_reshaped = K.reshape(output_flat_transpose, [
780      K.shape(output_flat_transpose)[0],
781  ] + output_shape.as_list()[1:])
782  return output_reshaped
783
784
785def make_2d(tensor, split_dim):
786  """Reshapes an N-dimensional tensor into a 2D tensor.
787
788  Dimensions before (excluding) and after (including) `split_dim` are grouped
789  together.
790
791  Args:
792    tensor: a tensor of shape `(d0, ..., d(N-1))`.
793    split_dim: an integer from 1 to N-1, index of the dimension to group
794      dimensions before (excluding) and after (including).
795
796  Returns:
797    Tensor of shape
798    `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`.
799  """
800  shape = array_ops.shape(tensor)
801  in_dims = shape[:split_dim]
802  out_dims = shape[split_dim:]
803
804  in_size = math_ops.reduce_prod(in_dims)
805  out_size = math_ops.reduce_prod(out_dims)
806
807  return array_ops.reshape(tensor, (in_size, out_size))
808