1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the base ProcessingLayer and a subclass that uses Combiners.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22 23import numpy as np 24import six 25 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import monitoring 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.keras import backend as K 33from tensorflow.python.keras.engine import data_adapter 34from tensorflow.python.keras.engine.base_layer import Layer 35from tensorflow.python.keras.utils import tf_utils 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import sparse_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.training.tracking import base as trackable 41from tensorflow.python.util.tf_export import keras_export 42 43 44keras_kpl_gauge = monitoring.BoolGauge( 45 '/tensorflow/api/keras/layers/preprocessing', 46 'keras preprocessing layers usage', 'method') 47 48 49@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer') 50@six.add_metaclass(abc.ABCMeta) 51class PreprocessingLayer(Layer): 52 """Base class for PreprocessingLayers. 53 54 Attributes: 55 stateful: Whether the layer contains state that needs to be adapted via 56 `PreprocessingLayer.adapt`. 57 streaming: Whether a layer can be adapted multiple times without resetting 58 the state of the layer. 59 """ 60 _must_restore_from_config = True 61 62 def __init__(self, stateful=False, streaming=True, **kwargs): 63 super(PreprocessingLayer, self).__init__(**kwargs) 64 self._stateful = stateful 65 self._streaming = streaming 66 self._is_compiled = False 67 self._is_adapted = False 68 69 # Sets `is_adapted=False` when `reset_state` is called. 70 self._reset_state_impl = self.reset_state 71 self.reset_state = self._reset_state_wrapper 72 73 self._adapt_function = None 74 75 @property 76 def streaming(self): 77 """Whether `adapt` can be called twice without resetting the state.""" 78 return self._streaming 79 80 @property 81 def is_adapted(self): 82 """Whether the layer has been fit to data already.""" 83 return self._is_adapted 84 85 def update_state(self, data): 86 """Accumulates statistics for the preprocessing layer. 87 88 Arguments: 89 data: A mini-batch of inputs to the layer. 90 """ 91 if self.stateful: 92 raise NotImplementedError 93 94 def reset_state(self): 95 """Resets the statistics of the preprocessing layer.""" 96 if self.stateful: 97 raise NotImplementedError 98 99 def merge_state(self, layers): 100 """Merge the statistics of multiple preprocessing layers. 101 102 This layer will contain the merged state. 103 104 Arguments: 105 layers: Layers whose statistics should be merge with the statistics of 106 this layer. 107 """ 108 if self.stateful: 109 raise NotImplementedError 110 111 def finalize_state(self): 112 """Finalize the statistics for the preprocessing layer. 113 114 This method is called at the end of `adapt`. This method 115 handles any one-time operations that should occur after all 116 data has been seen. 117 """ 118 pass 119 120 def make_adapt_function(self): 121 """Creates a function to execute one step of `adapt`. 122 123 This method can be overridden to support custom adapt logic. 124 This method is called by `PreprocessingLayer.adapt`. 125 126 Typically, this method directly controls `tf.function` settings, 127 and delegates the actual state update logic to 128 `PreprocessingLayer.update_state`. 129 130 This function is cached the first time `PreprocessingLayer.adapt` 131 is called. The cache is cleared whenever `PreprocessingLayer.compile` 132 is called. 133 134 Returns: 135 Function. The function created by this method should accept a 136 `tf.data.Iterator`, retrieve a batch, and update the state of the 137 layer. 138 """ 139 if self._adapt_function is not None: 140 return self._adapt_function 141 142 def adapt_step(iterator): 143 data = next(iterator) 144 self._adapt_maybe_build(data) 145 self.update_state(data) 146 147 if self._steps_per_execution.numpy().item() == 1: 148 adapt_fn = adapt_step 149 else: 150 151 def adapt_fn(iterator): 152 for _ in math_ops.range(self._steps_per_execution): 153 adapt_step(iterator) 154 155 if not self._run_eagerly: 156 adapt_fn = def_function.function(adapt_fn) 157 158 self._adapt_function = adapt_fn 159 return self._adapt_function 160 161 def compile(self, run_eagerly=None, steps_per_execution=None): 162 """Configures the layer for `adapt`. 163 164 Arguments: 165 run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic 166 will not be wrapped in a `tf.function`. Recommended to leave this as 167 `None` unless your `Model` cannot be run inside a `tf.function`. 168 steps_per_execution: Int. Defaults to 1. The number of batches to run 169 during each `tf.function` call. Running multiple batches inside a 170 single `tf.function` call can greatly improve performance on TPUs or 171 small models with a large Python overhead. 172 """ 173 if steps_per_execution is None: 174 steps_per_execution = 1 175 self._configure_steps_per_execution(steps_per_execution) 176 177 if run_eagerly is None: 178 run_eagerly = self.dynamic 179 self._run_eagerly = run_eagerly 180 181 self._is_compiled = True 182 183 def adapt(self, data, batch_size=None, steps=None, reset_state=True): 184 """Fits the state of the preprocessing layer to the data being passed. 185 186 Arguments: 187 data: The data to train on. It can be passed either as a tf.data 188 Dataset, or as a numpy array. 189 batch_size: Integer or `None`. 190 Number of samples per state update. 191 If unspecified, `batch_size` will default to 32. 192 Do not specify the `batch_size` if your data is in the 193 form of datasets, generators, or `keras.utils.Sequence` instances 194 (since they generate batches). 195 steps: Integer or `None`. 196 Total number of steps (batches of samples) 197 When training with input tensors such as 198 TensorFlow data tensors, the default `None` is equal to 199 the number of samples in your dataset divided by 200 the batch size, or 1 if that cannot be determined. If x is a 201 `tf.data` dataset, and 'steps' is None, the epoch will run until 202 the input dataset is exhausted. When passing an infinitely 203 repeating dataset, you must specify the `steps` argument. This 204 argument is not supported with array inputs. 205 reset_state: Optional argument specifying whether to clear the state of 206 the layer at the start of the call to `adapt`, or whether to start 207 from the existing state. This argument may not be relevant to all 208 preprocessing layers: a subclass of PreprocessingLayer may choose to 209 throw if 'reset_state' is set to False. 210 """ 211 _disallow_inside_tf_function('adapt') 212 if not self.stateful: 213 return 214 if not self.streaming and self._is_adapted and not reset_state: 215 raise ValueError('{} does not supporting calling `adapt` twice without ' 216 'resetting the state.'.format(self.__class__.__name__)) 217 if not self._is_compiled: 218 self.compile() # Compile with defaults. 219 if self.built and reset_state: 220 self.reset_state() 221 data_handler = data_adapter.DataHandler( 222 data, 223 batch_size=batch_size, 224 steps_per_epoch=steps, 225 epochs=1, 226 steps_per_execution=self._steps_per_execution, 227 distribute=False) 228 self._adapt_function = self.make_adapt_function() 229 for _, iterator in data_handler.enumerate_epochs(): 230 with data_handler.catch_stop_iteration(): 231 for _ in data_handler.steps(): 232 self._adapt_function(iterator) 233 if data_handler.should_sync: 234 context.async_wait() 235 self.finalize_state() 236 self._is_adapted = True 237 238 def _reset_state_wrapper(self): 239 """Calls `reset_state` and sets `adapted` to `False`.""" 240 self._reset_state_impl() 241 self._is_adapted = False 242 243 @trackable.no_automatic_dependency_tracking 244 def _configure_steps_per_execution(self, steps_per_execution): 245 self._steps_per_execution = variables.Variable( 246 steps_per_execution, 247 dtype='int64', 248 aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA) 249 250 # TODO(omalleyt): Unify this logic with `Layer._maybe_build`. 251 def _adapt_maybe_build(self, data): 252 if not self.built: 253 try: 254 # If this is a Numpy array or tensor, we can get shape from .shape. 255 # If not, an attribute error will be thrown. 256 data_shape = data.shape 257 data_shape_nones = tuple([None] * len(data.shape)) 258 except AttributeError: 259 # The input has an unknown number of dimensions. 260 data_shape = None 261 data_shape_nones = None 262 263 # TODO (b/159261555): move this to base layer build. 264 batch_input_shape = getattr(self, '_batch_input_shape', None) 265 if batch_input_shape is None: 266 # Set the number of dimensions. 267 self._batch_input_shape = data_shape_nones 268 self.build(data_shape) 269 self.built = True 270 271 272# TODO(omalleyt): This class will be gradually replaced. 273class CombinerPreprocessingLayer(PreprocessingLayer): 274 """Base class for PreprocessingLayers that do computation using a Combiner. 275 276 This class provides several helper methods to make creating a 277 PreprocessingLayer easier. It assumes that the core of your computation will 278 be done via a Combiner object. Subclassing this class to create a 279 PreprocessingLayer allows your layer to be compatible with distributed 280 computation. 281 282 This class is compatible with Tensorflow 2.0+. 283 """ 284 285 def __init__(self, combiner, **kwargs): 286 super(CombinerPreprocessingLayer, self).__init__(stateful=True, **kwargs) 287 self.state_variables = collections.OrderedDict() 288 self._combiner = combiner 289 self._adapt_accumulator = None 290 291 def reset_state(self): 292 self._adapt_accumulator = None 293 294 def update_state(self, data): 295 if self._adapt_accumulator is None: 296 self._adapt_accumulator = self._get_accumulator() 297 self._adapt_accumulator = self._combiner.compute(data, 298 self._adapt_accumulator) 299 300 def merge_state(self, layers): 301 accumulators = ([self._get_accumulator()] + 302 [l._get_accumulator() for l in layers]) # pylint: disable=protected-access 303 merged_accumulator = self._combiner.merge(accumulators) 304 self._set_accumulator(merged_accumulator) 305 306 def finalize_state(self): 307 self._set_accumulator(self._adapt_accumulator) 308 309 def compile(self, run_eagerly=None, steps_per_execution=None): 310 # TODO(omalleyt): Remove this once sublayers are switched to new APIs. 311 if run_eagerly is None: 312 run_eagerly = True 313 super(CombinerPreprocessingLayer, self).compile( 314 run_eagerly=run_eagerly, steps_per_execution=steps_per_execution) 315 316 def adapt(self, data, batch_size=None, steps=None, reset_state=True): 317 if not reset_state: 318 self._adapt_accumulator = self._combiner.restore(self._restore_updates()) 319 super(CombinerPreprocessingLayer, self).adapt( 320 data, batch_size=batch_size, steps=steps, reset_state=reset_state) 321 322 def _add_state_variable(self, 323 name, 324 shape, 325 dtype, 326 initializer=None, 327 partitioner=None, 328 use_resource=None, 329 **kwargs): 330 """Add a variable that can hold state which is updated during adapt(). 331 332 Args: 333 name: Variable name. 334 shape: Variable shape. Defaults to scalar if unspecified. 335 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 336 initializer: initializer instance (callable). 337 partitioner: Partitioner to be passed to the `Trackable` API. 338 use_resource: Whether to use `ResourceVariable` 339 **kwargs: Additional keyword arguments. Accepted values are `getter` and 340 `collections`. 341 342 Returns: 343 The created variable. 344 """ 345 weight = self.add_weight( 346 name=name, 347 shape=shape, 348 dtype=dtype, 349 initializer=initializer, 350 regularizer=None, 351 trainable=False, 352 constraint=None, 353 partitioner=partitioner, 354 use_resource=use_resource, 355 **kwargs) 356 # TODO(momernick): Do not allow collisions here. 357 self.state_variables[name] = weight 358 return weight 359 360 def _restore_updates(self): 361 """Recreates a dict of updates from the layer's weights.""" 362 data_dict = {} 363 for name, var in self.state_variables.items(): 364 data_dict[name] = var.numpy() 365 return data_dict 366 367 def _get_accumulator(self): 368 if self._is_adapted: 369 return self._combiner.restore(self._restore_updates()) 370 else: 371 return None 372 373 def _set_accumulator(self, accumulator): 374 updates = self._combiner.extract(accumulator) 375 self._set_state_variables(updates) 376 self._adapt_accumulator = None # Reset accumulator from adapt. 377 378 def _set_state_variables(self, updates): 379 """Directly update the internal state of this Layer. 380 381 This method expects a string-keyed dict of {state_variable_name: state}. The 382 precise nature of the state, and the names associated, are describe by 383 the subclasses of CombinerPreprocessingLayer. 384 385 Args: 386 updates: A string keyed dict of weights to update. 387 388 Raises: 389 RuntimeError: if 'build()' was not called before 'set_processing_state'. 390 """ 391 # TODO(momernick): Do we need to do any more input sanitization? 392 if not self.built: 393 raise RuntimeError('_set_state_variables() must be called after build().') 394 395 with ops.init_scope(): 396 for var_name, value in updates.items(): 397 self.state_variables[var_name].assign(value) 398 399 400def convert_to_list(values, sparse_default_value=None): 401 """Convert a TensorLike, CompositeTensor, or ndarray into a Python list.""" 402 if tf_utils.is_ragged(values): 403 # There is a corner case when dealing with ragged tensors: if you get an 404 # actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode, 405 # you can't call to_list() on it without evaluating it first. However, 406 # because we don't yet fully support composite tensors across Keras, 407 # K.get_value() won't evaluate the tensor. 408 # TODO(momernick): Get Keras to recognize composite tensors as Tensors 409 # and then replace this with a call to K.get_value. 410 if (isinstance(values, ragged_tensor.RaggedTensor) and 411 not context.executing_eagerly()): 412 values = K.get_session(values).run(values) 413 values = values.to_list() 414 415 if isinstance(values, 416 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 417 if sparse_default_value is None: 418 if dtypes.as_dtype(values.values.dtype) == dtypes.string: 419 sparse_default_value = '' 420 else: 421 sparse_default_value = -1 422 dense_tensor = sparse_ops.sparse_tensor_to_dense( 423 values, default_value=sparse_default_value) 424 values = K.get_value(dense_tensor) 425 426 if isinstance(values, ops.Tensor): 427 values = K.get_value(values) 428 429 # We may get passed a ndarray or the code above may give us a ndarray. 430 # In either case, we want to force it into a standard python list. 431 if isinstance(values, np.ndarray): 432 values = values.tolist() 433 434 return values 435 436 437# TODO(omalleyt): This class will be gradually replaced. 438class Combiner(object): 439 """Functional object that defines a shardable computation. 440 441 This object defines functions required to create and manipulate data objects. 442 These data objects, referred to below as 'accumulators', are computation- 443 specific and may be implemented alongside concrete subclasses of Combiner 444 (if necessary - some computations may be simple enough that standard Python 445 types can be used as accumulators). 446 447 The intent for this class is that by describing computations in this way, we 448 can arbitrarily shard a dataset, perform computations on a subset, and then 449 merge the computation into a final result. This enables distributed 450 computation. 451 452 The combiner itself does not own any state - all computational state is owned 453 by the accumulator objects. This is so that we can have an arbitrary number of 454 Combiners (thus sharding the computation N ways) without risking any change 455 to the underlying computation. These accumulator objects are uniquely 456 associated with each Combiner; a Combiner defines what the accumulator object 457 should be and will only work with accumulators of that type. 458 """ 459 __metaclass__ = abc.ABCMeta 460 461 def __repr__(self): 462 return '<{}>'.format(self.__class__.__name__) 463 464 @abc.abstractmethod 465 def compute(self, batch_values, accumulator=None): 466 """Compute a step in this computation, returning a new accumulator. 467 468 This method computes a step of the computation described by this Combiner. 469 If an accumulator is passed, the data in that accumulator is also used; so 470 compute(batch_values) results in f(batch_values), while 471 compute(batch_values, accumulator) results in 472 merge(f(batch_values), accumulator). 473 474 Args: 475 batch_values: A list of ndarrays representing the values of the inputs for 476 this step of the computation. 477 accumulator: the current accumulator. Can be None. 478 479 Returns: 480 An accumulator that includes the passed batch of inputs. 481 """ 482 pass 483 484 @abc.abstractmethod 485 def merge(self, accumulators): 486 """Merge several accumulators to a single accumulator. 487 488 This method takes the partial values in several accumulators and combines 489 them into a single accumulator. This computation must not be order-specific 490 (that is, merge([a, b]) must return the same result as merge([b, a]). 491 492 Args: 493 accumulators: the accumulators to merge, as a list. 494 495 Returns: 496 A merged accumulator. 497 """ 498 pass 499 500 @abc.abstractmethod 501 def extract(self, accumulator): 502 """Convert an accumulator into a dict of output values. 503 504 Args: 505 accumulator: The accumulator to convert. 506 507 Returns: 508 A dict of ndarrays representing the data in this accumulator. 509 """ 510 pass 511 512 @abc.abstractmethod 513 def restore(self, output): 514 """Create an accumulator based on 'output'. 515 516 This method creates a new accumulator with identical internal state to the 517 one used to create the data in 'output'. This means that if you do 518 519 output_data = combiner.extract(accumulator_1) 520 accumulator_2 = combiner.restore(output_data) 521 522 then accumulator_1 and accumulator_2 will have identical internal state, and 523 computations using either of them will be equivalent. 524 525 Args: 526 output: The data output from a previous computation. Should be in the same 527 form as provided by 'extract_output'. 528 529 Returns: 530 A new accumulator. 531 """ 532 pass 533 534 @abc.abstractmethod 535 def serialize(self, accumulator): 536 """Serialize an accumulator for a remote call. 537 538 This function serializes an accumulator to be sent to a remote process. 539 540 Args: 541 accumulator: The accumulator to serialize. 542 543 Returns: 544 A byte string representing the passed accumulator. 545 """ 546 pass 547 548 @abc.abstractmethod 549 def deserialize(self, encoded_accumulator): 550 """Deserialize an accumulator received from 'serialize()'. 551 552 This function deserializes an accumulator serialized by 'serialize()'. 553 554 Args: 555 encoded_accumulator: A byte string representing an accumulator. 556 557 Returns: 558 The accumulator represented by the passed byte_string. 559 """ 560 pass 561 562 563def _disallow_inside_tf_function(method_name): 564 """Disallow calling a method inside a `tf.function`.""" 565 if ops.inside_function(): 566 error_msg = ( 567 'Detected a call to `PreprocessingLayer.{method_name}` inside a ' 568 '`tf.function`. `PreprocessingLayer.{method_name} is a high-level ' 569 'endpoint that manages its own `tf.function`. Please move the call ' 570 'to `PreprocessingLayer.{method_name}` outside of all enclosing ' 571 '`tf.function`s. Note that you can call a `PreprocessingLayer` ' 572 'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, ' 573 'or update its state like: `layer.update_state(x)`.').format( 574 method_name=method_name) 575 raise RuntimeError(error_msg) 576