1load("//tensorflow:tensorflow.bzl", "tf_py_test") 2 3package( 4 default_visibility = ["//tensorflow:internal"], 5 licenses = ["notice"], # Apache 2.0 6) 7 8exports_files(["LICENSE"]) 9 10py_library( 11 name = "cluster_coordinator", 12 srcs = ["cluster_coordinator.py"], 13 srcs_version = "PY3", 14 deps = [ 15 ":metric_utils", 16 ":utils", 17 "//tensorflow/python:errors", 18 "//tensorflow/python:framework_ops", 19 "//tensorflow/python:func_graph", 20 "//tensorflow/python:resource_variable_ops", 21 "//tensorflow/python:training_server_lib", 22 "//tensorflow/python:util", 23 "//tensorflow/python/distribute:input_lib", 24 "//tensorflow/python/distribute:parameter_server_strategy_v2", 25 "//tensorflow/python/distribute:values", 26 "//tensorflow/python/eager:cancellation", 27 "//tensorflow/python/eager:context", 28 "//tensorflow/python/eager:def_function", 29 "//tensorflow/python/eager:executor", 30 "//tensorflow/python/eager:function", 31 "//tensorflow/python/eager:remote", 32 "@six_archive//:six", 33 ], 34) 35 36tf_py_test( 37 name = "cluster_coordinator_test", 38 size = "small", 39 srcs = ["cluster_coordinator_test.py"], 40 python_version = "PY3", 41 shard_count = 50, 42 tags = [ 43 "no_pip", 44 "notsan", # TODO(b/171040359): Flaky timeout, even if maximum shards 45 ], 46 deps = [ 47 ":cluster_coordinator", 48 "//tensorflow/python:check_ops", 49 "//tensorflow/python:client_testlib", 50 "//tensorflow/python:constant_op", 51 "//tensorflow/python:dtypes", 52 "//tensorflow/python:errors", 53 "//tensorflow/python:math_ops", 54 "//tensorflow/python:random_ops", 55 "//tensorflow/python:tensor_spec", 56 "//tensorflow/python:training_lib", 57 "//tensorflow/python:training_server_lib", 58 "//tensorflow/python:util", 59 "//tensorflow/python:variables", 60 "//tensorflow/python/data/ops:dataset_ops", 61 "//tensorflow/python/distribute:multi_worker_test_base", 62 "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", 63 "//tensorflow/python/eager:def_function", 64 "//tensorflow/python/eager:test", 65 ], 66) 67 68tf_py_test( 69 name = "fault_tolerance_test", 70 srcs = ["fault_tolerance_test.py"], 71 python_version = "PY3", 72 shard_count = 27, 73 tags = [ 74 "noasan", # Multi-process runner does not work with test sanitizers 75 "nomac", # TODO(b/177065434) 76 "notsan", # Multi-process runner does not work with test sanitizers 77 ], 78 deps = [ 79 ":cluster_coordinator", 80 "//tensorflow/python:array_ops", 81 "//tensorflow/python:check_ops", 82 "//tensorflow/python:dtypes", 83 "//tensorflow/python:errors", 84 "//tensorflow/python:framework_ops", 85 "//tensorflow/python:math_ops", 86 "//tensorflow/python:platform", 87 "//tensorflow/python:random_ops", 88 "//tensorflow/python:variables", 89 "//tensorflow/python/compat:v2_compat", 90 "//tensorflow/python/distribute:multi_process_runner", 91 "//tensorflow/python/distribute:multi_worker_test_base", 92 "//tensorflow/python/distribute:parameter_server_strategy_v2", 93 "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", 94 "//tensorflow/python/eager:context", 95 "//tensorflow/python/eager:def_function", 96 "//tensorflow/python/eager:test", 97 "//tensorflow/python/training:training_lib", 98 ], 99) 100 101py_library( 102 name = "metric_utils", 103 srcs = ["metric_utils.py"], 104 srcs_version = "PY3", 105 deps = [ 106 "//tensorflow/python/eager:monitoring", 107 ], 108) 109 110tf_py_test( 111 name = "metric_utils_test", 112 srcs = ["metric_utils_test.py"], 113 python_version = "PY3", 114 deps = [ 115 ":cluster_coordinator", 116 ":metric_utils", 117 "//tensorflow/python:training_server_lib", 118 "//tensorflow/python/distribute:multi_worker_test_base", 119 "//tensorflow/python/distribute:parameter_server_strategy_v2", 120 "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", 121 "//tensorflow/python/eager:test", 122 ], 123) 124 125py_library( 126 name = "utils", 127 srcs = ["utils.py"], 128 srcs_version = "PY3", 129 deps = [ 130 "//tensorflow/python:training_server_lib", 131 ], 132) 133 134py_library( 135 name = "remote_eager_lib", 136 srcs_version = "PY3", 137 visibility = ["//visibility:public"], 138) 139