1"""Build rules for tf.distribute testing.""" 2 3load("//tensorflow/python/tpu:tpu.bzl", _tpu_py_test = "tpu_py_test") 4load("//tensorflow:tensorflow.bzl", "cuda_py_test") 5 6def distribute_py_test( 7 name, 8 srcs = [], 9 deps = [], 10 tags = [], 11 data = [], 12 main = None, 13 size = "medium", 14 args = [], 15 tpu_args = [], 16 tpu_tags = None, 17 shard_count = 1, 18 full_precision = False, 19 disable_v2 = False, 20 disable_v3 = False, 21 disable_mlir_bridge = True, 22 **kwargs): 23 """Generates py_test targets for CPU and GPU. 24 25 Args: 26 name: test target name to generate suffixed with `test`. 27 srcs: source files for the tests. 28 deps: additional dependencies for the test targets. 29 tags: tags to be assigned to the different test targets. 30 data: data files that need to be associated with the target files. 31 main: optional main script. 32 size: size of test, to control timeout. 33 args: arguments to the non-tpu tests. 34 tpu_args: arguments for the tpu tests. 35 tpu_tags: tags for the tpu tests. If unspecified, uses value of `tags`. 36 shard_count: number of shards to split the tests across. 37 full_precision: unused. 38 disable_v2: whether tests for TPU version 2 should be generated. 39 disable_v3: whether tests for TPU version 3 should be generated. 40 disable_mlir_bridge: whether to also run this with the mlir bridge enabled. 41 **kwargs: extra keyword arguments to the non-tpu test. 42 """ 43 44 # Default to PY3 since multi worker tests require PY3. 45 kwargs.setdefault("python_version", "PY3") 46 47 _ignore = (full_precision) 48 tpu_tags = tags if (tpu_tags == None) else tpu_tags 49 main = main if main else "%s.py" % name 50 51 cuda_py_test( 52 name = name, 53 srcs = srcs, 54 data = data, 55 main = main, 56 size = size, 57 deps = deps, 58 shard_count = shard_count, 59 tags = tags, 60 args = args, 61 **kwargs 62 ) 63 64 if "notpu" not in tags and "no_tpu" not in tags: 65 _tpu_py_test( 66 disable_experimental = True, 67 name = name + "_tpu", 68 srcs = srcs, 69 data = data, 70 main = main, 71 size = size, 72 args = tpu_args, 73 shard_count = shard_count, 74 deps = deps, 75 tags = tpu_tags, 76 disable_v2 = disable_v2, 77 disable_v3 = disable_v3, 78 disable_mlir_bridge = disable_mlir_bridge, 79 ) 80