1# Description:
2#   Contains the Keras API (internal TensorFlow version).
3load("//tensorflow:tensorflow.bzl", "tf_py_test")
4load("//tensorflow:tensorflow.bzl", "cuda_py_test")
5
6licenses(["notice"])  # Apache 2.0
7
8package(default_visibility = ["//visibility:public"])
9
10exports_files(["LICENSE"])
11
12config_setting(
13    name = "empty_condition",
14    values = {"define": "UNUSED=unused"},
15)
16
17py_library(
18    name = "keras",
19    srcs = [
20        "__init__.py",
21        "applications/__init__.py",
22        "applications/densenet.py",
23        "applications/imagenet_utils.py",
24        "applications/inception_resnet_v2.py",
25        "applications/inception_v3.py",
26        "applications/mobilenet.py",
27        "applications/mobilenet_v2.py",
28        "applications/nasnet.py",
29        "applications/resnet50.py",
30        "applications/vgg16.py",
31        "applications/vgg19.py",
32        "applications/xception.py",
33        "datasets/__init__.py",
34        "datasets/boston_housing.py",
35        "datasets/cifar.py",
36        "datasets/cifar10.py",
37        "datasets/cifar100.py",
38        "datasets/fashion_mnist.py",
39        "datasets/imdb.py",
40        "datasets/mnist.py",
41        "datasets/reuters.py",
42        "estimator/__init__.py",
43        "keras_parameterized.py",
44        "ops.py",
45        "preprocessing/__init__.py",
46        "preprocessing/image.py",
47        "preprocessing/sequence.py",
48        "preprocessing/text.py",
49        "testing_utils.py",
50        "utils/__init__.py",
51        "utils/multi_gpu_utils.py",
52        "utils/np_utils.py",
53        "utils/vis_utils.py",
54        "wrappers/__init__.py",
55        "wrappers/scikit_learn.py",
56    ],
57    srcs_version = "PY2AND3",
58    visibility = ["//visibility:public"],
59    deps = [
60        ":backend",
61        ":engine",
62        ":layers",
63        ":pil_for_keras",
64        ":saving",
65        "//tensorflow/python:training",
66        "//tensorflow/python/keras/mixed_precision/experimental:mixed_precision_experimental",
67        "//tensorflow/python/keras/optimizer_v2",
68        "//tensorflow/python/saved_model",
69        "@keras_applications_archive//:keras_applications",
70    ],
71)
72
73py_library(
74    name = "pil_for_keras",
75    deps = select({
76        ":empty_condition": [],
77        "//conditions:default": [],
78    }),
79)
80
81py_library(
82    name = "backend",
83    srcs = ["backend.py"],
84    srcs_version = "PY2AND3",
85    deps = [
86        ":backend_config",
87        "//tensorflow/core:protos_all_py",
88        "//tensorflow/python:array_ops",
89        "//tensorflow/python:check_ops",
90        "//tensorflow/python:client",
91        "//tensorflow/python:clip_ops",
92        "//tensorflow/python:constant_op",
93        "//tensorflow/python:control_flow_ops",
94        "//tensorflow/python:ctc_ops",
95        "//tensorflow/python:distribute",
96        "//tensorflow/python:dtypes",
97        "//tensorflow/python:framework",
98        "//tensorflow/python:framework_ops",
99        "//tensorflow/python:functional_ops",
100        "//tensorflow/python:gradients",
101        "//tensorflow/python:image_ops",
102        "//tensorflow/python:init_ops",
103        "//tensorflow/python:init_ops_v2",
104        "//tensorflow/python:logging_ops",
105        "//tensorflow/python:map_fn",
106        "//tensorflow/python:math_ops",
107        "//tensorflow/python:metrics",
108        "//tensorflow/python:nn",
109        "//tensorflow/python:platform",
110        "//tensorflow/python:random_ops",
111        "//tensorflow/python:session",
112        "//tensorflow/python:sparse_ops",
113        "//tensorflow/python:sparse_tensor",
114        "//tensorflow/python:state_ops",
115        "//tensorflow/python:summary",
116        "//tensorflow/python:tensor_array_grad",
117        "//tensorflow/python:tensor_array_ops",
118        "//tensorflow/python:tensor_shape",
119        "//tensorflow/python:training",
120        "//tensorflow/python:util",
121        "//tensorflow/python:variables",
122        "//tensorflow/python/distribute:distribute_coordinator",
123    ],
124)
125
126py_library(
127    name = "backend_config",
128    srcs = ["backend_config.py"],
129    srcs_version = "PY2AND3",
130)
131
132py_library(
133    name = "engine",
134    srcs = [
135        "engine/__init__.py",
136        "engine/base_layer.py",
137        "engine/base_layer_utils.py",
138        "engine/distributed_training_utils.py",
139        "engine/input_layer.py",
140        "engine/input_spec.py",
141        "engine/network.py",
142        "engine/partial_batch_padding_handler.py",
143        "engine/saving.py",
144        "engine/sequential.py",
145        "engine/training.py",
146        "engine/training_arrays.py",
147        "engine/training_distributed.py",
148        "engine/training_eager.py",
149        "engine/training_generator.py",
150        "engine/training_utils.py",
151        "metrics.py",  # Need base_layer
152        "models.py",
153        "utils/metrics_utils.py",
154    ],
155    srcs_version = "PY2AND3",
156    deps = [
157        ":activations",
158        ":backend",
159        ":callbacks",
160        ":callbacks_v1",
161        ":constraints",
162        ":engine_utils",
163        ":initializers",
164        ":losses",
165        ":mode_keys",
166        ":optimizers",
167        ":regularizers",
168        ":saving",
169        "//tensorflow/python/data",
170        "//tensorflow/python/distribute:distribute_coordinator",
171        "//tensorflow/python/distribute:distribute_lib",
172        "//tensorflow/python/distribute:input_lib",
173        "//tensorflow/python/distribute:reduce_util",
174        "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable",
175        "//tensorflow/python/keras/mixed_precision/experimental:policy",
176        "//tensorflow/python/training/tracking:data_structures",
177        "//tensorflow/tools/docs:doc_controls",
178        "@six_archive//:six",
179    ],
180)
181
182py_library(
183    name = "saving",
184    srcs = [
185        "saving/__init__.py",
186        "saving/hdf5_format.py",
187        "saving/model_config.py",
188        "saving/saved_model.py",
189        "saving/saving_utils.py",
190    ],
191    srcs_version = "PY2AND3",
192    deps = [
193        ":backend",
194        ":engine_utils",
195        ":mode_keys",
196        ":optimizers",
197        "//tensorflow/python:lib",
198        "//tensorflow/python:saver",
199        "//tensorflow/python/saved_model",
200        "//tensorflow/python/saved_model/model_utils",
201    ],
202)
203
204py_library(
205    name = "activations",
206    srcs = [
207        "activations.py",
208    ],
209    srcs_version = "PY2AND3",
210    deps = [
211        ":backend",
212        ":engine_utils",
213    ],
214)
215
216py_library(
217    name = "callbacks",
218    srcs = [
219        "callbacks.py",
220    ],
221    srcs_version = "PY2AND3",
222    deps = [
223        ":backend",
224        ":engine_utils",
225        ":mode_keys",
226    ],
227)
228
229py_library(
230    name = "callbacks_v1",
231    srcs = [
232        "callbacks_v1.py",
233    ],
234    srcs_version = "PY2AND3",
235    deps = [
236        ":backend",
237        ":engine_utils",
238        "//tensorflow/python/eager:profiler",
239    ],
240)
241
242py_library(
243    name = "constraints",
244    srcs = [
245        "constraints.py",
246    ],
247    srcs_version = "PY2AND3",
248    deps = [
249        ":backend",
250        ":engine_utils",
251    ],
252)
253
254py_library(
255    name = "initializers",
256    srcs = [
257        "initializers.py",
258    ],
259    srcs_version = "PY2AND3",
260    deps = [
261        ":backend",
262        ":engine_utils",
263    ],
264)
265
266py_library(
267    name = "losses",
268    srcs = [
269        "losses.py",
270    ],
271    srcs_version = "PY2AND3",
272    deps = [
273        ":backend",
274        ":engine_utils",
275    ],
276)
277
278py_library(
279    name = "optimizers",
280    srcs = [
281        "optimizers.py",
282    ],
283    srcs_version = "PY2AND3",
284    deps = [
285        ":backend",
286        ":engine_utils",
287        "//tensorflow/python/keras/optimizer_v2",
288    ],
289)
290
291py_library(
292    name = "regularizers",
293    srcs = [
294        "regularizers.py",
295    ],
296    srcs_version = "PY2AND3",
297    deps = [
298        ":backend",
299        ":engine_utils",
300    ],
301)
302
303py_library(
304    name = "engine_utils",
305    srcs = [
306        "utils/conv_utils.py",
307        "utils/data_utils.py",
308        "utils/io_utils.py",
309        "utils/losses_utils.py",
310    ],
311    srcs_version = "PY2AND3",
312    deps = [
313        ":backend",
314        "//tensorflow/python/distribute:distribute_lib",
315    ],
316)
317
318py_library(
319    name = "layers",
320    srcs = [
321        "layers/__init__.py",
322        "layers/advanced_activations.py",
323        "layers/convolutional.py",
324        "layers/convolutional_recurrent.py",
325        "layers/core.py",
326        "layers/cudnn_recurrent.py",
327        "layers/dense_attention.py",
328        "layers/embeddings.py",
329        "layers/kernelized.py",
330        "layers/local.py",
331        "layers/merge.py",
332        "layers/noise.py",
333        "layers/normalization.py",
334        "layers/normalization_v2.py",
335        "layers/pooling.py",
336        "layers/recurrent.py",
337        "layers/recurrent_v2.py",
338        "layers/serialization.py",
339        "layers/wrappers.py",
340        "utils/kernelized_utils.py",
341        "utils/layer_utils.py",
342        "utils/tf_utils.py",
343    ],
344    srcs_version = "PY2AND3",
345    deps = [
346        ":engine",
347        ":generic_utils",
348        "//tensorflow/python:array_ops",
349        "//tensorflow/python:cudnn_rnn_ops_gen",
350        "//tensorflow/python:dtypes",
351        "//tensorflow/python:embedding_ops",
352        "//tensorflow/python:framework_ops",
353        "//tensorflow/python:init_ops",
354        "//tensorflow/python:math_ops",
355        "//tensorflow/python:nn",
356        "//tensorflow/python:nn_ops",
357        "//tensorflow/python:platform",
358        "//tensorflow/python:sparse_tensor",
359        "//tensorflow/python:standard_ops",
360        "//tensorflow/python:tensor_shape",
361        "//tensorflow/python:tensor_util",
362        "//tensorflow/python:util",
363        "//tensorflow/python:variables",
364        "//tensorflow/python/distribute:distribute_lib",
365        "//third_party/py/numpy",
366    ],
367)
368
369py_library(
370    name = "generic_utils",
371    srcs = [
372        "utils/generic_utils.py",
373    ],
374    srcs_version = "PY2AND3",
375    deps = [
376        "//tensorflow/python:util",
377        "//third_party/py/numpy",
378    ],
379)
380
381py_library(
382    name = "mode_keys",
383    srcs = [
384        "utils/mode_keys.py",
385    ],
386    srcs_version = "PY2AND3",
387    deps = [
388        "//tensorflow/python/saved_model/model_utils:mode_keys",
389    ],
390)
391
392tf_py_test(
393    name = "integration_test",
394    size = "medium",
395    srcs = ["integration_test.py"],
396    additional_deps = [
397        ":keras",
398        "@absl_py//absl/testing:parameterized",
399        "//third_party/py/numpy",
400        "//tensorflow/python:client_testlib",
401        "//tensorflow/python:nn_ops",
402    ],
403    shard_count = 12,
404    tags = ["notsan"],
405)
406
407tf_py_test(
408    name = "activations_test",
409    size = "small",
410    srcs = ["activations_test.py"],
411    additional_deps = [
412        ":keras",
413        "@absl_py//absl/testing:parameterized",
414        "//third_party/py/numpy",
415        "//tensorflow/python:client_testlib",
416        "//tensorflow/python:nn_ops",
417    ],
418)
419
420tf_py_test(
421    name = "constraints_test",
422    size = "small",
423    srcs = ["constraints_test.py"],
424    additional_deps = [
425        ":keras",
426        "@absl_py//absl/testing:parameterized",
427        "//third_party/py/numpy",
428        "//tensorflow/python:client_testlib",
429    ],
430)
431
432tf_py_test(
433    name = "initializers_test",
434    size = "small",
435    srcs = ["initializers_test.py"],
436    additional_deps = [
437        ":keras",
438        "@absl_py//absl/testing:parameterized",
439        "//third_party/py/numpy",
440        "//tensorflow/python:client_testlib",
441        "//tensorflow/python:init_ops",
442    ],
443)
444
445tf_py_test(
446    name = "regularizers_test",
447    size = "medium",
448    srcs = ["regularizers_test.py"],
449    additional_deps = [
450        ":keras",
451        "@absl_py//absl/testing:parameterized",
452        "//tensorflow/python:client_testlib",
453    ],
454)
455
456tf_py_test(
457    name = "optimizers_test",
458    size = "medium",
459    srcs = ["optimizers_test.py"],
460    additional_deps = [
461        ":keras",
462        "@absl_py//absl/testing:parameterized",
463        "//third_party/py/numpy",
464        "//tensorflow/python:client_testlib",
465        "//tensorflow/python:training",
466    ],
467    shard_count = 8,
468    tags = ["notsan"],
469)
470
471tf_py_test(
472    name = "losses_test",
473    size = "small",
474    srcs = ["losses_test.py"],
475    additional_deps = [
476        ":keras",
477        "@absl_py//absl/testing:parameterized",
478        "//third_party/py/numpy",
479        "//tensorflow/python:client_testlib",
480    ],
481)
482
483tf_py_test(
484    name = "metrics_functional_test",
485    size = "small",
486    srcs = ["metrics_functional_test.py"],
487    additional_deps = [
488        ":keras",
489        "//third_party/py/numpy",
490        "//tensorflow/python:client_testlib",
491    ],
492)
493
494tf_py_test(
495    name = "metrics_test",
496    size = "medium",
497    srcs = ["metrics_test.py"],
498    additional_deps = [
499        ":keras",
500        "@absl_py//absl/testing:parameterized",
501        "//third_party/py/numpy",
502        "//tensorflow/python:client_testlib",
503    ],
504    shard_count = 4,
505)
506
507tf_py_test(
508    name = "metrics_confusion_matrix_test",
509    size = "medium",
510    srcs = ["metrics_confusion_matrix_test.py"],
511    additional_deps = [
512        ":keras",
513        "@absl_py//absl/testing:parameterized",
514        "//third_party/py/numpy",
515        "//tensorflow/python:client_testlib",
516    ],
517    shard_count = 4,
518)
519
520tf_py_test(
521    name = "metrics_correctness_test",
522    size = "medium",
523    srcs = ["metrics_correctness_test.py"],
524    additional_deps = [
525        ":keras",
526        "@absl_py//absl/testing:parameterized",
527        "//third_party/py/numpy",
528        "//tensorflow/python:client_testlib",
529    ],
530    shard_count = 4,
531)
532
533tf_py_test(
534    name = "applications_test",
535    size = "medium",
536    srcs = ["applications/applications_test.py"],
537    additional_deps = [
538        ":keras",
539        "@absl_py//absl/testing:parameterized",
540        "//tensorflow/python:client_testlib",
541    ],
542    shard_count = 11,
543)
544
545tf_py_test(
546    name = "advanced_activations_test",
547    size = "medium",
548    srcs = ["layers/advanced_activations_test.py"],
549    additional_deps = [
550        ":keras",
551        "@absl_py//absl/testing:parameterized",
552        "//tensorflow/python:client_testlib",
553    ],
554)
555
556tf_py_test(
557    name = "tensorflow_op_layer_test",
558    size = "medium",
559    srcs = ["layers/tensorflow_op_layer_test.py"],
560    additional_deps = [
561        ":keras",
562        "@absl_py//absl/testing:parameterized",
563        "//tensorflow/python:client_testlib",
564    ],
565    shard_count = 3,
566)
567
568tf_py_test(
569    name = "convolutional_recurrent_test",
570    size = "medium",
571    srcs = ["layers/convolutional_recurrent_test.py"],
572    additional_deps = [
573        ":keras",
574        "@absl_py//absl/testing:parameterized",
575        "//third_party/py/numpy",
576        "//tensorflow/python:client_testlib",
577    ],
578    shard_count = 4,
579)
580
581cuda_py_test(
582    name = "convolutional_test",
583    size = "medium",
584    srcs = ["layers/convolutional_test.py"],
585    additional_deps = [
586        ":keras",
587        "@absl_py//absl/testing:parameterized",
588        "//third_party/py/numpy",
589        "//tensorflow/python:client_testlib",
590    ],
591    shard_count = 8,
592    xla_enable_strict_auto_jit = True,
593)
594
595cuda_py_test(
596    name = "convolutional_transpose_test",
597    size = "medium",
598    srcs = ["layers/convolutional_transpose_test.py"],
599    additional_deps = [
600        ":keras",
601        "@absl_py//absl/testing:parameterized",
602        "//third_party/py/numpy",
603        "//tensorflow/python:client_testlib",
604    ],
605    xla_enable_strict_auto_jit = True,
606)
607
608cuda_py_test(
609    name = "cudnn_recurrent_test",
610    size = "medium",
611    srcs = ["layers/cudnn_recurrent_test.py"],
612    additional_deps = [
613        ":keras",
614        "@absl_py//absl/testing:parameterized",
615        "//third_party/py/numpy",
616        "//tensorflow/python:client_testlib",
617    ],
618    shard_count = 4,
619    tags = [
620        "no_rocm",
621        "no_windows_gpu",
622    ],
623    xla_enable_strict_auto_jit = True,
624)
625
626tf_py_test(
627    name = "pooling_test",
628    size = "medium",
629    srcs = ["layers/pooling_test.py"],
630    additional_deps = [
631        ":keras",
632        "@absl_py//absl/testing:parameterized",
633        "//tensorflow/python:client_testlib",
634    ],
635    shard_count = 8,
636    # TODO(b/127881287): Re-enable.
637    tags = [
638        "no_windows_gpu",
639    ],
640)
641
642tf_py_test(
643    name = "core_test",
644    size = "medium",
645    srcs = ["layers/core_test.py"],
646    additional_deps = [
647        ":keras",
648        "@absl_py//absl/testing:parameterized",
649        "//third_party/py/numpy",
650        "//tensorflow/python:client_testlib",
651    ],
652    shard_count = 3,
653)
654
655tf_py_test(
656    name = "dense_attention_test",
657    size = "medium",
658    srcs = ["layers/dense_attention_test.py"],
659    additional_deps = [
660        ":keras",
661        "//third_party/py/numpy",
662        "@absl_py//absl/testing:parameterized",
663        "//tensorflow/python:client_testlib",
664    ],
665)
666
667cuda_py_test(
668    name = "embeddings_test",
669    size = "medium",
670    srcs = ["layers/embeddings_test.py"],
671    additional_deps = [
672        ":keras",
673        "@absl_py//absl/testing:parameterized",
674        "//tensorflow/python:client_testlib",
675    ],
676    xla_enable_strict_auto_jit = True,
677)
678
679tf_py_test(
680    name = "local_test",
681    size = "medium",
682    srcs = ["layers/local_test.py"],
683    additional_deps = [
684        ":keras",
685        "@absl_py//absl/testing:parameterized",
686        "//third_party/py/numpy",
687        "//tensorflow/python:client_testlib",
688    ],
689    shard_count = 2,
690    tags = ["no_windows"],
691)
692
693tf_py_test(
694    name = "merge_test",
695    size = "small",
696    srcs = ["layers/merge_test.py"],
697    additional_deps = [
698        ":keras",
699        "@absl_py//absl/testing:parameterized",
700        "//third_party/py/numpy",
701        "//tensorflow/python:client_testlib",
702    ],
703)
704
705tf_py_test(
706    name = "noise_test",
707    size = "small",
708    srcs = ["layers/noise_test.py"],
709    additional_deps = [
710        ":keras",
711        "@absl_py//absl/testing:parameterized",
712        "//tensorflow/python:client_testlib",
713    ],
714)
715
716tf_py_test(
717    name = "normalization_test",
718    size = "medium",
719    srcs = ["layers/normalization_test.py"],
720    additional_deps = [
721        ":keras",
722        "@absl_py//absl/testing:parameterized",
723        "//third_party/py/numpy",
724        "//tensorflow/python:client_testlib",
725    ],
726    shard_count = 3,
727    tags = ["notsan"],
728)
729
730tf_py_test(
731    name = "simplernn_test",
732    size = "medium",
733    srcs = ["layers/simplernn_test.py"],
734    additional_deps = [
735        ":keras",
736        "@absl_py//absl/testing:parameterized",
737        "//third_party/py/numpy",
738        "//tensorflow/python:client_testlib",
739    ],
740    shard_count = 4,
741    tags = ["notsan"],
742)
743
744tf_py_test(
745    name = "gru_test",
746    size = "medium",
747    srcs = ["layers/gru_test.py"],
748    additional_deps = [
749        ":keras",
750        "@absl_py//absl/testing:parameterized",
751        "//third_party/py/numpy",
752        "//tensorflow/python:client_testlib",
753    ],
754    shard_count = 4,
755    tags = ["notsan"],  # http://b/62136390
756)
757
758tf_py_test(
759    name = "lstm_test",
760    size = "medium",
761    srcs = ["layers/lstm_test.py"],
762    additional_deps = [
763        ":keras",
764        "@absl_py//absl/testing:parameterized",
765        "//third_party/py/numpy",
766        "//tensorflow/python:client_testlib",
767    ],
768    shard_count = 4,
769    tags = [
770        "noasan",  # times out b/63678675
771        "notsan",  # http://b/62189182
772    ],
773)
774
775tf_py_test(
776    name = "recurrent_test",
777    size = "medium",
778    srcs = ["layers/recurrent_test.py"],
779    additional_deps = [
780        ":keras",
781        "@absl_py//absl/testing:parameterized",
782        "//third_party/py/numpy",
783        "//tensorflow/python:client_testlib",
784    ],
785    shard_count = 10,
786)
787
788cuda_py_test(
789    name = "recurrent_v2_test",
790    size = "medium",
791    srcs = ["layers/recurrent_v2_test.py"],
792    additional_deps = [
793        ":keras",
794        "@absl_py//absl/testing:parameterized",
795        "//third_party/py/numpy",
796        "//tensorflow/python:client_testlib",
797    ],
798    shard_count = 2,
799)
800
801cuda_py_test(
802    name = "separable_convolutional_test",
803    size = "medium",
804    srcs = ["layers/separable_convolutional_test.py"],
805    additional_deps = [
806        ":keras",
807        "@absl_py//absl/testing:parameterized",
808        "//third_party/py/numpy",
809        "//tensorflow/python:client_testlib",
810    ],
811    xla_enable_strict_auto_jit = True,
812)
813
814cuda_py_test(
815    name = "lstm_v2_test",
816    size = "medium",
817    srcs = ["layers/lstm_v2_test.py"],
818    additional_deps = [
819        ":keras",
820        "@absl_py//absl/testing:parameterized",
821        "//third_party/py/numpy",
822        "//tensorflow/python:client_testlib",
823    ],
824    shard_count = 8,
825    tags = ["no_rocm"],
826)
827
828cuda_py_test(
829    name = "gru_v2_test",
830    size = "medium",
831    srcs = ["layers/gru_v2_test.py"],
832    additional_deps = [
833        ":keras",
834        "@absl_py//absl/testing:parameterized",
835        "//third_party/py/numpy",
836        "//tensorflow/python:client_testlib",
837    ],
838    shard_count = 8,
839    tags = ["no_rocm"],
840)
841
842tf_py_test(
843    name = "serialization_test",
844    size = "small",
845    srcs = ["layers/serialization_test.py"],
846    additional_deps = [
847        ":keras",
848        "@absl_py//absl/testing:parameterized",
849        "//tensorflow/python:client_testlib",
850    ],
851)
852
853tf_py_test(
854    name = "kernelized_test",
855    size = "small",
856    srcs = ["layers/kernelized_test.py"],
857    additional_deps = [
858        ":backend",
859        ":initializers",
860        ":keras",
861        ":layers",
862        "@absl_py//absl/testing:parameterized",
863        "//third_party/py/numpy",
864        "//tensorflow/python:array_ops",
865        "//tensorflow/python:client_testlib",
866        "//tensorflow/python:constant_op",
867        "//tensorflow/python:dtypes",
868        "//tensorflow/python:framework_ops",
869        "//tensorflow/python:framework_test_lib",
870        "//tensorflow/python:init_ops",
871        "//tensorflow/python:math_ops",
872        "//tensorflow/python:random_ops",
873        "//tensorflow/python:random_seed",
874        "//tensorflow/python:tensor_shape",
875        "//tensorflow/python/eager:context",
876    ],
877)
878
879tf_py_test(
880    name = "wrappers_test",
881    size = "medium",
882    srcs = ["layers/wrappers_test.py"],
883    additional_deps = [
884        ":keras",
885        "@absl_py//absl/testing:parameterized",
886        "//third_party/py/numpy",
887        "//tensorflow/python:client_testlib",
888    ],
889    shard_count = 4,
890    tags = [
891        "noasan",  # http://b/78599823
892        "notsan",
893    ],
894)
895
896tf_py_test(
897    name = "time_distributed_learning_phase_test",
898    size = "small",
899    srcs = ["layers/time_distributed_learning_phase_test.py"],
900    additional_deps = [
901        ":keras",
902        "//third_party/py/numpy",
903        "//tensorflow/python:client_testlib",
904    ],
905    tags = [
906        "noasan",  # http://b/78599823
907        "notsan",
908    ],
909)
910
911tf_py_test(
912    name = "scikit_learn_test",
913    size = "small",
914    srcs = ["wrappers/scikit_learn_test.py"],
915    additional_deps = [
916        ":keras",
917        "@absl_py//absl/testing:parameterized",
918        "//third_party/py/numpy",
919        "//tensorflow/python:client_testlib",
920    ],
921    tags = ["notsan"],
922)
923
924tf_py_test(
925    name = "data_utils_test",
926    size = "medium",
927    srcs = ["utils/data_utils_test.py"],
928    additional_deps = [
929        ":keras",
930        "@absl_py//absl/testing:parameterized",
931        "//third_party/py/numpy",
932        "//tensorflow/python:client_testlib",
933    ],
934    shard_count = 6,
935    tags = [
936        "no_oss",
937        "no_windows",
938        "noasan",  # times out
939        "notsan",
940        "optonly",  # times out
941    ],
942)
943
944tf_py_test(
945    name = "generic_utils_test",
946    size = "small",
947    srcs = ["utils/generic_utils_test.py"],
948    additional_deps = [
949        ":keras",
950        "@absl_py//absl/testing:parameterized",
951        "//tensorflow/python:client_testlib",
952    ],
953)
954
955tf_py_test(
956    name = "tf_utils_test",
957    size = "small",
958    srcs = ["utils/tf_utils_test.py"],
959    additional_deps = [
960        ":keras",
961        "//tensorflow/python:client_testlib",
962    ],
963)
964
965tf_py_test(
966    name = "composite_tensor_support_test",
967    size = "medium",
968    srcs = ["utils/composite_tensor_support_test.py"],
969    additional_deps = [
970        ":engine",
971        ":layers",
972        "//third_party/py/numpy",
973        "@absl_py//absl/testing:parameterized",
974        "//tensorflow/python/ops/ragged:ragged_tensor",
975        "//tensorflow/python:array_ops",
976        "//tensorflow/python:client_testlib",
977        "//tensorflow/python:dtypes",
978        "//tensorflow/python:framework_ops",
979        "//tensorflow/python:framework_test_lib",
980        "//tensorflow/python:math_ops",
981        "//tensorflow/python:sparse_ops",
982        "//tensorflow/python:sparse_tensor",
983    ],
984)
985
986tf_py_test(
987    name = "io_utils_test",
988    size = "small",
989    srcs = ["utils/io_utils_test.py"],
990    additional_deps = [
991        ":keras",
992        "@absl_py//absl/testing:parameterized",
993        "//third_party/py/numpy",
994        "//tensorflow/python:client_testlib",
995    ],
996    tags = [
997        "no_windows",  # TODO: needs investigation on Windows
998        "notsan",
999    ],
1000)
1001
1002tf_py_test(
1003    name = "np_utils_test",
1004    size = "small",
1005    srcs = ["utils/np_utils_test.py"],
1006    additional_deps = [
1007        ":keras",
1008        "@absl_py//absl/testing:parameterized",
1009        "//third_party/py/numpy",
1010        "//tensorflow/python:client_testlib",
1011    ],
1012)
1013
1014tf_py_test(
1015    name = "kernelized_utils_test",
1016    size = "small",
1017    srcs = ["utils/kernelized_utils_test.py"],
1018    additional_deps = [
1019        ":layers",
1020        "@absl_py//absl/testing:parameterized",
1021        "//tensorflow/python:client_testlib",
1022        "//tensorflow/python:constant_op",
1023    ],
1024)
1025
1026cuda_py_test(
1027    name = "multi_gpu_utils_test",
1028    srcs = ["utils/multi_gpu_utils_test.py"],
1029    additional_deps = [
1030        ":keras",
1031        "@absl_py//absl/testing:parameterized",
1032        "//third_party/py/numpy",
1033        "//tensorflow/python:client_testlib",
1034    ],
1035    tags = [
1036        "guitar",
1037        "multi_gpu",
1038    ],
1039    xla_enable_strict_auto_jit = True,
1040)
1041
1042cuda_py_test(
1043    name = "training_gpu_test",
1044    size = "small",
1045    srcs = ["engine/training_gpu_test.py"],
1046    additional_deps = [
1047        ":keras",
1048        "@absl_py//absl/testing:parameterized",
1049        "//third_party/py/numpy",
1050        "//tensorflow/python:client_testlib",
1051    ],
1052    xla_enable_strict_auto_jit = True,
1053)
1054
1055tf_py_test(
1056    name = "conv_utils_test",
1057    size = "small",
1058    srcs = ["utils/conv_utils_test.py"],
1059    additional_deps = [
1060        ":keras",
1061        "@absl_py//absl/testing:parameterized",
1062        "//third_party/py/numpy",
1063        "//tensorflow/python:client_testlib",
1064    ],
1065)
1066
1067tf_py_test(
1068    name = "image_test",
1069    size = "medium",
1070    srcs = ["preprocessing/image_test.py"],
1071    additional_deps = [
1072        ":keras",
1073        "@absl_py//absl/testing:parameterized",
1074        "//third_party/py/numpy",
1075        "//tensorflow/python:client_testlib",
1076    ],
1077)
1078
1079tf_py_test(
1080    name = "sequence_test",
1081    size = "small",
1082    srcs = ["preprocessing/sequence_test.py"],
1083    additional_deps = [
1084        ":keras",
1085        "@absl_py//absl/testing:parameterized",
1086        "//third_party/py/numpy",
1087        "//tensorflow/python:client_testlib",
1088    ],
1089)
1090
1091tf_py_test(
1092    name = "text_test",
1093    size = "small",
1094    srcs = ["preprocessing/text_test.py"],
1095    additional_deps = [
1096        ":keras",
1097        "@absl_py//absl/testing:parameterized",
1098        "//third_party/py/numpy",
1099        "//tensorflow/python:client_testlib",
1100    ],
1101)
1102
1103tf_py_test(
1104    name = "callbacks_test",
1105    size = "medium",
1106    srcs = ["callbacks_test.py"],
1107    additional_deps = [
1108        ":keras",
1109        "@absl_py//absl/testing:parameterized",
1110        "//third_party/py/numpy",
1111        "//tensorflow/python:client_testlib",
1112    ],
1113    shard_count = 4,
1114    tags = ["notsan"],
1115)
1116
1117tf_py_test(
1118    name = "callbacks_v1_test",
1119    size = "medium",
1120    srcs = ["callbacks_v1_test.py"],
1121    additional_deps = [
1122        ":keras",
1123        "@absl_py//absl/testing:parameterized",
1124        "//third_party/py/numpy",
1125        "//tensorflow/python:client_testlib",
1126    ],
1127    tags = ["notsan"],
1128)
1129
1130tf_py_test(
1131    name = "correctness_test",
1132    size = "medium",
1133    srcs = ["engine/correctness_test.py"],
1134    additional_deps = [
1135        ":keras",
1136        "@absl_py//absl/testing:parameterized",
1137        "//third_party/py/numpy",
1138        "//tensorflow/python:client_testlib",
1139    ],
1140    shard_count = 2,
1141    tags = ["notsan"],
1142)
1143
1144tf_py_test(
1145    name = "training_test",
1146    size = "medium",
1147    srcs = ["engine/training_test.py"],
1148    additional_deps = [
1149        ":keras",
1150        "@absl_py//absl/testing:parameterized",
1151        "//third_party/py/numpy",
1152        "//tensorflow/python:client_testlib",
1153    ],
1154    shard_count = 16,
1155    tags = ["notsan"],
1156)
1157
1158tf_py_test(
1159    name = "training_dataset_test",
1160    size = "medium",
1161    srcs = ["engine/training_dataset_test.py"],
1162    additional_deps = [
1163        ":keras",
1164        "@absl_py//absl/testing:parameterized",
1165        "//third_party/py/numpy",
1166        "//tensorflow/python:client_testlib",
1167    ],
1168    shard_count = 4,
1169)
1170
1171tf_py_test(
1172    name = "training_arrays_test",
1173    size = "small",
1174    srcs = ["engine/training_arrays_test.py"],
1175    additional_deps = [
1176        ":keras",
1177        ":layers",
1178        "@absl_py//absl/testing:parameterized",
1179        "//third_party/py/numpy",
1180        "//tensorflow/python/data/ops:dataset_ops",
1181        "//tensorflow/python:client_testlib",
1182    ],
1183)
1184
1185tf_py_test(
1186    name = "training_generator_test",
1187    size = "medium",
1188    srcs = ["engine/training_generator_test.py"],
1189    additional_deps = [
1190        ":keras",
1191        "@absl_py//absl/testing:parameterized",
1192        "//third_party/py/numpy",
1193        "//tensorflow/python:client_testlib",
1194    ],
1195    shard_count = 6,
1196    tags = [
1197        "no_oss",
1198        "notap",  #TODO(b/123544294): Re-enable this test.
1199        "notsan",
1200    ],
1201)
1202
1203tf_py_test(
1204    name = "feature_columns_integration_test",
1205    size = "small",
1206    srcs = ["engine/feature_columns_integration_test.py"],
1207    additional_deps = [
1208        ":keras",
1209        "@absl_py//absl/testing:parameterized",
1210        "//third_party/py/numpy",
1211        "//tensorflow/python:client_testlib",
1212        "//tensorflow/python/feature_column:feature_column_py",
1213    ],
1214    tags = ["notsan"],
1215)
1216
1217tf_py_test(
1218    name = "training_eager_test",
1219    size = "medium",
1220    srcs = ["engine/training_eager_test.py"],
1221    additional_deps = [
1222        ":keras",
1223        "@absl_py//absl/testing:parameterized",
1224        "//third_party/py/numpy",
1225        "//tensorflow/python:client_testlib",
1226    ],
1227    tags = ["notsan"],
1228)
1229
1230tf_py_test(
1231    name = "training_utils_test",
1232    size = "medium",
1233    srcs = ["engine/training_utils_test.py"],
1234    additional_deps = [
1235        ":keras",
1236        "@absl_py//absl/testing:parameterized",
1237        "//third_party/py/numpy",
1238        "//tensorflow/python:client_testlib",
1239    ],
1240    tags = ["notsan"],
1241)
1242
1243tf_py_test(
1244    name = "model_subclassing_test",
1245    size = "medium",
1246    srcs = ["model_subclassing_test.py"],
1247    additional_deps = [
1248        ":keras",
1249        "@absl_py//absl/testing:parameterized",
1250        "//third_party/py/numpy",
1251        "//tensorflow/python:client_testlib",
1252    ],
1253    shard_count = 4,
1254    tags = ["notsan"],
1255)
1256
1257tf_py_test(
1258    name = "topology_test",
1259    size = "medium",
1260    srcs = ["engine/topology_test.py"],
1261    additional_deps = [
1262        ":keras",
1263        "@absl_py//absl/testing:parameterized",
1264        "//third_party/py/numpy",
1265        "//tensorflow/python:client_testlib",
1266    ],
1267    tags = [
1268        "no-internal-py3",
1269    ],
1270)
1271
1272tf_py_test(
1273    name = "base_layer_test",
1274    size = "medium",
1275    srcs = ["engine/base_layer_test.py"],
1276    additional_deps = [
1277        ":keras",
1278        "@absl_py//absl/testing:parameterized",
1279        "//third_party/py/numpy",
1280        "//tensorflow/python:client_testlib",
1281    ],
1282    shard_count = 8,
1283    tags = ["no_rocm"],
1284)
1285
1286tf_py_test(
1287    name = "hdf5_format_test",
1288    size = "medium",
1289    srcs = ["saving/hdf5_format_test.py"],
1290    additional_deps = [
1291        ":keras",
1292        "@absl_py//absl/testing:parameterized",
1293        "//third_party/py/numpy",
1294        "//tensorflow/python:client_testlib",
1295    ],
1296    shard_count = 4,
1297)
1298
1299tf_py_test(
1300    name = "sequential_test",
1301    size = "medium",
1302    srcs = ["engine/sequential_test.py"],
1303    additional_deps = [
1304        ":keras",
1305        "@absl_py//absl/testing:parameterized",
1306        "//third_party/py/numpy",
1307        "//tensorflow/python:client_testlib",
1308    ],
1309)
1310
1311tf_py_test(
1312    name = "models_test",
1313    size = "medium",
1314    srcs = ["models_test.py"],
1315    additional_deps = [
1316        ":keras",
1317        "@absl_py//absl/testing:parameterized",
1318        "//third_party/py/numpy",
1319        "//tensorflow/python:client_testlib",
1320        "//tensorflow/python:training",
1321    ],
1322    shard_count = 8,
1323    tags = ["notsan"],  # b/67509773
1324)
1325
1326tf_py_test(
1327    name = "backend_test",
1328    size = "medium",
1329    srcs = ["backend_test.py"],
1330    additional_deps = [
1331        ":keras",
1332        "@absl_py//absl/testing:parameterized",
1333        "//third_party/py/numpy",
1334        "//tensorflow/python:client_testlib",
1335        "//tensorflow/python:util",
1336    ],
1337    shard_count = 4,
1338)
1339
1340tf_py_test(
1341    name = "backend_config_test",
1342    size = "medium",
1343    srcs = ["backend_config_test.py"],
1344    additional_deps = [
1345        ":keras",
1346        "//third_party/py/numpy",
1347        "//tensorflow/python:client_testlib",
1348        "//tensorflow/python:util",
1349    ],
1350)
1351
1352tf_py_test(
1353    name = "keras_parameterized_test",
1354    size = "small",
1355    srcs = ["keras_parameterized_test.py"],
1356    additional_deps = [
1357        ":keras",
1358        "@absl_py//absl/testing:parameterized",
1359        "//third_party/py/numpy",
1360        "//tensorflow/python:client_testlib",
1361    ],
1362    tags = ["notsan"],
1363)
1364
1365tf_py_test(
1366    name = "saved_model_test",
1367    size = "medium",
1368    srcs = ["saving/saved_model_test.py"],
1369    additional_deps = [
1370        ":keras",
1371        "@absl_py//absl/testing:parameterized",
1372        "//third_party/py/numpy",
1373        "//tensorflow/python:client_testlib",
1374    ],
1375    tags = [
1376        "no_oss",  # TODO(b/119349471): Re-enable
1377        "no_windows",
1378    ],
1379)
1380
1381tf_py_test(
1382    name = "saving_utils_test",
1383    size = "medium",
1384    srcs = ["saving/saving_utils_test.py"],
1385    additional_deps = [
1386        ":keras",
1387        "@absl_py//absl/testing:parameterized",
1388        "//third_party/py/numpy",
1389        "//tensorflow/python:client_testlib",
1390    ],
1391    tags = ["notsan"],
1392)
1393