1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Pooling layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.keras import backend
25from tensorflow.python.keras.engine.base_layer import Layer
26from tensorflow.python.keras.engine.input_spec import InputSpec
27from tensorflow.python.keras.utils import conv_utils
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn
31from tensorflow.python.util.tf_export import keras_export
32
33
34class Pooling1D(Layer):
35  """Pooling layer for arbitrary pooling functions, for 1D inputs.
36
37  This class only exists for code reuse. It will never be an exposed API.
38
39  Arguments:
40    pool_function: The pooling function to apply, e.g. `tf.nn.max_pool`.
41    pool_size: An integer or tuple/list of a single integer,
42      representing the size of the pooling window.
43    strides: An integer or tuple/list of a single integer, specifying the
44      strides of the pooling operation.
45    padding: A string. The padding method, either 'valid' or 'same'.
46      Case-insensitive.
47    data_format: A string,
48      one of `channels_last` (default) or `channels_first`.
49      The ordering of the dimensions in the inputs.
50      `channels_last` corresponds to inputs with shape
51      `(batch, steps, features)` while `channels_first`
52      corresponds to inputs with shape
53      `(batch, features, steps)`.
54    name: A string, the name of the layer.
55  """
56
57  def __init__(self, pool_function, pool_size, strides,
58               padding='valid', data_format='channels_last',
59               name=None, **kwargs):
60    super(Pooling1D, self).__init__(name=name, **kwargs)
61    if data_format is None:
62      data_format = backend.image_data_format()
63    if strides is None:
64      strides = pool_size
65    self.pool_function = pool_function
66    self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size')
67    self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
68    self.padding = conv_utils.normalize_padding(padding)
69    self.data_format = conv_utils.normalize_data_format(data_format)
70    self.input_spec = InputSpec(ndim=3)
71
72  def call(self, inputs):
73    pad_axis = 2 if self.data_format == 'channels_last' else 3
74    inputs = array_ops.expand_dims(inputs, pad_axis)
75    outputs = self.pool_function(
76        inputs,
77        self.pool_size + (1,),
78        strides=self.strides + (1,),
79        padding=self.padding,
80        data_format=self.data_format)
81    return array_ops.squeeze(outputs, pad_axis)
82
83  def compute_output_shape(self, input_shape):
84    input_shape = tensor_shape.TensorShape(input_shape).as_list()
85    if self.data_format == 'channels_first':
86      steps = input_shape[2]
87      features = input_shape[1]
88    else:
89      steps = input_shape[1]
90      features = input_shape[2]
91    length = conv_utils.conv_output_length(steps,
92                                           self.pool_size[0],
93                                           self.padding,
94                                           self.strides[0])
95    if self.data_format == 'channels_first':
96      return tensor_shape.TensorShape([input_shape[0], features, length])
97    else:
98      return tensor_shape.TensorShape([input_shape[0], length, features])
99
100  def get_config(self):
101    config = {
102        'strides': self.strides,
103        'pool_size': self.pool_size,
104        'padding': self.padding,
105        'data_format': self.data_format,
106    }
107    base_config = super(Pooling1D, self).get_config()
108    return dict(list(base_config.items()) + list(config.items()))
109
110
111@keras_export('keras.layers.MaxPool1D', 'keras.layers.MaxPooling1D')
112class MaxPooling1D(Pooling1D):
113  """Max pooling operation for temporal data.
114
115  Arguments:
116    pool_size: Integer, size of the max pooling windows.
117    strides: Integer, or None. Factor by which to downscale.
118      E.g. 2 will halve the input.
119      If None, it will default to `pool_size`.
120    padding: One of `"valid"` or `"same"` (case-insensitive).
121    data_format: A string,
122      one of `channels_last` (default) or `channels_first`.
123      The ordering of the dimensions in the inputs.
124      `channels_last` corresponds to inputs with shape
125      `(batch, steps, features)` while `channels_first`
126      corresponds to inputs with shape
127      `(batch, features, steps)`.
128
129  Input shape:
130    - If `data_format='channels_last'`:
131      3D tensor with shape `(batch_size, steps, features)`.
132    - If `data_format='channels_first'`:
133      3D tensor with shape `(batch_size, features, steps)`.
134
135  Output shape:
136    - If `data_format='channels_last'`:
137      3D tensor with shape `(batch_size, downsampled_steps, features)`.
138    - If `data_format='channels_first'`:
139      3D tensor with shape `(batch_size, features, downsampled_steps)`.
140  """
141
142  def __init__(self, pool_size=2, strides=None,
143               padding='valid', data_format='channels_last', **kwargs):
144
145    super(MaxPooling1D, self).__init__(
146        functools.partial(backend.pool2d, pool_mode='max'),
147        pool_size=pool_size,
148        strides=strides,
149        padding=padding,
150        data_format=data_format,
151        **kwargs)
152
153
154@keras_export('keras.layers.AveragePooling1D', 'keras.layers.AvgPool1D')
155class AveragePooling1D(Pooling1D):
156  """Average pooling for temporal data.
157
158  Arguments:
159    pool_size: Integer, size of the max pooling windows.
160    strides: Integer, or None. Factor by which to downscale.
161      E.g. 2 will halve the input.
162      If None, it will default to `pool_size`.
163    padding: One of `"valid"` or `"same"` (case-insensitive).
164    data_format: A string,
165      one of `channels_last` (default) or `channels_first`.
166      The ordering of the dimensions in the inputs.
167      `channels_last` corresponds to inputs with shape
168      `(batch, steps, features)` while `channels_first`
169      corresponds to inputs with shape
170      `(batch, features, steps)`.
171
172  Input shape:
173    - If `data_format='channels_last'`:
174      3D tensor with shape `(batch_size, steps, features)`.
175    - If `data_format='channels_first'`:
176      3D tensor with shape `(batch_size, features, steps)`.
177
178  Output shape:
179    - If `data_format='channels_last'`:
180      3D tensor with shape `(batch_size, downsampled_steps, features)`.
181    - If `data_format='channels_first'`:
182      3D tensor with shape `(batch_size, features, downsampled_steps)`.
183  """
184
185  def __init__(self, pool_size=2, strides=None,
186               padding='valid', data_format='channels_last', **kwargs):
187    super(AveragePooling1D, self).__init__(
188        functools.partial(backend.pool2d, pool_mode='avg'),
189        pool_size=pool_size,
190        strides=strides,
191        padding=padding,
192        data_format=data_format,
193        **kwargs)
194
195
196class Pooling2D(Layer):
197  """Pooling layer for arbitrary pooling functions, for 2D inputs (e.g. images).
198
199  This class only exists for code reuse. It will never be an exposed API.
200
201  Arguments:
202    pool_function: The pooling function to apply, e.g. `tf.nn.max_pool`.
203    pool_size: An integer or tuple/list of 2 integers: (pool_height, pool_width)
204      specifying the size of the pooling window.
205      Can be a single integer to specify the same value for
206      all spatial dimensions.
207    strides: An integer or tuple/list of 2 integers,
208      specifying the strides of the pooling operation.
209      Can be a single integer to specify the same value for
210      all spatial dimensions.
211    padding: A string. The padding method, either 'valid' or 'same'.
212      Case-insensitive.
213    data_format: A string, one of `channels_last` (default) or `channels_first`.
214      The ordering of the dimensions in the inputs.
215      `channels_last` corresponds to inputs with shape
216      `(batch, height, width, channels)` while `channels_first` corresponds to
217      inputs with shape `(batch, channels, height, width)`.
218    name: A string, the name of the layer.
219  """
220
221  def __init__(self, pool_function, pool_size, strides,
222               padding='valid', data_format=None,
223               name=None, **kwargs):
224    super(Pooling2D, self).__init__(name=name, **kwargs)
225    if data_format is None:
226      data_format = backend.image_data_format()
227    if strides is None:
228      strides = pool_size
229    self.pool_function = pool_function
230    self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size')
231    self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
232    self.padding = conv_utils.normalize_padding(padding)
233    self.data_format = conv_utils.normalize_data_format(data_format)
234    self.input_spec = InputSpec(ndim=4)
235
236  def call(self, inputs):
237    if self.data_format == 'channels_last':
238      pool_shape = (1,) + self.pool_size + (1,)
239      strides = (1,) + self.strides + (1,)
240    else:
241      pool_shape = (1, 1) + self.pool_size
242      strides = (1, 1) + self.strides
243    outputs = self.pool_function(
244        inputs,
245        ksize=pool_shape,
246        strides=strides,
247        padding=self.padding.upper(),
248        data_format=conv_utils.convert_data_format(self.data_format, 4))
249    return outputs
250
251  def compute_output_shape(self, input_shape):
252    input_shape = tensor_shape.TensorShape(input_shape).as_list()
253    if self.data_format == 'channels_first':
254      rows = input_shape[2]
255      cols = input_shape[3]
256    else:
257      rows = input_shape[1]
258      cols = input_shape[2]
259    rows = conv_utils.conv_output_length(rows, self.pool_size[0], self.padding,
260                                         self.strides[0])
261    cols = conv_utils.conv_output_length(cols, self.pool_size[1], self.padding,
262                                         self.strides[1])
263    if self.data_format == 'channels_first':
264      return tensor_shape.TensorShape(
265          [input_shape[0], input_shape[1], rows, cols])
266    else:
267      return tensor_shape.TensorShape(
268          [input_shape[0], rows, cols, input_shape[3]])
269
270  def get_config(self):
271    config = {
272        'pool_size': self.pool_size,
273        'padding': self.padding,
274        'strides': self.strides,
275        'data_format': self.data_format
276    }
277    base_config = super(Pooling2D, self).get_config()
278    return dict(list(base_config.items()) + list(config.items()))
279
280
281@keras_export('keras.layers.MaxPool2D', 'keras.layers.MaxPooling2D')
282class MaxPooling2D(Pooling2D):
283  """Max pooling operation for spatial data.
284
285  Arguments:
286    pool_size: integer or tuple of 2 integers,
287      factors by which to downscale (vertical, horizontal).
288      `(2, 2)` will halve the input in both spatial dimension.
289      If only one integer is specified, the same window length
290      will be used for both dimensions.
291    strides: Integer, tuple of 2 integers, or None.
292      Strides values.
293      If None, it will default to `pool_size`.
294    padding: One of `"valid"` or `"same"` (case-insensitive).
295    data_format: A string,
296      one of `channels_last` (default) or `channels_first`.
297      The ordering of the dimensions in the inputs.
298      `channels_last` corresponds to inputs with shape
299      `(batch, height, width, channels)` while `channels_first`
300      corresponds to inputs with shape
301      `(batch, channels, height, width)`.
302      It defaults to the `image_data_format` value found in your
303      Keras config file at `~/.keras/keras.json`.
304      If you never set it, then it will be "channels_last".
305
306  Input shape:
307    - If `data_format='channels_last'`:
308      4D tensor with shape `(batch_size, rows, cols, channels)`.
309    - If `data_format='channels_first'`:
310      4D tensor with shape `(batch_size, channels, rows, cols)`.
311
312  Output shape:
313    - If `data_format='channels_last'`:
314      4D tensor with shape `(batch_size, pooled_rows, pooled_cols, channels)`.
315    - If `data_format='channels_first'`:
316      4D tensor with shape `(batch_size, channels, pooled_rows, pooled_cols)`.
317  """
318
319  def __init__(self,
320               pool_size=(2, 2),
321               strides=None,
322               padding='valid',
323               data_format=None,
324               **kwargs):
325    super(MaxPooling2D, self).__init__(
326        nn.max_pool,
327        pool_size=pool_size, strides=strides,
328        padding=padding, data_format=data_format, **kwargs)
329
330
331@keras_export('keras.layers.AveragePooling2D', 'keras.layers.AvgPool2D')
332class AveragePooling2D(Pooling2D):
333  """Average pooling operation for spatial data.
334
335  Arguments:
336    pool_size: integer or tuple of 2 integers,
337      factors by which to downscale (vertical, horizontal).
338      `(2, 2)` will halve the input in both spatial dimension.
339      If only one integer is specified, the same window length
340      will be used for both dimensions.
341    strides: Integer, tuple of 2 integers, or None.
342      Strides values.
343      If None, it will default to `pool_size`.
344    padding: One of `"valid"` or `"same"` (case-insensitive).
345    data_format: A string,
346      one of `channels_last` (default) or `channels_first`.
347      The ordering of the dimensions in the inputs.
348      `channels_last` corresponds to inputs with shape
349      `(batch, height, width, channels)` while `channels_first`
350      corresponds to inputs with shape
351      `(batch, channels, height, width)`.
352      It defaults to the `image_data_format` value found in your
353      Keras config file at `~/.keras/keras.json`.
354      If you never set it, then it will be "channels_last".
355
356  Input shape:
357    - If `data_format='channels_last'`:
358      4D tensor with shape `(batch_size, rows, cols, channels)`.
359    - If `data_format='channels_first'`:
360      4D tensor with shape `(batch_size, channels, rows, cols)`.
361
362  Output shape:
363    - If `data_format='channels_last'`:
364      4D tensor with shape `(batch_size, pooled_rows, pooled_cols, channels)`.
365    - If `data_format='channels_first'`:
366      4D tensor with shape `(batch_size, channels, pooled_rows, pooled_cols)`.
367  """
368
369  def __init__(self,
370               pool_size=(2, 2),
371               strides=None,
372               padding='valid',
373               data_format=None,
374               **kwargs):
375    super(AveragePooling2D, self).__init__(
376        nn.avg_pool,
377        pool_size=pool_size, strides=strides,
378        padding=padding, data_format=data_format, **kwargs)
379
380
381class Pooling3D(Layer):
382  """Pooling layer for arbitrary pooling functions, for 3D inputs.
383
384  This class only exists for code reuse. It will never be an exposed API.
385
386  Arguments:
387    pool_function: The pooling function to apply, e.g. `tf.nn.max_pool`.
388    pool_size: An integer or tuple/list of 3 integers:
389      (pool_depth, pool_height, pool_width)
390      specifying the size of the pooling window.
391      Can be a single integer to specify the same value for
392      all spatial dimensions.
393    strides: An integer or tuple/list of 3 integers,
394      specifying the strides of the pooling operation.
395      Can be a single integer to specify the same value for
396      all spatial dimensions.
397    padding: A string. The padding method, either 'valid' or 'same'.
398      Case-insensitive.
399    data_format: A string, one of `channels_last` (default) or `channels_first`.
400      The ordering of the dimensions in the inputs.
401      `channels_last` corresponds to inputs with shape
402      `(batch, depth, height, width, channels)`
403      while `channels_first` corresponds to
404      inputs with shape `(batch, channels, depth, height, width)`.
405    name: A string, the name of the layer.
406  """
407
408  def __init__(self, pool_function, pool_size, strides,
409               padding='valid', data_format='channels_last',
410               name=None, **kwargs):
411    super(Pooling3D, self).__init__(name=name, **kwargs)
412    if data_format is None:
413      data_format = backend.image_data_format()
414    if strides is None:
415      strides = pool_size
416    self.pool_function = pool_function
417    self.pool_size = conv_utils.normalize_tuple(pool_size, 3, 'pool_size')
418    self.strides = conv_utils.normalize_tuple(strides, 3, 'strides')
419    self.padding = conv_utils.normalize_padding(padding)
420    self.data_format = conv_utils.normalize_data_format(data_format)
421    self.input_spec = InputSpec(ndim=5)
422
423  def call(self, inputs):
424    pool_shape = (1,) + self.pool_size + (1,)
425    strides = (1,) + self.strides + (1,)
426
427    if self.data_format == 'channels_first':
428      # TF does not support `channels_first` with 3D pooling operations,
429      # so we must handle this case manually.
430      # TODO(fchollet): remove this when TF pooling is feature-complete.
431      inputs = array_ops.transpose(inputs, (0, 2, 3, 4, 1))
432
433    outputs = self.pool_function(
434        inputs,
435        ksize=pool_shape,
436        strides=strides,
437        padding=self.padding.upper())
438
439    if self.data_format == 'channels_first':
440      outputs = array_ops.transpose(outputs, (0, 4, 1, 2, 3))
441    return outputs
442
443  def compute_output_shape(self, input_shape):
444    input_shape = tensor_shape.TensorShape(input_shape).as_list()
445    if self.data_format == 'channels_first':
446      len_dim1 = input_shape[2]
447      len_dim2 = input_shape[3]
448      len_dim3 = input_shape[4]
449    else:
450      len_dim1 = input_shape[1]
451      len_dim2 = input_shape[2]
452      len_dim3 = input_shape[3]
453    len_dim1 = conv_utils.conv_output_length(len_dim1, self.pool_size[0],
454                                             self.padding, self.strides[0])
455    len_dim2 = conv_utils.conv_output_length(len_dim2, self.pool_size[1],
456                                             self.padding, self.strides[1])
457    len_dim3 = conv_utils.conv_output_length(len_dim3, self.pool_size[2],
458                                             self.padding, self.strides[2])
459    if self.data_format == 'channels_first':
460      return tensor_shape.TensorShape(
461          [input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3])
462    else:
463      return tensor_shape.TensorShape(
464          [input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]])
465
466  def get_config(self):
467    config = {
468        'pool_size': self.pool_size,
469        'padding': self.padding,
470        'strides': self.strides,
471        'data_format': self.data_format
472    }
473    base_config = super(Pooling3D, self).get_config()
474    return dict(list(base_config.items()) + list(config.items()))
475
476
477@keras_export('keras.layers.MaxPool3D', 'keras.layers.MaxPooling3D')
478class MaxPooling3D(Pooling3D):
479  """Max pooling operation for 3D data (spatial or spatio-temporal).
480
481  Arguments:
482    pool_size: Tuple of 3 integers,
483      factors by which to downscale (dim1, dim2, dim3).
484      `(2, 2, 2)` will halve the size of the 3D input in each dimension.
485    strides: tuple of 3 integers, or None. Strides values.
486    padding: One of `"valid"` or `"same"` (case-insensitive).
487    data_format: A string,
488      one of `channels_last` (default) or `channels_first`.
489      The ordering of the dimensions in the inputs.
490      `channels_last` corresponds to inputs with shape
491      `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
492      while `channels_first` corresponds to inputs with shape
493      `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
494      It defaults to the `image_data_format` value found in your
495      Keras config file at `~/.keras/keras.json`.
496      If you never set it, then it will be "channels_last".
497
498  Input shape:
499    - If `data_format='channels_last'`:
500      5D tensor with shape:
501      `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
502    - If `data_format='channels_first'`:
503      5D tensor with shape:
504      `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
505
506  Output shape:
507    - If `data_format='channels_last'`:
508      5D tensor with shape:
509      `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
510    - If `data_format='channels_first'`:
511      5D tensor with shape:
512      `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`
513  """
514
515  def __init__(self,
516               pool_size=(2, 2, 2),
517               strides=None,
518               padding='valid',
519               data_format=None,
520               **kwargs):
521    super(MaxPooling3D, self).__init__(
522        nn.max_pool3d,
523        pool_size=pool_size, strides=strides,
524        padding=padding, data_format=data_format, **kwargs)
525
526
527@keras_export('keras.layers.AveragePooling3D', 'keras.layers.AvgPool3D')
528class AveragePooling3D(Pooling3D):
529  """Average pooling operation for 3D data (spatial or spatio-temporal).
530
531  Arguments:
532    pool_size: tuple of 3 integers,
533      factors by which to downscale (dim1, dim2, dim3).
534      `(2, 2, 2)` will halve the size of the 3D input in each dimension.
535    strides: tuple of 3 integers, or None. Strides values.
536    padding: One of `"valid"` or `"same"` (case-insensitive).
537    data_format: A string,
538      one of `channels_last` (default) or `channels_first`.
539      The ordering of the dimensions in the inputs.
540      `channels_last` corresponds to inputs with shape
541      `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
542      while `channels_first` corresponds to inputs with shape
543      `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
544      It defaults to the `image_data_format` value found in your
545      Keras config file at `~/.keras/keras.json`.
546      If you never set it, then it will be "channels_last".
547
548  Input shape:
549    - If `data_format='channels_last'`:
550      5D tensor with shape:
551      `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
552    - If `data_format='channels_first'`:
553      5D tensor with shape:
554      `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
555
556  Output shape:
557    - If `data_format='channels_last'`:
558      5D tensor with shape:
559      `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`
560    - If `data_format='channels_first'`:
561      5D tensor with shape:
562      `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`
563  """
564
565  def __init__(self,
566               pool_size=(2, 2, 2),
567               strides=None,
568               padding='valid',
569               data_format=None,
570               **kwargs):
571    super(AveragePooling3D, self).__init__(
572        nn.avg_pool3d,
573        pool_size=pool_size, strides=strides,
574        padding=padding, data_format=data_format, **kwargs)
575
576
577class GlobalPooling1D(Layer):
578  """Abstract class for different global pooling 1D layers."""
579
580  def __init__(self, data_format='channels_last', **kwargs):
581    super(GlobalPooling1D, self).__init__(**kwargs)
582    self.input_spec = InputSpec(ndim=3)
583    self.data_format = conv_utils.normalize_data_format(data_format)
584
585  def compute_output_shape(self, input_shape):
586    input_shape = tensor_shape.TensorShape(input_shape).as_list()
587    if self.data_format == 'channels_first':
588      return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
589    else:
590      return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
591
592  def call(self, inputs):
593    raise NotImplementedError
594
595  def get_config(self):
596    config = {'data_format': self.data_format}
597    base_config = super(GlobalPooling1D, self).get_config()
598    return dict(list(base_config.items()) + list(config.items()))
599
600
601@keras_export('keras.layers.GlobalAveragePooling1D',
602              'keras.layers.GlobalAvgPool1D')
603class GlobalAveragePooling1D(GlobalPooling1D):
604  """Global average pooling operation for temporal data.
605
606  Arguments:
607    data_format: A string,
608      one of `channels_last` (default) or `channels_first`.
609      The ordering of the dimensions in the inputs.
610      `channels_last` corresponds to inputs with shape
611      `(batch, steps, features)` while `channels_first`
612      corresponds to inputs with shape
613      `(batch, features, steps)`.
614
615  Call arguments:
616    inputs: A 3D tensor.
617    mask: Binary tensor of shape `(batch_size, steps)` indicating whether
618      a given step should be masked (excluded from the average).
619
620  Input shape:
621    - If `data_format='channels_last'`:
622      3D tensor with shape:
623      `(batch_size, steps, features)`
624    - If `data_format='channels_first'`:
625      3D tensor with shape:
626      `(batch_size, features, steps)`
627
628  Output shape:
629    2D tensor with shape `(batch_size, features)`.
630  """
631
632  def __init__(self, data_format='channels_last', **kwargs):
633    super(GlobalAveragePooling1D, self).__init__(data_format=data_format,
634                                                 **kwargs)
635    self.supports_masking = True
636
637  def call(self, inputs, mask=None):
638    steps_axis = 1 if self.data_format == 'channels_last' else 2
639    if mask is not None:
640      mask = math_ops.cast(mask, backend.floatx())
641      input_shape = inputs.shape.as_list()
642      broadcast_shape = [-1, input_shape[steps_axis], 1]
643      mask = array_ops.reshape(mask, broadcast_shape)
644      inputs *= mask
645      return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum(
646          mask, axis=steps_axis)
647    else:
648      return backend.mean(inputs, axis=steps_axis)
649
650  def compute_mask(self, inputs, mask=None):
651    return None
652
653
654@keras_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D')
655class GlobalMaxPooling1D(GlobalPooling1D):
656  """Global max pooling operation for temporal data.
657
658  Arguments:
659    data_format: A string,
660      one of `channels_last` (default) or `channels_first`.
661      The ordering of the dimensions in the inputs.
662      `channels_last` corresponds to inputs with shape
663      `(batch, steps, features)` while `channels_first`
664      corresponds to inputs with shape
665      `(batch, features, steps)`.
666
667  Input shape:
668    - If `data_format='channels_last'`:
669      3D tensor with shape:
670      `(batch_size, steps, features)`
671    - If `data_format='channels_first'`:
672      3D tensor with shape:
673      `(batch_size, features, steps)`
674
675  Output shape:
676    2D tensor with shape `(batch_size, features)`.
677  """
678
679  def call(self, inputs):
680    steps_axis = 1 if self.data_format == 'channels_last' else 2
681    return backend.max(inputs, axis=steps_axis)
682
683
684class GlobalPooling2D(Layer):
685  """Abstract class for different global pooling 2D layers.
686  """
687
688  def __init__(self, data_format=None, **kwargs):
689    super(GlobalPooling2D, self).__init__(**kwargs)
690    self.data_format = conv_utils.normalize_data_format(data_format)
691    self.input_spec = InputSpec(ndim=4)
692
693  def compute_output_shape(self, input_shape):
694    input_shape = tensor_shape.TensorShape(input_shape).as_list()
695    if self.data_format == 'channels_last':
696      return tensor_shape.TensorShape([input_shape[0], input_shape[3]])
697    else:
698      return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
699
700  def call(self, inputs):
701    raise NotImplementedError
702
703  def get_config(self):
704    config = {'data_format': self.data_format}
705    base_config = super(GlobalPooling2D, self).get_config()
706    return dict(list(base_config.items()) + list(config.items()))
707
708
709@keras_export('keras.layers.GlobalAveragePooling2D',
710              'keras.layers.GlobalAvgPool2D')
711class GlobalAveragePooling2D(GlobalPooling2D):
712  """Global average pooling operation for spatial data.
713
714  Arguments:
715      data_format: A string,
716        one of `channels_last` (default) or `channels_first`.
717        The ordering of the dimensions in the inputs.
718        `channels_last` corresponds to inputs with shape
719        `(batch, height, width, channels)` while `channels_first`
720        corresponds to inputs with shape
721        `(batch, channels, height, width)`.
722        It defaults to the `image_data_format` value found in your
723        Keras config file at `~/.keras/keras.json`.
724        If you never set it, then it will be "channels_last".
725
726  Input shape:
727    - If `data_format='channels_last'`:
728      4D tensor with shape `(batch_size, rows, cols, channels)`.
729    - If `data_format='channels_first'`:
730      4D tensor with shape `(batch_size, channels, rows, cols)`.
731
732  Output shape:
733    2D tensor with shape `(batch_size, channels)`.
734  """
735
736  def call(self, inputs):
737    if self.data_format == 'channels_last':
738      return backend.mean(inputs, axis=[1, 2])
739    else:
740      return backend.mean(inputs, axis=[2, 3])
741
742
743@keras_export('keras.layers.GlobalMaxPool2D', 'keras.layers.GlobalMaxPooling2D')
744class GlobalMaxPooling2D(GlobalPooling2D):
745  """Global max pooling operation for spatial data.
746
747  Arguments:
748    data_format: A string,
749      one of `channels_last` (default) or `channels_first`.
750      The ordering of the dimensions in the inputs.
751      `channels_last` corresponds to inputs with shape
752      `(batch, height, width, channels)` while `channels_first`
753      corresponds to inputs with shape
754      `(batch, channels, height, width)`.
755      It defaults to the `image_data_format` value found in your
756      Keras config file at `~/.keras/keras.json`.
757      If you never set it, then it will be "channels_last".
758
759  Input shape:
760    - If `data_format='channels_last'`:
761      4D tensor with shape `(batch_size, rows, cols, channels)`.
762    - If `data_format='channels_first'`:
763      4D tensor with shape `(batch_size, channels, rows, cols)`.
764
765  Output shape:
766    2D tensor with shape `(batch_size, channels)`.
767  """
768
769  def call(self, inputs):
770    if self.data_format == 'channels_last':
771      return backend.max(inputs, axis=[1, 2])
772    else:
773      return backend.max(inputs, axis=[2, 3])
774
775
776class GlobalPooling3D(Layer):
777  """Abstract class for different global pooling 3D layers."""
778
779  def __init__(self, data_format=None, **kwargs):
780    super(GlobalPooling3D, self).__init__(**kwargs)
781    self.data_format = conv_utils.normalize_data_format(data_format)
782    self.input_spec = InputSpec(ndim=5)
783
784  def compute_output_shape(self, input_shape):
785    input_shape = tensor_shape.TensorShape(input_shape).as_list()
786    if self.data_format == 'channels_last':
787      return tensor_shape.TensorShape([input_shape[0], input_shape[4]])
788    else:
789      return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
790
791  def call(self, inputs):
792    raise NotImplementedError
793
794  def get_config(self):
795    config = {'data_format': self.data_format}
796    base_config = super(GlobalPooling3D, self).get_config()
797    return dict(list(base_config.items()) + list(config.items()))
798
799
800@keras_export('keras.layers.GlobalAveragePooling3D',
801              'keras.layers.GlobalAvgPool3D')
802class GlobalAveragePooling3D(GlobalPooling3D):
803  """Global Average pooling operation for 3D data.
804
805  Arguments:
806    data_format: A string,
807      one of `channels_last` (default) or `channels_first`.
808      The ordering of the dimensions in the inputs.
809      `channels_last` corresponds to inputs with shape
810      `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
811      while `channels_first` corresponds to inputs with shape
812      `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
813      It defaults to the `image_data_format` value found in your
814      Keras config file at `~/.keras/keras.json`.
815      If you never set it, then it will be "channels_last".
816
817  Input shape:
818    - If `data_format='channels_last'`:
819      5D tensor with shape:
820      `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
821    - If `data_format='channels_first'`:
822      5D tensor with shape:
823      `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
824
825  Output shape:
826    2D tensor with shape `(batch_size, channels)`.
827  """
828
829  def call(self, inputs):
830    if self.data_format == 'channels_last':
831      return backend.mean(inputs, axis=[1, 2, 3])
832    else:
833      return backend.mean(inputs, axis=[2, 3, 4])
834
835
836@keras_export('keras.layers.GlobalMaxPool3D', 'keras.layers.GlobalMaxPooling3D')
837class GlobalMaxPooling3D(GlobalPooling3D):
838  """Global Max pooling operation for 3D data.
839
840  Arguments:
841    data_format: A string,
842      one of `channels_last` (default) or `channels_first`.
843      The ordering of the dimensions in the inputs.
844      `channels_last` corresponds to inputs with shape
845      `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
846      while `channels_first` corresponds to inputs with shape
847      `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
848      It defaults to the `image_data_format` value found in your
849      Keras config file at `~/.keras/keras.json`.
850      If you never set it, then it will be "channels_last".
851
852  Input shape:
853    - If `data_format='channels_last'`:
854      5D tensor with shape:
855      `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
856    - If `data_format='channels_first'`:
857      5D tensor with shape:
858      `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
859
860  Output shape:
861    2D tensor with shape `(batch_size, channels)`.
862  """
863
864  def call(self, inputs):
865    if self.data_format == 'channels_last':
866      return backend.max(inputs, axis=[1, 2, 3])
867    else:
868      return backend.max(inputs, axis=[2, 3, 4])
869
870
871# Aliases
872
873AvgPool1D = AveragePooling1D
874MaxPool1D = MaxPooling1D
875AvgPool2D = AveragePooling2D
876MaxPool2D = MaxPooling2D
877AvgPool3D = AveragePooling3D
878MaxPool3D = MaxPooling3D
879GlobalMaxPool1D = GlobalMaxPooling1D
880GlobalMaxPool2D = GlobalMaxPooling2D
881GlobalMaxPool3D = GlobalMaxPooling3D
882GlobalAvgPool1D = GlobalAveragePooling1D
883GlobalAvgPool2D = GlobalAveragePooling2D
884GlobalAvgPool3D = GlobalAveragePooling3D
885