1"""Build rules for tf.distribute testing."""
2
3load("//tensorflow/python/tpu:tpu.bzl", _tpu_py_test = "tpu_py_test")
4load("//tensorflow:tensorflow.bzl", "cuda_py_test")
5
6def distribute_py_test(
7        name,
8        srcs = [],
9        deps = [],
10        tags = [],
11        data = [],
12        main = None,
13        size = "medium",
14        args = [],
15        tpu_args = [],
16        tpu_tags = None,
17        shard_count = 1,
18        full_precision = False,
19        disable_v2 = False,
20        disable_v3 = False,
21        disable_mlir_bridge = True,
22        **kwargs):
23    """Generates py_test targets for CPU and GPU.
24
25    Args:
26        name: test target name to generate suffixed with `test`.
27        srcs: source files for the tests.
28        deps: additional dependencies for the test targets.
29        tags: tags to be assigned to the different test targets.
30        data: data files that need to be associated with the target files.
31        main: optional main script.
32        size: size of test, to control timeout.
33        args: arguments to the non-tpu tests.
34        tpu_args: arguments for the tpu tests.
35        tpu_tags: tags for the tpu tests. If unspecified, uses value of `tags`.
36        shard_count: number of shards to split the tests across.
37        full_precision: unused.
38        disable_v2: whether tests for TPU version 2 should be generated.
39        disable_v3: whether tests for TPU version 3 should be generated.
40        disable_mlir_bridge: whether to also run this with the mlir bridge enabled.
41        **kwargs: extra keyword arguments to the non-tpu test.
42    """
43
44    # Default to PY3 since multi worker tests require PY3.
45    kwargs.setdefault("python_version", "PY3")
46
47    _ignore = (full_precision)
48    tpu_tags = tags if (tpu_tags == None) else tpu_tags
49    main = main if main else "%s.py" % name
50
51    cuda_py_test(
52        name = name,
53        srcs = srcs,
54        data = data,
55        main = main,
56        size = size,
57        deps = deps,
58        shard_count = shard_count,
59        tags = tags,
60        args = args,
61        **kwargs
62    )
63
64    if "notpu" not in tags and "no_tpu" not in tags:
65        _tpu_py_test(
66            disable_experimental = True,
67            name = name + "_tpu",
68            srcs = srcs,
69            data = data,
70            main = main,
71            size = size,
72            args = tpu_args,
73            shard_count = shard_count,
74            deps = deps,
75            tags = tpu_tags,
76            disable_v2 = disable_v2,
77            disable_v3 = disable_v3,
78            disable_mlir_bridge = disable_mlir_bridge,
79        )
80