1# Description: 2# Contains the Keras save model API (internal TensorFlow version). 3 4load("//tensorflow:tensorflow.bzl", "tf_py_test") 5 6package( 7 # TODO(scottzhu): Remove non-keras deps from TF. 8 default_visibility = [ 9 "//tensorflow/python/distribute:__pkg__", 10 "//tensorflow/python/keras:__subpackages__", 11 ], 12 licenses = ["notice"], # Apache 2.0 13) 14 15exports_files(["LICENSE"]) 16 17filegroup( 18 name = "all_py_srcs", 19 srcs = glob([ 20 "*.py", 21 "**/*.py", 22 ]), 23 visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"], 24) 25 26py_library( 27 name = "saving", 28 srcs = [ 29 "__init__.py", 30 "hdf5_format.py", 31 "model_config.py", 32 "save.py", 33 "saved_model/base_serialization.py", 34 "saved_model/constants.py", 35 "saved_model/json_utils.py", 36 "saved_model/layer_serialization.py", 37 "saved_model/load.py", 38 "saved_model/metric_serialization.py", 39 "saved_model/model_serialization.py", 40 "saved_model/network_serialization.py", 41 "saved_model/save.py", 42 "saved_model/save_impl.py", 43 "saved_model/serialized_attributes.py", 44 "saved_model/utils.py", 45 "saved_model_experimental.py", 46 "saving_utils.py", 47 ], 48 srcs_version = "PY3", 49 deps = [ 50 ":load_context", 51 "//tensorflow/python:lib", 52 "//tensorflow/python:math_ops", 53 "//tensorflow/python:platform", 54 "//tensorflow/python:saver", 55 "//tensorflow/python:tensor_spec", 56 "//tensorflow/python/eager:def_function", 57 "//tensorflow/python/keras:backend", 58 "//tensorflow/python/keras:losses", 59 "//tensorflow/python/keras:optimizers", 60 "//tensorflow/python/keras:regularizers", 61 "//tensorflow/python/keras/engine:input_spec", 62 "//tensorflow/python/keras/mixed_precision:autocast_variable", 63 "//tensorflow/python/keras/protobuf:saved_metadata_proto_py", 64 "//tensorflow/python/keras/utils:engine_utils", 65 "//tensorflow/python/keras/utils:metrics_utils", 66 "//tensorflow/python/keras/utils:mode_keys", 67 "//tensorflow/python/saved_model", 68 "//tensorflow/python/saved_model/model_utils", 69 "//tensorflow/python/training/tracking", 70 ], 71) 72 73py_library( 74 name = "load_context", 75 srcs = [ 76 "saved_model/load_context.py", 77 ], 78 srcs_version = "PY3", 79 deps = [], 80) 81 82tf_py_test( 83 name = "metrics_serialization_test", 84 size = "medium", 85 srcs = ["metrics_serialization_test.py"], 86 python_version = "PY3", 87 shard_count = 8, 88 tags = [ 89 "notsan", # TODO(b/170870790) 90 ], 91 deps = [ 92 "//tensorflow/python:client_testlib", 93 "//tensorflow/python/keras", 94 "//third_party/py/numpy", 95 "@absl_py//absl/testing:parameterized", 96 ], 97) 98 99tf_py_test( 100 name = "losses_serialization_test", 101 size = "medium", 102 srcs = ["losses_serialization_test.py"], 103 python_version = "PY3", 104 shard_count = 4, 105 deps = [ 106 "//tensorflow/python:client_testlib", 107 "//tensorflow/python/keras", 108 "//third_party/py/numpy", 109 "@absl_py//absl/testing:parameterized", 110 ], 111) 112 113tf_py_test( 114 name = "save_weights_test", 115 size = "medium", 116 srcs = ["save_weights_test.py"], 117 python_version = "PY3", 118 shard_count = 4, 119 tags = [ 120 "no_oss_py35", # b/147011479 121 "no_windows", 122 ], 123 deps = [ 124 "//tensorflow/python:client_testlib", 125 "//tensorflow/python/keras", 126 "//tensorflow/python/keras:combinations", 127 "//third_party/py/numpy", 128 "@absl_py//absl/testing:parameterized", 129 ], 130) 131 132tf_py_test( 133 name = "save_test", 134 size = "medium", 135 srcs = ["save_test.py"], 136 python_version = "PY3", 137 shard_count = 4, 138 deps = [ 139 "//tensorflow/python:client_testlib", 140 "//tensorflow/python/feature_column:feature_column_v2", 141 "//tensorflow/python/keras", 142 "//tensorflow/python/keras:combinations", 143 "//third_party/py/numpy", 144 "@absl_py//absl/testing:parameterized", 145 ], 146) 147 148tf_py_test( 149 name = "saved_model_experimental_test", 150 size = "medium", 151 srcs = ["saved_model_experimental_test.py"], 152 python_version = "PY3", 153 shard_count = 4, 154 tags = [ 155 "no_oss", # TODO(b/119349471): Re-enable 156 "no_windows", 157 ], 158 deps = [ 159 "//tensorflow/python:client_testlib", 160 "//tensorflow/python/keras", 161 "//third_party/py/numpy", 162 "@absl_py//absl/testing:parameterized", 163 ], 164) 165 166tf_py_test( 167 name = "saved_model_test", 168 size = "medium", 169 srcs = ["saved_model/saved_model_test.py"], 170 python_version = "PY3", 171 shard_count = 4, 172 tags = [ 173 "no_rocm", 174 "no_windows", 175 "notap", # TODO(b/161198218): flaky timeout 176 ], 177 deps = [ 178 "//tensorflow/python:client_testlib", 179 "//tensorflow/python/compat:v2_compat", 180 "//tensorflow/python/distribute:mirrored_strategy", 181 "//tensorflow/python/keras", 182 "//tensorflow/python/keras:combinations", 183 "//third_party/py/numpy", 184 "@absl_py//absl/testing:parameterized", 185 ], 186) 187 188tf_py_test( 189 name = "saving_utils_test", 190 size = "medium", 191 srcs = ["saving_utils_test.py"], 192 python_version = "PY3", 193 tags = ["notsan"], 194 deps = [ 195 "//tensorflow/python:client_testlib", 196 "//tensorflow/python/keras", 197 "//tensorflow/python/keras:combinations", 198 "//third_party/py/numpy", 199 "@absl_py//absl/testing:parameterized", 200 ], 201) 202 203tf_py_test( 204 name = "revive_test", 205 size = "medium", 206 srcs = ["saved_model/revive_test.py"], 207 python_version = "PY3", 208 shard_count = 8, 209 tags = [ 210 "no_windows", # b/158005583 211 ], 212 deps = [ 213 "//tensorflow/python:client_testlib", 214 "//tensorflow/python/keras", 215 "//third_party/py/numpy", 216 "@absl_py//absl/testing:parameterized", 217 ], 218) 219 220tf_py_test( 221 name = "json_utils_test", 222 size = "small", 223 srcs = ["saved_model/json_utils_test.py"], 224 python_version = "PY3", 225 deps = [ 226 ":saving", 227 "//tensorflow/python:client_testlib", 228 "//third_party/py/numpy", 229 "@absl_py//absl/testing:parameterized", 230 ], 231) 232