1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16# pylint: disable=missing-docstring 17"""EfficientNet models for Keras. 18 19Reference: 20 - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks]( 21 https://arxiv.org/abs/1905.11946) (ICML 2019) 22""" 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import copy 28import math 29 30from tensorflow.python.keras import backend 31from tensorflow.python.keras.applications import imagenet_utils 32from tensorflow.python.keras.engine import training 33from tensorflow.python.keras.layers import VersionAwareLayers 34from tensorflow.python.keras.utils import data_utils 35from tensorflow.python.keras.utils import layer_utils 36from tensorflow.python.lib.io import file_io 37from tensorflow.python.util.tf_export import keras_export 38 39 40BASE_WEIGHTS_PATH = 'https://storage.googleapis.com/keras-applications/' 41 42WEIGHTS_HASHES = { 43 'b0': ('902e53a9f72be733fc0bcb005b3ebbac', 44 '50bc09e76180e00e4465e1a485ddc09d'), 45 'b1': ('1d254153d4ab51201f1646940f018540', 46 '74c4e6b3e1f6a1eea24c589628592432'), 47 'b2': ('b15cce36ff4dcbd00b6dd88e7857a6ad', 48 '111f8e2ac8aa800a7a99e3239f7bfb39'), 49 'b3': ('ffd1fdc53d0ce67064dc6a9c7960ede0', 50 'af6d107764bb5b1abb91932881670226'), 51 'b4': ('18c95ad55216b8f92d7e70b3a046e2fc', 52 'ebc24e6d6c33eaebbd558eafbeedf1ba'), 53 'b5': ('ace28f2a6363774853a83a0b21b9421a', 54 '38879255a25d3c92d5e44e04ae6cec6f'), 55 'b6': ('165f6e37dce68623721b423839de8be5', 56 '9ecce42647a20130c1f39a5d4cb75743'), 57 'b7': ('8c03f828fec3ef71311cd463b6759d99', 58 'cbcfe4450ddf6f3ad90b1b398090fe4a'), 59} 60 61DEFAULT_BLOCKS_ARGS = [{ 62 'kernel_size': 3, 63 'repeats': 1, 64 'filters_in': 32, 65 'filters_out': 16, 66 'expand_ratio': 1, 67 'id_skip': True, 68 'strides': 1, 69 'se_ratio': 0.25 70}, { 71 'kernel_size': 3, 72 'repeats': 2, 73 'filters_in': 16, 74 'filters_out': 24, 75 'expand_ratio': 6, 76 'id_skip': True, 77 'strides': 2, 78 'se_ratio': 0.25 79}, { 80 'kernel_size': 5, 81 'repeats': 2, 82 'filters_in': 24, 83 'filters_out': 40, 84 'expand_ratio': 6, 85 'id_skip': True, 86 'strides': 2, 87 'se_ratio': 0.25 88}, { 89 'kernel_size': 3, 90 'repeats': 3, 91 'filters_in': 40, 92 'filters_out': 80, 93 'expand_ratio': 6, 94 'id_skip': True, 95 'strides': 2, 96 'se_ratio': 0.25 97}, { 98 'kernel_size': 5, 99 'repeats': 3, 100 'filters_in': 80, 101 'filters_out': 112, 102 'expand_ratio': 6, 103 'id_skip': True, 104 'strides': 1, 105 'se_ratio': 0.25 106}, { 107 'kernel_size': 5, 108 'repeats': 4, 109 'filters_in': 112, 110 'filters_out': 192, 111 'expand_ratio': 6, 112 'id_skip': True, 113 'strides': 2, 114 'se_ratio': 0.25 115}, { 116 'kernel_size': 3, 117 'repeats': 1, 118 'filters_in': 192, 119 'filters_out': 320, 120 'expand_ratio': 6, 121 'id_skip': True, 122 'strides': 1, 123 'se_ratio': 0.25 124}] 125 126CONV_KERNEL_INITIALIZER = { 127 'class_name': 'VarianceScaling', 128 'config': { 129 'scale': 2.0, 130 'mode': 'fan_out', 131 'distribution': 'truncated_normal' 132 } 133} 134 135DENSE_KERNEL_INITIALIZER = { 136 'class_name': 'VarianceScaling', 137 'config': { 138 'scale': 1. / 3., 139 'mode': 'fan_out', 140 'distribution': 'uniform' 141 } 142} 143 144layers = VersionAwareLayers() 145 146BASE_DOCSTRING = """Instantiates the {name} architecture. 147 148 Reference: 149 - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks]( 150 https://arxiv.org/abs/1905.11946) (ICML 2019) 151 152 Optionally loads weights pre-trained on ImageNet. 153 Note that the data format convention used by the model is 154 the one specified in your Keras config at `~/.keras/keras.json`. 155 If you have never configured it, it defaults to `"channels_last"`. 156 157 Args: 158 include_top: Whether to include the fully-connected 159 layer at the top of the network. Defaults to True. 160 weights: One of `None` (random initialization), 161 'imagenet' (pre-training on ImageNet), 162 or the path to the weights file to be loaded. Defaults to 'imagenet'. 163 input_tensor: Optional Keras tensor 164 (i.e. output of `layers.Input()`) 165 to use as image input for the model. 166 input_shape: Optional shape tuple, only to be specified 167 if `include_top` is False. 168 It should have exactly 3 inputs channels. 169 pooling: Optional pooling mode for feature extraction 170 when `include_top` is `False`. Defaults to None. 171 - `None` means that the output of the model will be 172 the 4D tensor output of the 173 last convolutional layer. 174 - `avg` means that global average pooling 175 will be applied to the output of the 176 last convolutional layer, and thus 177 the output of the model will be a 2D tensor. 178 - `max` means that global max pooling will 179 be applied. 180 classes: Optional number of classes to classify images 181 into, only to be specified if `include_top` is True, and 182 if no `weights` argument is specified. Defaults to 1000 (number of 183 ImageNet classes). 184 classifier_activation: A `str` or callable. The activation function to use 185 on the "top" layer. Ignored unless `include_top=True`. Set 186 `classifier_activation=None` to return the logits of the "top" layer. 187 Defaults to 'softmax'. 188 189 Returns: 190 A `keras.Model` instance. 191""" 192 193 194def EfficientNet( 195 width_coefficient, 196 depth_coefficient, 197 default_size, 198 dropout_rate=0.2, 199 drop_connect_rate=0.2, 200 depth_divisor=8, 201 activation='swish', 202 blocks_args='default', 203 model_name='efficientnet', 204 include_top=True, 205 weights='imagenet', 206 input_tensor=None, 207 input_shape=None, 208 pooling=None, 209 classes=1000, 210 classifier_activation='softmax'): 211 """Instantiates the EfficientNet architecture using given scaling coefficients. 212 213 Reference: 214 - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks]( 215 https://arxiv.org/abs/1905.11946) (ICML 2019) 216 217 Optionally loads weights pre-trained on ImageNet. 218 Note that the data format convention used by the model is 219 the one specified in your Keras config at `~/.keras/keras.json`. 220 221 Args: 222 width_coefficient: float, scaling coefficient for network width. 223 depth_coefficient: float, scaling coefficient for network depth. 224 default_size: integer, default input image size. 225 dropout_rate: float, dropout rate before final classifier layer. 226 drop_connect_rate: float, dropout rate at skip connections. 227 depth_divisor: integer, a unit of network width. 228 activation: activation function. 229 blocks_args: list of dicts, parameters to construct block modules. 230 model_name: string, model name. 231 include_top: whether to include the fully-connected 232 layer at the top of the network. 233 weights: one of `None` (random initialization), 234 'imagenet' (pre-training on ImageNet), 235 or the path to the weights file to be loaded. 236 input_tensor: optional Keras tensor 237 (i.e. output of `layers.Input()`) 238 to use as image input for the model. 239 input_shape: optional shape tuple, only to be specified 240 if `include_top` is False. 241 It should have exactly 3 inputs channels. 242 pooling: optional pooling mode for feature extraction 243 when `include_top` is `False`. 244 - `None` means that the output of the model will be 245 the 4D tensor output of the 246 last convolutional layer. 247 - `avg` means that global average pooling 248 will be applied to the output of the 249 last convolutional layer, and thus 250 the output of the model will be a 2D tensor. 251 - `max` means that global max pooling will 252 be applied. 253 classes: optional number of classes to classify images 254 into, only to be specified if `include_top` is True, and 255 if no `weights` argument is specified. 256 classifier_activation: A `str` or callable. The activation function to use 257 on the "top" layer. Ignored unless `include_top=True`. Set 258 `classifier_activation=None` to return the logits of the "top" layer. 259 260 Returns: 261 A `keras.Model` instance. 262 263 Raises: 264 ValueError: in case of invalid argument for `weights`, 265 or invalid input shape. 266 ValueError: if `classifier_activation` is not `softmax` or `None` when 267 using a pretrained top layer. 268 """ 269 if blocks_args == 'default': 270 blocks_args = DEFAULT_BLOCKS_ARGS 271 272 if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)): 273 raise ValueError('The `weights` argument should be either ' 274 '`None` (random initialization), `imagenet` ' 275 '(pre-training on ImageNet), ' 276 'or the path to the weights file to be loaded.') 277 278 if weights == 'imagenet' and include_top and classes != 1000: 279 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 280 ' as true, `classes` should be 1000') 281 282 # Determine proper input shape 283 input_shape = imagenet_utils.obtain_input_shape( 284 input_shape, 285 default_size=default_size, 286 min_size=32, 287 data_format=backend.image_data_format(), 288 require_flatten=include_top, 289 weights=weights) 290 291 if input_tensor is None: 292 img_input = layers.Input(shape=input_shape) 293 else: 294 if not backend.is_keras_tensor(input_tensor): 295 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 296 else: 297 img_input = input_tensor 298 299 bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 300 301 def round_filters(filters, divisor=depth_divisor): 302 """Round number of filters based on depth multiplier.""" 303 filters *= width_coefficient 304 new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor) 305 # Make sure that round down does not go down by more than 10%. 306 if new_filters < 0.9 * filters: 307 new_filters += divisor 308 return int(new_filters) 309 310 def round_repeats(repeats): 311 """Round number of repeats based on depth multiplier.""" 312 return int(math.ceil(depth_coefficient * repeats)) 313 314 # Build stem 315 x = img_input 316 x = layers.Rescaling(1. / 255.)(x) 317 x = layers.Normalization(axis=bn_axis)(x) 318 319 x = layers.ZeroPadding2D( 320 padding=imagenet_utils.correct_pad(x, 3), 321 name='stem_conv_pad')(x) 322 x = layers.Conv2D( 323 round_filters(32), 324 3, 325 strides=2, 326 padding='valid', 327 use_bias=False, 328 kernel_initializer=CONV_KERNEL_INITIALIZER, 329 name='stem_conv')(x) 330 x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x) 331 x = layers.Activation(activation, name='stem_activation')(x) 332 333 # Build blocks 334 blocks_args = copy.deepcopy(blocks_args) 335 336 b = 0 337 blocks = float(sum(round_repeats(args['repeats']) for args in blocks_args)) 338 for (i, args) in enumerate(blocks_args): 339 assert args['repeats'] > 0 340 # Update block input and output filters based on depth multiplier. 341 args['filters_in'] = round_filters(args['filters_in']) 342 args['filters_out'] = round_filters(args['filters_out']) 343 344 for j in range(round_repeats(args.pop('repeats'))): 345 # The first block needs to take care of stride and filter size increase. 346 if j > 0: 347 args['strides'] = 1 348 args['filters_in'] = args['filters_out'] 349 x = block( 350 x, 351 activation, 352 drop_connect_rate * b / blocks, 353 name='block{}{}_'.format(i + 1, chr(j + 97)), 354 **args) 355 b += 1 356 357 # Build top 358 x = layers.Conv2D( 359 round_filters(1280), 360 1, 361 padding='same', 362 use_bias=False, 363 kernel_initializer=CONV_KERNEL_INITIALIZER, 364 name='top_conv')(x) 365 x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x) 366 x = layers.Activation(activation, name='top_activation')(x) 367 if include_top: 368 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 369 if dropout_rate > 0: 370 x = layers.Dropout(dropout_rate, name='top_dropout')(x) 371 imagenet_utils.validate_activation(classifier_activation, weights) 372 x = layers.Dense( 373 classes, 374 activation=classifier_activation, 375 kernel_initializer=DENSE_KERNEL_INITIALIZER, 376 name='predictions')(x) 377 else: 378 if pooling == 'avg': 379 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 380 elif pooling == 'max': 381 x = layers.GlobalMaxPooling2D(name='max_pool')(x) 382 383 # Ensure that the model takes into account 384 # any potential predecessors of `input_tensor`. 385 if input_tensor is not None: 386 inputs = layer_utils.get_source_inputs(input_tensor) 387 else: 388 inputs = img_input 389 390 # Create model. 391 model = training.Model(inputs, x, name=model_name) 392 393 # Load weights. 394 if weights == 'imagenet': 395 if include_top: 396 file_suffix = '.h5' 397 file_hash = WEIGHTS_HASHES[model_name[-2:]][0] 398 else: 399 file_suffix = '_notop.h5' 400 file_hash = WEIGHTS_HASHES[model_name[-2:]][1] 401 file_name = model_name + file_suffix 402 weights_path = data_utils.get_file( 403 file_name, 404 BASE_WEIGHTS_PATH + file_name, 405 cache_subdir='models', 406 file_hash=file_hash) 407 model.load_weights(weights_path) 408 elif weights is not None: 409 model.load_weights(weights) 410 return model 411 412 413def block(inputs, 414 activation='swish', 415 drop_rate=0., 416 name='', 417 filters_in=32, 418 filters_out=16, 419 kernel_size=3, 420 strides=1, 421 expand_ratio=1, 422 se_ratio=0., 423 id_skip=True): 424 """An inverted residual block. 425 426 Args: 427 inputs: input tensor. 428 activation: activation function. 429 drop_rate: float between 0 and 1, fraction of the input units to drop. 430 name: string, block label. 431 filters_in: integer, the number of input filters. 432 filters_out: integer, the number of output filters. 433 kernel_size: integer, the dimension of the convolution window. 434 strides: integer, the stride of the convolution. 435 expand_ratio: integer, scaling coefficient for the input filters. 436 se_ratio: float between 0 and 1, fraction to squeeze the input filters. 437 id_skip: boolean. 438 439 Returns: 440 output tensor for the block. 441 """ 442 bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 443 444 # Expansion phase 445 filters = filters_in * expand_ratio 446 if expand_ratio != 1: 447 x = layers.Conv2D( 448 filters, 449 1, 450 padding='same', 451 use_bias=False, 452 kernel_initializer=CONV_KERNEL_INITIALIZER, 453 name=name + 'expand_conv')( 454 inputs) 455 x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x) 456 x = layers.Activation(activation, name=name + 'expand_activation')(x) 457 else: 458 x = inputs 459 460 # Depthwise Convolution 461 if strides == 2: 462 x = layers.ZeroPadding2D( 463 padding=imagenet_utils.correct_pad(x, kernel_size), 464 name=name + 'dwconv_pad')(x) 465 conv_pad = 'valid' 466 else: 467 conv_pad = 'same' 468 x = layers.DepthwiseConv2D( 469 kernel_size, 470 strides=strides, 471 padding=conv_pad, 472 use_bias=False, 473 depthwise_initializer=CONV_KERNEL_INITIALIZER, 474 name=name + 'dwconv')(x) 475 x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x) 476 x = layers.Activation(activation, name=name + 'activation')(x) 477 478 # Squeeze and Excitation phase 479 if 0 < se_ratio <= 1: 480 filters_se = max(1, int(filters_in * se_ratio)) 481 se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x) 482 if bn_axis == 1: 483 se_shape = (filters, 1, 1) 484 else: 485 se_shape = (1, 1, filters) 486 se = layers.Reshape(se_shape, name=name + 'se_reshape')(se) 487 se = layers.Conv2D( 488 filters_se, 489 1, 490 padding='same', 491 activation=activation, 492 kernel_initializer=CONV_KERNEL_INITIALIZER, 493 name=name + 'se_reduce')( 494 se) 495 se = layers.Conv2D( 496 filters, 497 1, 498 padding='same', 499 activation='sigmoid', 500 kernel_initializer=CONV_KERNEL_INITIALIZER, 501 name=name + 'se_expand')(se) 502 x = layers.multiply([x, se], name=name + 'se_excite') 503 504 # Output phase 505 x = layers.Conv2D( 506 filters_out, 507 1, 508 padding='same', 509 use_bias=False, 510 kernel_initializer=CONV_KERNEL_INITIALIZER, 511 name=name + 'project_conv')(x) 512 x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x) 513 if id_skip and strides == 1 and filters_in == filters_out: 514 if drop_rate > 0: 515 x = layers.Dropout( 516 drop_rate, noise_shape=(None, 1, 1, 1), name=name + 'drop')(x) 517 x = layers.add([x, inputs], name=name + 'add') 518 return x 519 520 521@keras_export('keras.applications.efficientnet.EfficientNetB0', 522 'keras.applications.EfficientNetB0') 523def EfficientNetB0(include_top=True, 524 weights='imagenet', 525 input_tensor=None, 526 input_shape=None, 527 pooling=None, 528 classes=1000, 529 classifier_activation='softmax', 530 **kwargs): 531 return EfficientNet( 532 1.0, 533 1.0, 534 224, 535 0.2, 536 model_name='efficientnetb0', 537 include_top=include_top, 538 weights=weights, 539 input_tensor=input_tensor, 540 input_shape=input_shape, 541 pooling=pooling, 542 classes=classes, 543 classifier_activation=classifier_activation, 544 **kwargs) 545 546 547@keras_export('keras.applications.efficientnet.EfficientNetB1', 548 'keras.applications.EfficientNetB1') 549def EfficientNetB1(include_top=True, 550 weights='imagenet', 551 input_tensor=None, 552 input_shape=None, 553 pooling=None, 554 classes=1000, 555 classifier_activation='softmax', 556 **kwargs): 557 return EfficientNet( 558 1.0, 559 1.1, 560 240, 561 0.2, 562 model_name='efficientnetb1', 563 include_top=include_top, 564 weights=weights, 565 input_tensor=input_tensor, 566 input_shape=input_shape, 567 pooling=pooling, 568 classes=classes, 569 classifier_activation=classifier_activation, 570 **kwargs) 571 572 573@keras_export('keras.applications.efficientnet.EfficientNetB2', 574 'keras.applications.EfficientNetB2') 575def EfficientNetB2(include_top=True, 576 weights='imagenet', 577 input_tensor=None, 578 input_shape=None, 579 pooling=None, 580 classes=1000, 581 classifier_activation='softmax', 582 **kwargs): 583 return EfficientNet( 584 1.1, 585 1.2, 586 260, 587 0.3, 588 model_name='efficientnetb2', 589 include_top=include_top, 590 weights=weights, 591 input_tensor=input_tensor, 592 input_shape=input_shape, 593 pooling=pooling, 594 classes=classes, 595 classifier_activation=classifier_activation, 596 **kwargs) 597 598 599@keras_export('keras.applications.efficientnet.EfficientNetB3', 600 'keras.applications.EfficientNetB3') 601def EfficientNetB3(include_top=True, 602 weights='imagenet', 603 input_tensor=None, 604 input_shape=None, 605 pooling=None, 606 classes=1000, 607 classifier_activation='softmax', 608 **kwargs): 609 return EfficientNet( 610 1.2, 611 1.4, 612 300, 613 0.3, 614 model_name='efficientnetb3', 615 include_top=include_top, 616 weights=weights, 617 input_tensor=input_tensor, 618 input_shape=input_shape, 619 pooling=pooling, 620 classes=classes, 621 classifier_activation=classifier_activation, 622 **kwargs) 623 624 625@keras_export('keras.applications.efficientnet.EfficientNetB4', 626 'keras.applications.EfficientNetB4') 627def EfficientNetB4(include_top=True, 628 weights='imagenet', 629 input_tensor=None, 630 input_shape=None, 631 pooling=None, 632 classes=1000, 633 classifier_activation='softmax', 634 **kwargs): 635 return EfficientNet( 636 1.4, 637 1.8, 638 380, 639 0.4, 640 model_name='efficientnetb4', 641 include_top=include_top, 642 weights=weights, 643 input_tensor=input_tensor, 644 input_shape=input_shape, 645 pooling=pooling, 646 classes=classes, 647 classifier_activation=classifier_activation, 648 **kwargs) 649 650 651@keras_export('keras.applications.efficientnet.EfficientNetB5', 652 'keras.applications.EfficientNetB5') 653def EfficientNetB5(include_top=True, 654 weights='imagenet', 655 input_tensor=None, 656 input_shape=None, 657 pooling=None, 658 classes=1000, 659 classifier_activation='softmax', 660 **kwargs): 661 return EfficientNet( 662 1.6, 663 2.2, 664 456, 665 0.4, 666 model_name='efficientnetb5', 667 include_top=include_top, 668 weights=weights, 669 input_tensor=input_tensor, 670 input_shape=input_shape, 671 pooling=pooling, 672 classes=classes, 673 classifier_activation=classifier_activation, 674 **kwargs) 675 676 677@keras_export('keras.applications.efficientnet.EfficientNetB6', 678 'keras.applications.EfficientNetB6') 679def EfficientNetB6(include_top=True, 680 weights='imagenet', 681 input_tensor=None, 682 input_shape=None, 683 pooling=None, 684 classes=1000, 685 classifier_activation='softmax', 686 **kwargs): 687 return EfficientNet( 688 1.8, 689 2.6, 690 528, 691 0.5, 692 model_name='efficientnetb6', 693 include_top=include_top, 694 weights=weights, 695 input_tensor=input_tensor, 696 input_shape=input_shape, 697 pooling=pooling, 698 classes=classes, 699 classifier_activation=classifier_activation, 700 **kwargs) 701 702 703@keras_export('keras.applications.efficientnet.EfficientNetB7', 704 'keras.applications.EfficientNetB7') 705def EfficientNetB7(include_top=True, 706 weights='imagenet', 707 input_tensor=None, 708 input_shape=None, 709 pooling=None, 710 classes=1000, 711 classifier_activation='softmax', 712 **kwargs): 713 return EfficientNet( 714 2.0, 715 3.1, 716 600, 717 0.5, 718 model_name='efficientnetb7', 719 include_top=include_top, 720 weights=weights, 721 input_tensor=input_tensor, 722 input_shape=input_shape, 723 pooling=pooling, 724 classes=classes, 725 classifier_activation=classifier_activation, 726 **kwargs) 727 728 729EfficientNetB0.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB0') 730EfficientNetB1.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB1') 731EfficientNetB2.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB2') 732EfficientNetB3.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB3') 733EfficientNetB4.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB4') 734EfficientNetB5.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB5') 735EfficientNetB6.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB6') 736EfficientNetB7.__doc__ = BASE_DOCSTRING.format(name='EfficientNetB7') 737 738 739@keras_export('keras.applications.efficientnet.preprocess_input') 740def preprocess_input(x, data_format=None): # pylint: disable=unused-argument 741 return x 742 743 744@keras_export('keras.applications.efficientnet.decode_predictions') 745def decode_predictions(preds, top=5): 746 return imagenet_utils.decode_predictions(preds, top=top) 747 748 749decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 750