1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for tf.data options.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22def _internal_attr_name(name): 23 return "_" + name 24 25 26class OptionsBase(object): 27 """Base class for representing a set of tf.data options. 28 29 Attributes: 30 _options: Stores the option values. 31 """ 32 33 def __init__(self): 34 # NOTE: Cannot use `self._options` here as we override `__setattr__` 35 object.__setattr__(self, "_options", {}) 36 37 def __eq__(self, other): 38 if not isinstance(other, self.__class__): 39 return NotImplemented 40 for name in set(self._options) | set(other._options): # pylint: disable=protected-access 41 if getattr(self, name) != getattr(other, name): 42 return False 43 return True 44 45 def __ne__(self, other): 46 if isinstance(other, self.__class__): 47 return not self.__eq__(other) 48 else: 49 return NotImplemented 50 51 def __setattr__(self, name, value): 52 if hasattr(self, name): 53 object.__setattr__(self, name, value) 54 else: 55 raise AttributeError( 56 "Cannot set the property %s on %s." % (name, type(self).__name__)) 57 58 59def create_option(name, ty, docstring, default_factory=lambda: None): 60 """Creates a type-checked property. 61 62 Args: 63 name: The name to use. 64 ty: The type to use. The type of the property will be validated when it 65 is set. 66 docstring: The docstring to use. 67 default_factory: A callable that takes no arguments and returns a default 68 value to use if not set. 69 70 Returns: 71 A type-checked property. 72 """ 73 74 def get_fn(option): 75 # pylint: disable=protected-access 76 if name not in option._options: 77 option._options[name] = default_factory() 78 return option._options.get(name) 79 80 def set_fn(option, value): 81 if not isinstance(value, ty): 82 raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" % 83 (name, ty, value, type(value))) 84 option._options[name] = value # pylint: disable=protected-access 85 86 return property(get_fn, set_fn, None, docstring) 87 88 89def merge_options(*options_list): 90 """Merges the given options, returning the result as a new options object. 91 92 The input arguments are expected to have a matching type that derives from 93 `OptionsBase` (and thus each represent a set of options). The method outputs 94 an object of the same type created by merging the sets of options represented 95 by the input arguments. 96 97 The sets of options can be merged as long as there does not exist an option 98 with different non-default values. 99 100 If an option is an instance of `OptionsBase` itself, then this method is 101 applied recursively to the set of options represented by this option. 102 103 Args: 104 *options_list: options to merge 105 106 Raises: 107 TypeError: if the input arguments are incompatible or not derived from 108 `OptionsBase` 109 ValueError: if the given options cannot be merged 110 111 Returns: 112 A new options object which is the result of merging the given options. 113 """ 114 if len(options_list) < 1: 115 raise ValueError("At least one options should be provided") 116 result_type = type(options_list[0]) 117 118 for options in options_list: 119 if not isinstance(options, result_type): 120 raise TypeError("Incompatible options type: %r vs %r" % (type(options), 121 result_type)) 122 123 if not isinstance(options_list[0], OptionsBase): 124 raise TypeError("The inputs should inherit from `OptionsBase`") 125 126 default_options = result_type() 127 result = result_type() 128 for options in options_list: 129 # Iterate over all set options and merge the into the result. 130 for name in options._options: # pylint: disable=protected-access 131 this = getattr(result, name) 132 that = getattr(options, name) 133 default = getattr(default_options, name) 134 if that == default: 135 continue 136 elif this == default: 137 setattr(result, name, that) 138 elif isinstance(this, OptionsBase): 139 setattr(result, name, merge_options(this, that)) 140 elif this != that: 141 raise ValueError( 142 "Cannot merge incompatible values (%r and %r) of option: %s" % 143 (this, that, name)) 144 return result 145