1load("//tensorflow:tensorflow.bzl", "tf_py_test") 2load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") 3load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") 4 5package( 6 default_visibility = [ 7 ":friends", 8 ], 9 licenses = ["notice"], # Apache 2.0 10) 11 12package_group( 13 name = "friends", 14 includes = ["//third_party/mlir:subpackages"], 15 packages = [ 16 "//tensorflow/compiler/mlir/tfr/...", 17 ], 18) 19 20gen_op_libraries( 21 name = "mnist_ops", 22 src = "ops_defs.py", 23 deps = [ 24 "//tensorflow:tensorflow_py", 25 ], 26) 27 28tf_py_test( 29 name = "mnist_ops_test", 30 size = "small", 31 srcs = ["mnist_ops_test.py"], 32 data = [":mnist_ops_mlir"], 33 python_version = "PY3", 34 srcs_version = "PY3", 35 tags = [ 36 "no_pip", 37 "no_windows", # TODO(b/170752141) 38 "nomac", # TODO(b/170752141) 39 ], 40 deps = [ 41 ":mnist_ops", 42 ":mnist_ops_py", 43 "//tensorflow:tensorflow_py", 44 "//tensorflow/compiler/mlir/tfr:test_utils", 45 ], 46) 47 48py_library( 49 name = "mnist_train", 50 srcs = ["mnist_train.py"], 51 data = [":mnist_ops_mlir"], 52 srcs_version = "PY3", 53 deps = [ 54 ":mnist_ops", 55 ":mnist_ops_py", 56 "//tensorflow:tensorflow_py", 57 "//tensorflow/python:framework", 58 "@absl_py//absl/flags", 59 ], 60) 61 62distribute_py_test( 63 name = "mnist_train_test", 64 size = "medium", 65 srcs = ["mnist_train_test.py"], 66 data = [":mnist_ops_mlir"], 67 disable_v3 = True, # Not needed. Save some resources and test time. 68 python_version = "PY3", 69 tags = [ 70 "no_cuda_asan", # Not needed, and there were issues with timeouts. 71 "no_oss", # Avoid downloading mnist data set in oss. 72 "nomultivm", # Not needed. Save some resources and test time. 73 "notap", # The test is too long to run as part of llvm presubmits (b/173661843). 74 "notsan", # Not needed, and there were issues with timeouts. 75 ], 76 77 # TODO(b/175056184): Re-enable xla_enable_strict_auto_jit once the issues 78 # with GPU and the MLIR bridge are worked out. 79 xla_enable_strict_auto_jit = False, 80 deps = [ 81 ":mnist_train", 82 "//tensorflow/python:client_testlib", 83 "//tensorflow/python:extra_py_tests_deps", 84 "//tensorflow/python:is_mlir_bridge_test_true", 85 "//tensorflow/python/distribute:combinations", 86 "//tensorflow/python/distribute:strategy_combinations", 87 "//tensorflow/python/distribute:test_util", 88 "@absl_py//absl/testing:parameterized", 89 ], 90) 91