• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Description:
2#   Contains the Keras save model API (internal TensorFlow version).
3
4load("//tensorflow:tensorflow.bzl", "tf_py_test")
5
6package(
7    # TODO(scottzhu): Remove non-keras deps from TF.
8    default_visibility = [
9        "//tensorflow/python/distribute:__pkg__",
10        "//tensorflow/python/keras:__subpackages__",
11    ],
12    licenses = ["notice"],  # Apache 2.0
13)
14
15exports_files(["LICENSE"])
16
17filegroup(
18    name = "all_py_srcs",
19    srcs = glob([
20        "*.py",
21        "**/*.py",
22    ]),
23    visibility = ["//tensorflow/python/keras/google/private_tf_api_test:__pkg__"],
24)
25
26py_library(
27    name = "saving",
28    srcs = [
29        "__init__.py",
30        "hdf5_format.py",
31        "model_config.py",
32        "save.py",
33        "saved_model/base_serialization.py",
34        "saved_model/constants.py",
35        "saved_model/json_utils.py",
36        "saved_model/layer_serialization.py",
37        "saved_model/load.py",
38        "saved_model/metric_serialization.py",
39        "saved_model/model_serialization.py",
40        "saved_model/network_serialization.py",
41        "saved_model/save.py",
42        "saved_model/save_impl.py",
43        "saved_model/serialized_attributes.py",
44        "saved_model/utils.py",
45        "saved_model_experimental.py",
46        "saving_utils.py",
47    ],
48    srcs_version = "PY3",
49    deps = [
50        ":load_context",
51        "//tensorflow/python:lib",
52        "//tensorflow/python:math_ops",
53        "//tensorflow/python:platform",
54        "//tensorflow/python:saver",
55        "//tensorflow/python:tensor_spec",
56        "//tensorflow/python/eager:def_function",
57        "//tensorflow/python/keras:backend",
58        "//tensorflow/python/keras:losses",
59        "//tensorflow/python/keras:optimizers",
60        "//tensorflow/python/keras:regularizers",
61        "//tensorflow/python/keras/engine:input_spec",
62        "//tensorflow/python/keras/mixed_precision:autocast_variable",
63        "//tensorflow/python/keras/protobuf:saved_metadata_proto_py",
64        "//tensorflow/python/keras/utils:engine_utils",
65        "//tensorflow/python/keras/utils:metrics_utils",
66        "//tensorflow/python/keras/utils:mode_keys",
67        "//tensorflow/python/saved_model",
68        "//tensorflow/python/saved_model/model_utils",
69        "//tensorflow/python/training/tracking",
70    ],
71)
72
73py_library(
74    name = "load_context",
75    srcs = [
76        "saved_model/load_context.py",
77    ],
78    srcs_version = "PY3",
79    deps = [],
80)
81
82tf_py_test(
83    name = "metrics_serialization_test",
84    size = "medium",
85    srcs = ["metrics_serialization_test.py"],
86    python_version = "PY3",
87    shard_count = 8,
88    tags = [
89        "notsan",  # TODO(b/170870790)
90    ],
91    deps = [
92        "//tensorflow/python:client_testlib",
93        "//tensorflow/python/keras",
94        "//third_party/py/numpy",
95        "@absl_py//absl/testing:parameterized",
96    ],
97)
98
99tf_py_test(
100    name = "losses_serialization_test",
101    size = "medium",
102    srcs = ["losses_serialization_test.py"],
103    python_version = "PY3",
104    shard_count = 4,
105    deps = [
106        "//tensorflow/python:client_testlib",
107        "//tensorflow/python/keras",
108        "//third_party/py/numpy",
109        "@absl_py//absl/testing:parameterized",
110    ],
111)
112
113tf_py_test(
114    name = "save_weights_test",
115    size = "medium",
116    srcs = ["save_weights_test.py"],
117    python_version = "PY3",
118    shard_count = 4,
119    tags = [
120        "no_oss_py35",  # b/147011479
121        "no_windows",
122    ],
123    deps = [
124        "//tensorflow/python:client_testlib",
125        "//tensorflow/python/keras",
126        "//tensorflow/python/keras:combinations",
127        "//third_party/py/numpy",
128        "@absl_py//absl/testing:parameterized",
129    ],
130)
131
132tf_py_test(
133    name = "save_test",
134    size = "medium",
135    srcs = ["save_test.py"],
136    python_version = "PY3",
137    shard_count = 4,
138    deps = [
139        "//tensorflow/python:client_testlib",
140        "//tensorflow/python/feature_column:feature_column_v2",
141        "//tensorflow/python/keras",
142        "//tensorflow/python/keras:combinations",
143        "//third_party/py/numpy",
144        "@absl_py//absl/testing:parameterized",
145    ],
146)
147
148tf_py_test(
149    name = "saved_model_experimental_test",
150    size = "medium",
151    srcs = ["saved_model_experimental_test.py"],
152    python_version = "PY3",
153    shard_count = 4,
154    tags = [
155        "no_oss",  # TODO(b/119349471): Re-enable
156        "no_windows",
157    ],
158    deps = [
159        "//tensorflow/python:client_testlib",
160        "//tensorflow/python/keras",
161        "//third_party/py/numpy",
162        "@absl_py//absl/testing:parameterized",
163    ],
164)
165
166tf_py_test(
167    name = "saved_model_test",
168    size = "medium",
169    srcs = ["saved_model/saved_model_test.py"],
170    python_version = "PY3",
171    shard_count = 4,
172    tags = [
173        "no_rocm",
174        "no_windows",
175        "notap",  # TODO(b/161198218): flaky timeout
176    ],
177    deps = [
178        "//tensorflow/python:client_testlib",
179        "//tensorflow/python/compat:v2_compat",
180        "//tensorflow/python/distribute:mirrored_strategy",
181        "//tensorflow/python/keras",
182        "//tensorflow/python/keras:combinations",
183        "//third_party/py/numpy",
184        "@absl_py//absl/testing:parameterized",
185    ],
186)
187
188tf_py_test(
189    name = "saving_utils_test",
190    size = "medium",
191    srcs = ["saving_utils_test.py"],
192    python_version = "PY3",
193    tags = ["notsan"],
194    deps = [
195        "//tensorflow/python:client_testlib",
196        "//tensorflow/python/keras",
197        "//tensorflow/python/keras:combinations",
198        "//third_party/py/numpy",
199        "@absl_py//absl/testing:parameterized",
200    ],
201)
202
203tf_py_test(
204    name = "revive_test",
205    size = "medium",
206    srcs = ["saved_model/revive_test.py"],
207    python_version = "PY3",
208    shard_count = 8,
209    tags = [
210        "no_windows",  # b/158005583
211    ],
212    deps = [
213        "//tensorflow/python:client_testlib",
214        "//tensorflow/python/keras",
215        "//third_party/py/numpy",
216        "@absl_py//absl/testing:parameterized",
217    ],
218)
219
220tf_py_test(
221    name = "json_utils_test",
222    size = "small",
223    srcs = ["saved_model/json_utils_test.py"],
224    python_version = "PY3",
225    deps = [
226        ":saving",
227        "//tensorflow/python:client_testlib",
228        "//third_party/py/numpy",
229        "@absl_py//absl/testing:parameterized",
230    ],
231)
232