1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=missing-docstring
17"""EfficientNet models for Keras.
18
19Reference:
20  - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
21      https://arxiv.org/abs/1905.11946) (ICML 2019)
22"""
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import copy
28import math
29
30from tensorflow.python.keras import backend
31from tensorflow.python.keras.applications import imagenet_utils
32from tensorflow.python.keras.engine import training
33from tensorflow.python.keras.layers import VersionAwareLayers
34from tensorflow.python.keras.utils import data_utils
35from tensorflow.python.keras.utils import layer_utils
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.util.tf_export import keras_export
38
39
40BASE_WEIGHTS_PATH = 'https://storage.googleapis.com/keras-applications/'
41
42WEIGHTS_HASHES = {
43    'b0': ('902e53a9f72be733fc0bcb005b3ebbac',
44           '50bc09e76180e00e4465e1a485ddc09d'),
45    'b1': ('1d254153d4ab51201f1646940f018540',
46           '74c4e6b3e1f6a1eea24c589628592432'),
47    'b2': ('b15cce36ff4dcbd00b6dd88e7857a6ad',
48           '111f8e2ac8aa800a7a99e3239f7bfb39'),
49    'b3': ('ffd1fdc53d0ce67064dc6a9c7960ede0',
50           'af6d107764bb5b1abb91932881670226'),
51    'b4': ('18c95ad55216b8f92d7e70b3a046e2fc',
52           'ebc24e6d6c33eaebbd558eafbeedf1ba'),
53    'b5': ('ace28f2a6363774853a83a0b21b9421a',
54           '38879255a25d3c92d5e44e04ae6cec6f'),
55    'b6': ('165f6e37dce68623721b423839de8be5',
56           '9ecce42647a20130c1f39a5d4cb75743'),
57    'b7': ('8c03f828fec3ef71311cd463b6759d99',
58           'cbcfe4450ddf6f3ad90b1b398090fe4a'),
59}
60
61DEFAULT_BLOCKS_ARGS = [{
62    'kernel_size': 3,
63    'repeats': 1,
64    'filters_in': 32,
65    'filters_out': 16,
66    'expand_ratio': 1,
67    'id_skip': True,
68    'strides': 1,
69    'se_ratio': 0.25
70}, {
71    'kernel_size': 3,
72    'repeats': 2,
73    'filters_in': 16,
74    'filters_out': 24,
75    'expand_ratio': 6,
76    'id_skip': True,
77    'strides': 2,
78    'se_ratio': 0.25
79}, {
80    'kernel_size': 5,
81    'repeats': 2,
82    'filters_in': 24,
83    'filters_out': 40,
84    'expand_ratio': 6,
85    'id_skip': True,
86    'strides': 2,
87    'se_ratio': 0.25
88}, {
89    'kernel_size': 3,
90    'repeats': 3,
91    'filters_in': 40,
92    'filters_out': 80,
93    'expand_ratio': 6,
94    'id_skip': True,
95    'strides': 2,
96    'se_ratio': 0.25
97}, {
98    'kernel_size': 5,
99    'repeats': 3,
100    'filters_in': 80,
101    'filters_out': 112,
102    'expand_ratio': 6,
103    'id_skip': True,
104    'strides': 1,
105    'se_ratio': 0.25
106}, {
107    'kernel_size': 5,
108    'repeats': 4,
109    'filters_in': 112,
110    'filters_out': 192,
111    'expand_ratio': 6,
112    'id_skip': True,
113    'strides': 2,
114    'se_ratio': 0.25
115}, {
116    'kernel_size': 3,
117    'repeats': 1,
118    'filters_in': 192,
119    'filters_out': 320,
120    'expand_ratio': 6,
121    'id_skip': True,
122    'strides': 1,
123    'se_ratio': 0.25
124}]
125
126CONV_KERNEL_INITIALIZER = {
127    'class_name': 'VarianceScaling',
128    'config': {
129        'scale': 2.0,
130        'mode': 'fan_out',
131        'distribution': 'truncated_normal'
132    }
133}
134
135DENSE_KERNEL_INITIALIZER = {
136    'class_name': 'VarianceScaling',
137    'config': {
138        'scale': 1. / 3.,
139        'mode': 'fan_out',
140        'distribution': 'uniform'
141    }
142}
143
144layers = VersionAwareLayers()
145
146BASE_DOCSTRING = """Instantiates the {name} architecture.
147
148  Reference:
149  - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
150      https://arxiv.org/abs/1905.11946) (ICML 2019)
151
152  Optionally loads weights pre-trained on ImageNet.
153  Note that the data format convention used by the model is
154  the one specified in your Keras config at `~/.keras/keras.json`.
155  If you have never configured it, it defaults to `"channels_last"`.
156
157  Args:
158    include_top: Whether to include the fully-connected
159        layer at the top of the network. Defaults to True.
160    weights: One of `None` (random initialization),
161          'imagenet' (pre-training on ImageNet),
162          or the path to the weights file to be loaded. Defaults to 'imagenet'.
163    input_tensor: Optional Keras tensor
164        (i.e. output of `layers.Input()`)
165        to use as image input for the model.
166    input_shape: Optional shape tuple, only to be specified
167        if `include_top` is False.
168        It should have exactly 3 inputs channels.
169    pooling: Optional pooling mode for feature extraction
170        when `include_top` is `False`. Defaults to None.
171        - `None` means that the output of the model will be
172            the 4D tensor output of the
173            last convolutional layer.
174        - `avg` means that global average pooling
175            will be applied to the output of the
176            last convolutional layer, and thus
177            the output of the model will be a 2D tensor.
178        - `max` means that global max pooling will
179            be applied.
180    classes: Optional number of classes to classify images
181        into, only to be specified if `include_top` is True, and
182        if no `weights` argument is specified. Defaults to 1000 (number of
183        ImageNet classes).
184    classifier_activation: A `str` or callable. The activation function to use
185        on the "top" layer. Ignored unless `include_top=True`. Set
186        `classifier_activation=None` to return the logits of the "top" layer.
187        Defaults to 'softmax'.
188
189  Returns:
190    A `keras.Model` instance.
191"""
192
193
194def EfficientNet(
195    width_coefficient,
196    depth_coefficient,
197    default_size,
198    dropout_rate=0.2,
199    drop_connect_rate=0.2,
200    depth_divisor=8,
201    activation='swish',
202    blocks_args='default',
203    model_name='efficientnet',
204    include_top=True,
205    weights='imagenet',
206    input_tensor=None,
207    input_shape=None,
208    pooling=None,
209    classes=1000,
210    classifier_activation='softmax'):
211  """Instantiates the EfficientNet architecture using given scaling coefficients.
212
213  Reference:
214  - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](
215      https://arxiv.org/abs/1905.11946) (ICML 2019)
216
217  Optionally loads weights pre-trained on ImageNet.
218  Note that the data format convention used by the model is
219  the one specified in your Keras config at `~/.keras/keras.json`.
220
221  Args:
222    width_coefficient: float, scaling coefficient for network width.
223    depth_coefficient: float, scaling coefficient for network depth.
224    default_size: integer, default input image size.
225    dropout_rate: float, dropout rate before final classifier layer.
226    drop_connect_rate: float, dropout rate at skip connections.
227    depth_divisor: integer, a unit of network width.
228    activation: activation function.
229    blocks_args: list of dicts, parameters to construct block modules.
230    model_name: string, model name.
231    include_top: whether to include the fully-connected
232        layer at the top of the network.
233    weights: one of `None` (random initialization),
234          'imagenet' (pre-training on ImageNet),
235          or the path to the weights file to be loaded.
236    input_tensor: optional Keras tensor
237        (i.e. output of `layers.Input()`)
238        to use as image input for the model.
239    input_shape: optional shape tuple, only to be specified
240        if `include_top` is False.
241        It should have exactly 3 inputs channels.
242    pooling: optional pooling mode for feature extraction
243        when `include_top` is `False`.
244        - `None` means that the output of the model will be
245            the 4D tensor output of the
246            last convolutional layer.
247        - `avg` means that global average pooling
248            will be applied to the output of the
249            last convolutional layer, and thus
250            the output of the model will be a 2D tensor.
251        - `max` means that global max pooling will
252            be applied.
253    classes: optional number of classes to classify images
254        into, only to be specified if `include_top` is True, and
255        if no `weights` argument is specified.
256    classifier_activation: A `str` or callable. The activation function to use
257        on the "top" layer. Ignored unless `include_top=True`. Set
258        `classifier_activation=None` to return the logits of the "top" layer.
259
260  Returns:
261    A `keras.Model` instance.
262
263  Raises:
264    ValueError: in case of invalid argument for `weights`,
265      or invalid input shape.
266    ValueError: if `classifier_activation` is not `softmax` or `None` when
267      using a pretrained top layer.
268  """
269  if blocks_args == 'default':
270    blocks_args = DEFAULT_BLOCKS_ARGS
271
272  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
273    raise ValueError('The `weights` argument should be either '
274                     '`None` (random initialization), `imagenet` '
275                     '(pre-training on ImageNet), '
276                     'or the path to the weights file to be loaded.')
277
278  if weights == 'imagenet' and include_top and classes != 1000:
279    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
280                     ' as true, `classes` should be 1000')
281
282  # Determine proper input shape
283  input_shape = imagenet_utils.obtain_input_shape(
284      input_shape,
285      default_size=default_size,
286      min_size=32,
287      data_format=backend.image_data_format(),
288      require_flatten=include_top,
289      weights=weights)
290
291  if input_tensor is None:
292    img_input = layers.Input(shape=input_shape)
293  else:
294    if not backend.is_keras_tensor(input_tensor):
295      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
296    else:
297      img_input = input_tensor
298
299  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
300
301  def round_filters(filters, divisor=depth_divisor):
302    """Round number of filters based on depth multiplier."""
303    filters *= width_coefficient
304    new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
305    # Make sure that round down does not go down by more than 10%.
306    if new_filters < 0.9 * filters:
307      new_filters += divisor
308    return int(new_filters)
309
310  def round_repeats(repeats):
311    """Round number of repeats based on depth multiplier."""
312    return int(math.ceil(depth_coefficient * repeats))
313
314  # Build stem
315  x = img_input
316  x = layers.Rescaling(1. / 255.)(x)
317  x = layers.Normalization(axis=bn_axis)(x)
318
319  x = layers.ZeroPadding2D(
320      padding=imagenet_utils.correct_pad(x, 3),
321      name='stem_conv_pad')(x)
322  x = layers.Conv2D(
323      round_filters(32),
324      3,
325      strides=2,
326      padding='valid',
327      use_bias=False,
328      kernel_initializer=CONV_KERNEL_INITIALIZER,
329      name='stem_conv')(x)
330  x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
331  x = layers.Activation(activation, name='stem_activation')(x)
332
333  # Build blocks
334  blocks_args = copy.deepcopy(blocks_args)
335
336  b = 0
337  blocks = float(sum(round_repeats(args['repeats']) for args in blocks_args))
338  for (i, args) in enumerate(blocks_args):
339    assert args['repeats'] > 0
340    # Update block input and output filters based on depth multiplier.
341    args['filters_in'] = round_filters(args['filters_in'])
342    args['filters_out'] = round_filters(args['filters_out'])
343
344    for j in range(round_repeats(args.pop('repeats'))):
345      # The first block needs to take care of stride and filter size increase.
346      if j > 0:
347        args['strides'] = 1
348        args['filters_in'] = args['filters_out']
349      x = block(
350          x,
351          activation,
352          drop_connect_rate * b / blocks,
353          name='block{}{}_'.format(i + 1, chr(j + 97)),
354          **args)
355      b += 1
356
357  # Build top
358  x = layers.Conv2D(
359      round_filters(1280),
360      1,
361      padding='same',
362      use_bias=False,
363      kernel_initializer=CONV_KERNEL_INITIALIZER,
364      name='top_conv')(x)
365  x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
366  x = layers.Activation(activation, name='top_activation')(x)
367  if include_top:
368    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
369    if dropout_rate > 0:
370      x = layers.Dropout(dropout_rate, name='top_dropout')(x)
371    imagenet_utils.validate_activation(classifier_activation, weights)
372    x = layers.Dense(
373        classes,
374        activation=classifier_activation,
375        kernel_initializer=DENSE_KERNEL_INITIALIZER,
376        name='predictions')(x)
377  else:
378    if pooling == 'avg':
379      x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
380    elif pooling == 'max':
381      x = layers.GlobalMaxPooling2D(name='max_pool')(x)
382
383  # Ensure that the model takes into account
384  # any potential predecessors of `input_tensor`.
385  if input_tensor is not None:
386    inputs = layer_utils.get_source_inputs(input_tensor)
387  else:
388    inputs = img_input
389
390  # Create model.
391  model = training.Model(inputs, x, name=model_name)
392
393  # Load weights.
394  if weights == 'imagenet':
395    if include_top:
396      file_suffix = '.h5'
397      file_hash = WEIGHTS_HASHES[model_name[-2:]][0]
398    else:
399      file_suffix = '_notop.h5'
400      file_hash = WEIGHTS_HASHES[model_name[-2:]][1]
401    file_name = model_name + file_suffix
402    weights_path = data_utils.get_file(
403        file_name,
404        BASE_WEIGHTS_PATH + file_name,
405        cache_subdir='models',
406        file_hash=file_hash)
407    model.load_weights(weights_path)
408  elif weights is not None:
409    model.load_weights(weights)
410  return model
411
412
413def block(inputs,
414          activation='swish',
415          drop_rate=0.,
416          name='',
417          filters_in=32,
418          filters_out=16,
419          kernel_size=3,
420          strides=1,
421          expand_ratio=1,
422          se_ratio=0.,
423          id_skip=True):
424  """An inverted residual block.
425
426  Args:
427      inputs: input tensor.
428      activation: activation function.
429      drop_rate: float between 0 and 1, fraction of the input units to drop.
430      name: string, block label.
431      filters_in: integer, the number of input filters.
432      filters_out: integer, the number of output filters.
433      kernel_size: integer, the dimension of the convolution window.
434      strides: integer, the stride of the convolution.
435      expand_ratio: integer, scaling coefficient for the input filters.
436      se_ratio: float between 0 and 1, fraction to squeeze the input filters.
437      id_skip: boolean.
438
439  Returns:
440      output tensor for the block.
441  """
442  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
443
444  # Expansion phase
445  filters = filters_in * expand_ratio
446  if expand_ratio != 1:
447    x = layers.Conv2D(
448        filters,
449        1,
450        padding='same',
451        use_bias=False,
452        kernel_initializer=CONV_KERNEL_INITIALIZER,
453        name=name + 'expand_conv')(
454            inputs)
455    x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
456    x = layers.Activation(activation, name=name + 'expand_activation')(x)
457  else:
458    x = inputs
459
460  # Depthwise Convolution
461  if strides == 2:
462    x = layers.ZeroPadding2D(
463        padding=imagenet_utils.correct_pad(x, kernel_size),
464        name=name + 'dwconv_pad')(x)
465    conv_pad = 'valid'
466  else:
467    conv_pad = 'same'
468  x = layers.DepthwiseConv2D(
469      kernel_size,
470      strides=strides,
471      padding=conv_pad,
472      use_bias=False,
473      depthwise_initializer=CONV_KERNEL_INITIALIZER,
474      name=name + 'dwconv')(x)
475  x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
476  x = layers.Activation(activation, name=name + 'activation')(x)
477
478  # Squeeze and Excitation phase
479  if 0 < se_ratio <= 1:
480    filters_se = max(1, int(filters_in * se_ratio))
481    se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x)
482    if bn_axis == 1:
483      se_shape = (filters, 1, 1)
484    else:
485      se_shape = (1, 1, filters)
486    se = layers.Reshape(se_shape, name=name + 'se_reshape')(se)
487    se = layers.Conv2D(
488        filters_se,
489        1,
490        padding='same',
491        activation=activation,
492        kernel_initializer=CONV_KERNEL_INITIALIZER,
493        name=name + 'se_reduce')(
494            se)
495    se = layers.Conv2D(
496        filters,
497        1,
498        padding='same',
499        activation='sigmoid',
500        kernel_initializer=CONV_KERNEL_INITIALIZER,
501        name=name + 'se_expand')(se)
502    x = layers.multiply([x, se], name=name + 'se_excite')
503
504  # Output phase
505  x = layers.Conv2D(
506      filters_out,
507      1,
508      padding='same',
509      use_bias=False,
510      kernel_initializer=CONV_KERNEL_INITIALIZER,
511      name=name + 'project_conv')(x)
512  x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
513  if id_skip and strides == 1 and filters_in == filters_out:
514    if drop_rate > 0:
515      x = layers.Dropout(
516          drop_rate, noise_shape=(None, 1, 1, 1), name=name + 'drop')(x)
517    x = layers.add([x, inputs], name=name + 'add')
518  return x
519
520
521@keras_export('keras.applications.efficientnet.EfficientNetB0',
522              'keras.applications.EfficientNetB0')
523def EfficientNetB0(include_top=True,
524                   weights='imagenet',
525                   input_tensor=None,
526                   input_shape=None,
527                   pooling=None,
528                   classes=1000,
529                   classifier_activation='softmax',
530                   **kwargs):
531  return EfficientNet(
532      1.0,
533      1.0,
534      224,
535      0.2,
536      model_name='efficientnetb0',
537      include_top=include_top,
538      weights=weights,
539      input_tensor=input_tensor,
540      input_shape=input_shape,
541      pooling=pooling,
542      classes=classes,
543      classifier_activation=classifier_activation,
544      **kwargs)
545
546
547@keras_export('keras.applications.efficientnet.EfficientNetB1',
548              'keras.applications.EfficientNetB1')
549def EfficientNetB1(include_top=True,
550                   weights='imagenet',
551                   input_tensor=None,
552                   input_shape=None,
553                   pooling=None,
554                   classes=1000,
555                   classifier_activation='softmax',
556                   **kwargs):
557  return EfficientNet(
558      1.0,
559      1.1,
560      240,
561      0.2,
562      model_name='efficientnetb1',
563      include_top=include_top,
564      weights=weights,
565      input_tensor=input_tensor,
566      input_shape=input_shape,
567      pooling=pooling,
568      classes=classes,
569      classifier_activation=classifier_activation,
570      **kwargs)
571
572
573@keras_export('keras.applications.efficientnet.EfficientNetB2',
574              'keras.applications.EfficientNetB2')
575def EfficientNetB2(include_top=True,
576                   weights='imagenet',
577                   input_tensor=None,
578                   input_shape=None,
579                   pooling=None,
580                   classes=1000,
581                   classifier_activation='softmax',
582                   **kwargs):
583  return EfficientNet(
584      1.1,
585      1.2,
586      260,
587      0.3,
588      model_name='efficientnetb2',
589      include_top=include_top,
590      weights=weights,
591      input_tensor=input_tensor,
592      input_shape=input_shape,
593      pooling=pooling,
594      classes=classes,
595      classifier_activation=classifier_activation,
596      **kwargs)
597
598
599@keras_export('keras.applications.efficientnet.EfficientNetB3',
600              'keras.applications.EfficientNetB3')
601def EfficientNetB3(include_top=True,
602                   weights='imagenet',
603                   input_tensor=None,
604                   input_shape=None,
605                   pooling=None,
606                   classes=1000,
607                   classifier_activation='softmax',
608                   **kwargs):
609  return EfficientNet(
610      1.2,
611      1.4,
612      300,
613      0.3,
614      model_name='efficientnetb3',
615      include_top=include_top,
616      weights=weights,
617      input_tensor=input_tensor,
618      input_shape=input_shape,
619      pooling=pooling,
620      classes=classes,
621      classifier_activation=classifier_activation,
622      **kwargs)
623
624
625@keras_export('keras.applications.efficientnet.EfficientNetB4',
626              'keras.applications.EfficientNetB4')
627def EfficientNetB4(include_top=True,
628                   weights='imagenet',
629                   input_tensor=None,
630                   input_shape=None,
631                   pooling=None,
632                   classes=1000,
633                   classifier_activation='softmax',
634                   **kwargs):
635  return EfficientNet(
636      1.4,
637      1.8,
638      380,
639      0.4,
640      model_name='efficientnetb4',
641      include_top=include_top,
642      weights=weights,
643      input_tensor=input_tensor,
644      input_shape=input_shape,
645      pooling=pooling,
646      classes=classes,
647      classifier_activation=classifier_activation,
648      **kwargs)
649
650
651@keras_export('keras.applications.efficientnet.EfficientNetB5',
652              'keras.applications.EfficientNetB5')
653def EfficientNetB5(include_top=True,
654                   weights='imagenet',
655                   input_tensor=None,
656                   input_shape=None,
657                   pooling=None,
658                   classes=1000,
659                   classifier_activation='softmax',
660                   **kwargs):
661  return EfficientNet(
662      1.6,
663      2.2,
664      456,
665      0.4,
666      model_name='efficientnetb5',
667      include_top=include_top,
668      weights=weights,
669      input_tensor=input_tensor,
670      input_shape=input_shape,
671      pooling=pooling,
672      classes=classes,
673      classifier_activation=classifier_activation,
674      **kwargs)
675
676
677@keras_export('keras.applications.efficientnet.EfficientNetB6',
678              'keras.applications.EfficientNetB6')
679def EfficientNetB6(include_top=True,
680                   weights='imagenet',
681                   input_tensor=None,
682                   input_shape=None,
683                   pooling=None,
684                   classes=1000,
685                   classifier_activation='softmax',
686                   **kwargs):
687  return EfficientNet(
688      1.8,
689      2.6,
690      528,
691      0.5,
692      model_name='efficientnetb6',
693      include_top=include_top,
694      weights=weights,
695      input_tensor=input_tensor,
696      input_shape=input_shape,
697      pooling=pooling,
698      classes=classes,
699      classifier_activation=classifier_activation,
700      **kwargs)
701
702
703@keras_export('keras.applications.efficientnet.EfficientNetB7',
704              'keras.applications.EfficientNetB7')
705def EfficientNetB7(include_top=True,
706                   weights='imagenet',
707                   input_tensor=None,
708                   input_shape=None,
709                   pooling=None,
710                   classes=1000,
711                   classifier_activation='softmax',
712                   **kwargs):
713  return EfficientNet(
714      2.0,
715      3.1,
716      600,
717      0.5,
718      model_name='efficientnetb7',
719      include_top=include_top,
720      weights=weights,
721      input_tensor=input_tensor,
722      input_shape=input_shape,
723      pooling=pooling,
724      classes=classes,
725      classifier_activation=classifier_activation,
726      **kwargs)
727
728
729EfficientNetB0.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB0')
730EfficientNetB1.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB1')
731EfficientNetB2.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB2')
732EfficientNetB3.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB3')
733EfficientNetB4.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB4')
734EfficientNetB5.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB5')
735EfficientNetB6.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB6')
736EfficientNetB7.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB7')
737
738
739@keras_export('keras.applications.efficientnet.preprocess_input')
740def preprocess_input(x, data_format=None):  # pylint: disable=unused-argument
741  return x
742
743
744@keras_export('keras.applications.efficientnet.decode_predictions')
745def decode_predictions(preds, top=5):
746  return imagenet_utils.decode_predictions(preds, top=top)
747
748
749decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
750