1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || \
17     (INTEL_MKL && defined(ENABLE_INTEL_MKL_BFLOAT16))
18 
19 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
20 
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
25 #include "tensorflow/cc/ops/list_ops.h"
26 #include "tensorflow/cc/ops/math_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/core/framework/function_testlib.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/tensor_testutil.h"
31 #include "tensorflow/core/grappler/clusters/single_machine.h"
32 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
33 #include "tensorflow/core/grappler/devices.h"
34 #include "tensorflow/core/grappler/graph_view.h"
35 #include "tensorflow/core/grappler/utils/grappler_test.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/lib/random/random.h"
38 
39 // TODO(benbarsdell): Improve the numerical checks in these tests. The tests
40 // were originally written only to check the graph coloring, so the graphs do
41 // not have particularly realistic numerical behavior.
42 
43 namespace tensorflow {
44 namespace grappler {
45 namespace {
46 
47 template <DataType DTYPE>
GenerateIdentityMatrix(int64 height,int64 width)48 Tensor GenerateIdentityMatrix(int64 height, int64 width) {
49   typedef typename EnumToDataType<DTYPE>::Type T;
50   Tensor tensor(DTYPE, TensorShape{height, width});
51   for (int64 i = 0; i < height; ++i) {
52     for (int64 j = 0; j < width; ++j) {
53       tensor.matrix<T>()(i, j) = i == j;
54     }
55   }
56   return tensor;
57 }
58 
59 template <DataType DTYPE>
GenerateRandomTensorInRange(const TensorShape & shape,double minval,double maxval)60 Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
61                                    double maxval) {
62   typedef typename EnumToDataType<DTYPE>::Type T;
63   Tensor tensor(DTYPE, shape);
64   for (auto i = 0; i < tensor.NumElements(); i++)
65     tensor.flat<T>()(i) =
66         (random::New64() % 65536 / 65536.0) * (maxval - minval) + minval;
67   return tensor;
68 }
69 
VerifyGraphsEquivalent(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)70 void VerifyGraphsEquivalent(const GraphDef& original_graph,
71                             const GraphDef& optimized_graph,
72                             const string& func) {
73   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
74   GraphView optimized_view(&optimized_graph);
75   for (int i = 0; i < original_graph.node_size(); ++i) {
76     const NodeDef& original = original_graph.node(i);
77     const NodeDef& optimized = *optimized_view.GetNode(original.name());
78     EXPECT_EQ(original.name(), optimized.name()) << func;
79     EXPECT_EQ(original.op(), optimized.op()) << func;
80     EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
81     if (original.input_size() == optimized.input_size()) {
82       for (int j = 0; j < original.input_size(); ++j) {
83         EXPECT_EQ(original.input(j), optimized.input(j)) << func;
84       }
85     }
86   }
87 }
88 
89 // Currently, this test suite only passes when TensorFlow passes with CUDA,
90 // because otherwise the optimizer will not turn clearlist nodes to float16.
91 // When looking at clearlist nodes, this optimizer checks if the nodes have a
92 // float16 GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
93 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94 
95 const std::pair<int, int> kMinGPUArch = {7, 0};
96 
97 class AutoMixedPrecisionTest : public GrapplerTest {
98  protected:
SetUp()99   void SetUp() override {
100     int num_gpus = GetNumAvailableGPUs();
101     // If GPUs are available, require that they all satisfy the min arch.
102     gpu_available_ = (num_gpus > 0);
103 #if GOOGLE_CUDA
104     gpu_available_ =
105         gpu_available_ && (num_gpus == GetNumAvailableGPUs(kMinGPUArch));
106 #endif
107     if (gpu_available_) {
108       virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
109     } else {
110       DeviceProperties device_properties;
111       device_properties.set_type("GPU");
112 #if GOOGLE_CUDA
113       device_properties.mutable_environment()->insert({"architecture", "7"});
114       device_properties.mutable_environment()->insert({"cuda", "9010"});
115 #endif
116       virtual_cluster_.reset(
117           new VirtualCluster({{"/GPU:1", device_properties}}));
118     }
119     TF_CHECK_OK(virtual_cluster_->Provision());
120   }
121 
TearDown()122   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
123 
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const124   NodeDef* AddSimpleNode(const string& name, const string& op,
125                          const std::vector<string>& inputs,
126                          GraphDef* graph) const {
127     std::vector<std::pair<string, AttrValue>> attributes;
128     if (op == "AddN" || op == "ShapeN") {
129       AttrValue num_inputs;
130       num_inputs.set_i(inputs.size());
131       attributes.emplace_back("N", num_inputs);
132     }
133     if (op == "ShapeN") {
134       AttrValue out_type;
135       out_type.set_type(DT_INT32);
136       attributes.emplace_back("out_type", out_type);
137     }
138     AttrValue type;
139     type.set_type(DT_FLOAT);
140     if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
141         op == "VarHandleOp" || op == "ReadVariableOp") {
142       attributes.emplace_back("dtype", type);
143     } else if (op == "SparseMatMul") {
144       attributes.emplace_back("Ta", type);
145       attributes.emplace_back("Tb", type);
146     } else if (op == "IdentityN") {
147       AttrValue type_list;
148       for (int i = 0; i < static_cast<int>(inputs.size()); ++i) {
149         type_list.mutable_list()->add_type(DT_FLOAT);
150       }
151       attributes.emplace_back("T", type_list);
152     } else if (op == "StackV2" || op == "StackPopV2") {
153       attributes.emplace_back("elem_type", type);
154     } else if (op == "Cast") {
155       attributes.emplace_back("SrcT", type);
156       attributes.emplace_back("DstT", type);
157     } else {
158       attributes.emplace_back("T", type);
159     }
160     return AddNode(name, op, inputs, attributes, graph);
161   }
162 
TestSimpleUnaryInferOp(double input_min,double input_max,double atol,double rtol,const std::function<Output (const tensorflow::Scope &,Output)> & test_op_factory)163   void TestSimpleUnaryInferOp(
164       double input_min, double input_max, double atol, double rtol,
165       const std::function<Output(const tensorflow::Scope&, Output)>&
166           test_op_factory) {
167     int size = 128;
168     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
169     Output eye = ops::Const(s.WithOpName("eye"),
170                             GenerateIdentityMatrix<DT_FLOAT>(size, size));
171     Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
172     Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
173     Output infer1 = test_op_factory(s.WithOpName("infer1"), allow1);
174     Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, eye);
175     Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
176     GrapplerItem item;
177     item.fetch = {"fetch1"};
178     TF_CHECK_OK(s.ToGraphDef(&item.graph));
179     auto input_tensor = GenerateRandomTensorInRange<DT_FLOAT>(
180         TensorShape({size, size}), input_min, input_max);
181     std::vector<std::pair<string, Tensor>> feed = {{"input", input_tensor}};
182     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
183 
184     AutoMixedPrecision optimizer;
185     GraphDef output;
186     TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
187 
188     VLOG(1) << output.DebugString();
189 
190     GraphView output_view(&output);
191     EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
192               DT_FLOAT);
193     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
194     EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
195     EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
196 
197     auto tensors = EvaluateNodes(output, item.fetch, feed);
198     EXPECT_EQ(tensors.size(), tensors_expected.size());
199     EXPECT_EQ(tensors.size(), item.fetch.size());
200     for (int i = 0; i < item.fetch.size(); ++i) {
201       test::ExpectClose(tensors_expected[i], tensors[i], atol, rtol);
202     }
203   }
204 
205   std::unique_ptr<Cluster> virtual_cluster_;
206   bool gpu_available_;
207 };
208 
TEST_F(AutoMixedPrecisionTest,NoOp)209 TEST_F(AutoMixedPrecisionTest, NoOp) {
210   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211   Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
212   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
213   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
214   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
215   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
216   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
217 
218   GrapplerItem item;
219   item.fetch = {"fetch"};
220   TF_CHECK_OK(s.ToGraphDef(&item.graph));
221   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
222 
223   AutoMixedPrecision optimizer;
224   GraphDef output;
225   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
226 
227   VLOG(1) << output.DebugString();
228 
229   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
230 
231   GraphView output_view(&output);
232   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
233   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
234   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
235   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
236   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
237 
238   auto tensors = EvaluateNodes(output, item.fetch);
239   EXPECT_EQ(tensors.size(), tensors_expected.size());
240   EXPECT_EQ(tensors.size(), item.fetch.size());
241   for (int i = 0; i < item.fetch.size(); ++i) {
242     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
243   }
244 }
245 
TEST_F(AutoMixedPrecisionTest,AlreadyFp16)246 TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
247   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
248   Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
249   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
250   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
251   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
252   Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
253   Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
254   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
255 
256   GrapplerItem item;
257   item.fetch = {"fetch"};
258   TF_CHECK_OK(s.ToGraphDef(&item.graph));
259   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
260 
261   AutoMixedPrecision optimizer;
262   GraphDef output;
263   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
264   VLOG(1) << output.DebugString();
265 
266   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
267   GraphView output_view(&output);
268   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
269   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
270   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
271   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
272   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
273   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
274   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
275 
276   auto tensors = EvaluateNodes(output, item.fetch);
277   EXPECT_EQ(tensors.size(), tensors_expected.size());
278   EXPECT_EQ(tensors.size(), item.fetch.size());
279   for (int i = 0; i < item.fetch.size(); ++i) {
280     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
281   }
282 }
283 
TEST_F(AutoMixedPrecisionTest,Simple)284 TEST_F(AutoMixedPrecisionTest, Simple) {
285   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
286   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
287   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
288   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
289   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
290   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
291   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
292   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
293   Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
294   Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
295   Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
296   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
297   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
298 
299   GrapplerItem item;
300   item.fetch = {"fetch"};
301   TF_CHECK_OK(s.ToGraphDef(&item.graph));
302   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
303 
304   AutoMixedPrecision optimizer;
305   GraphDef output;
306   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
307 
308   VLOG(1) << output.DebugString();
309 
310   GraphView output_view(&output);
311   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
312   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
313   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
314   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
315   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
316   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
317   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
318   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
319   EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
320   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
321   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
322   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
323   EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
324 
325   auto tensors = EvaluateNodes(output, item.fetch);
326   EXPECT_EQ(tensors.size(), tensors_expected.size());
327   EXPECT_EQ(tensors.size(), item.fetch.size());
328   for (int i = 0; i < item.fetch.size(); ++i) {
329     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
330   }
331 }
332 
TEST_F(AutoMixedPrecisionTest,BidirectionalClearChain)333 TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
334   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
335   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
336   Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
337   Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
338   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
339   auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
340   Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
341   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
342   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
343 
344   GrapplerItem item;
345   item.fetch = {"fetch1", "fetch2"};
346   TF_CHECK_OK(s.ToGraphDef(&item.graph));
347   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
348 
349   AutoMixedPrecision optimizer;
350   GraphDef output;
351   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
352 
353   VLOG(1) << output.DebugString();
354 
355   GraphView output_view(&output);
356   EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
357   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
358   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
359   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
360   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
361   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
362   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
363 
364   auto tensors = EvaluateNodes(output, item.fetch);
365   EXPECT_EQ(tensors.size(), tensors_expected.size());
366   EXPECT_EQ(tensors.size(), item.fetch.size());
367   for (int i = 0; i < item.fetch.size(); ++i) {
368     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
369   }
370 }
371 
TEST_F(AutoMixedPrecisionTest,PreserveFetches)372 TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
373   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
374   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
375   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
376   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
377   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
378   Output deny1 = ops::Exp(s.WithOpName("deny1"), infer1);
379   Output clr2 = ops::Relu(s.WithOpName("clr2"), deny1);
380   Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
381   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
382   Output deny2 = ops::Exp(s.WithOpName("deny2"), clr3);
383   Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
384 
385   GrapplerItem item;
386   item.fetch = {"allow1", "clr2", "clr3"};
387   TF_CHECK_OK(s.ToGraphDef(&item.graph));
388   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
389 
390   AutoMixedPrecision optimizer;
391   GraphDef output;
392   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
393 
394   VLOG(1) << output.DebugString();
395 
396   GraphView output_view(&output);
397   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
398   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
399   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
400   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
401   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
402   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
403   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
404   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
405   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
406   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
407   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
408 
409   auto tensors = EvaluateNodes(output, item.fetch);
410   EXPECT_EQ(tensors.size(), tensors_expected.size());
411   EXPECT_EQ(tensors.size(), item.fetch.size());
412   for (int i = 0; i < item.fetch.size(); ++i) {
413     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
414   }
415 }
416 
TEST_F(AutoMixedPrecisionTest,PreserveCPUNodes)417 TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
418   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
419   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
420   Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
421   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
422   Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
423   Output allow2 =
424       ops::MatMul(s.WithOpName("allow2").WithDevice(
425                       "/job:localhost/replica:0/task:0/device:CPU:0"),
426                   infer1, infer1);
427   Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
428   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
429 
430   GrapplerItem item;
431   item.fetch = {"fetch"};
432   TF_CHECK_OK(s.ToGraphDef(&item.graph));
433   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
434 
435   AutoMixedPrecision optimizer;
436   GraphDef output;
437   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
438 
439   VLOG(1) << output.DebugString();
440 
441   GraphView output_view(&output);
442   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
443   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
444   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
445   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
446   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
447   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
448   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
449 
450   auto tensors = EvaluateNodes(output, item.fetch);
451   EXPECT_EQ(tensors.size(), tensors_expected.size());
452   EXPECT_EQ(tensors.size(), item.fetch.size());
453   for (int i = 0; i < item.fetch.size(); ++i) {
454     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
455   }
456 }
457 
TEST_F(AutoMixedPrecisionTest,PreserveIdentityAfterVariable)458 TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
459   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
460   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
461   Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
462   Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
463   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
464   Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
465   Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
466   Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
467   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
468   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
469 
470   GrapplerItem item;
471   item.fetch = {"fetch1", "fetch2"};
472   TF_CHECK_OK(s.ToGraphDef(&item.graph));
473   auto var1_tensor =
474       GenerateConstantTensor<DT_FLOAT>(TensorShape({32, 32}), 3.141593f);
475   std::vector<std::pair<string, Tensor>> feed = {{"var1", var1_tensor}};
476   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
477 
478   AutoMixedPrecision optimizer;
479   GraphDef output;
480   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
481 
482   VLOG(1) << output.DebugString();
483 
484   GraphView output_view(&output);
485   EXPECT_EQ(output.node_size(), item.graph.node_size() + 5);
486   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
487   EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
488   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
489   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
490   EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
491   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
492   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
493 
494   auto tensors = EvaluateNodes(output, item.fetch, feed);
495   EXPECT_EQ(tensors.size(), tensors_expected.size());
496   EXPECT_EQ(tensors.size(), item.fetch.size());
497   for (int i = 0; i < item.fetch.size(); ++i) {
498     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-3);
499   }
500 }
501 
TEST_F(AutoMixedPrecisionTest,FusedBatchNorm)502 TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
503   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
504   // Uses NHWC data format because non-GPU execution does not support NCHW.
505   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {8, 56, 56, 16});
506   Output weight = ops::Const(s.WithOpName("weight"), 2.f, {3, 3, 16, 16});
507   Output scale = ops::Const(s.WithOpName("scale"), 3.f, {16});
508   Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
509   Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
510   Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
511   Output allow1 =
512       ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
513                   ops::Conv2D::DataFormat("NHWC"));
514   auto fbn1_op =
515       ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
516                           variance, ops::FusedBatchNorm::DataFormat("NHWC"));
517   Output fbn1 = fbn1_op.y;
518   Output fbn1_rs1 = fbn1_op.reserve_space_1;
519   Output fbn1_rs2 = fbn1_op.reserve_space_2;
520   Output bng1 = ops::FusedBatchNormGrad(
521                     s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
522                     fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
523                     .x_backprop;
524   Output infer1 = ops::Add(s.WithOpName("infer1"), fbn1, bng1);
525   Output allow2 =
526       ops::Conv2D(s.WithOpName("allow2"), infer1, weight, {1, 1, 1, 1}, "SAME",
527                   ops::Conv2D::DataFormat("NHWC"));
528   Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
529 
530   GrapplerItem item;
531   item.fetch = {"fetch"};
532   TF_CHECK_OK(s.ToGraphDef(&item.graph));
533   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
534 
535   AutoMixedPrecision optimizer;
536   GraphDef output;
537   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
538 
539   VLOG(1) << output.DebugString();
540 
541   GraphView output_view(&output);
542   EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
543   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
544   EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
545   EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
546   EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
547   EXPECT_EQ(output_view.GetNode("bng1")->op(), "FusedBatchNormGradV2");
548   EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
549   EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
550   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
551   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
552 
553   auto tensors = EvaluateNodes(output, item.fetch);
554   EXPECT_EQ(tensors.size(), tensors_expected.size());
555   EXPECT_EQ(tensors.size(), item.fetch.size());
556   for (int i = 0; i < item.fetch.size(); ++i) {
557     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-3);
558   }
559 }
560 
TEST_F(AutoMixedPrecisionTest,RepeatedAndListTypeAttrs)561 TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
562   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
563   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
564   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
565   auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
566   Output infer1 =
567       ops::AddN(s.WithOpName("infer1"),
568                 {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
569   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
570   Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
571 
572   GrapplerItem item;
573   item.fetch = {"fetch"};
574   TF_CHECK_OK(s.ToGraphDef(&item.graph));
575   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
576 
577   AutoMixedPrecision optimizer;
578   GraphDef output;
579   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
580 
581   VLOG(1) << output.DebugString();
582 
583   GraphView output_view(&output);
584   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
585   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
586   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
587   for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
588     EXPECT_EQ(type, DT_HALF);
589   }
590   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
591   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
592 
593   auto tensors = EvaluateNodes(output, item.fetch);
594   EXPECT_EQ(tensors.size(), tensors_expected.size());
595   EXPECT_EQ(tensors.size(), item.fetch.size());
596   for (int i = 0; i < item.fetch.size(); ++i) {
597     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
598   }
599 }
600 
TEST_F(AutoMixedPrecisionTest,ExistingCast)601 TEST_F(AutoMixedPrecisionTest, ExistingCast) {
602   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
603   Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
604   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
605   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
606   Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
607 
608   GrapplerItem item;
609   item.fetch = {"fetch"};
610   TF_CHECK_OK(s.ToGraphDef(&item.graph));
611   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
612 
613   AutoMixedPrecision optimizer;
614   GraphDef output;
615   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
616 
617   VLOG(1) << output.DebugString();
618 
619   GraphView output_view(&output);
620   EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
621   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
622   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
623   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
624 
625   auto tensors = EvaluateNodes(output, item.fetch);
626   EXPECT_EQ(tensors.size(), tensors_expected.size());
627   EXPECT_EQ(tensors.size(), item.fetch.size());
628   for (int i = 0; i < item.fetch.size(); ++i) {
629     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
630   }
631 }
632 
TEST_F(AutoMixedPrecisionTest,RecurrentEdgeColorMismatch)633 TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
634   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
635   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
636   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
637   Output ent1 =
638       ops::internal::Enter(s.WithOpName("ent1"), deny1, "loop1").output;
639   // Note that the second input is later replaced with "nxt1".
640   Output mrg1 = ops::Merge(s.WithOpName("mrg1"), {ent1, ent1}).output;
641   // For simplicity, the loop condition is constant false.
642   Output con1 = ops::Const(s.WithOpName("con1"), false, {});
643   Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
644   auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
645   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), swt1.output_true);
646   Output allow1 = ops::MatMul(s.WithOpName("allow1"), infer1, infer1);
647   Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
648   Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
649   Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
650   // Add a second merge node from the same NextIteration node. This case arises
651   // during graph optimization of some models.
652   auto mrg2 = ops::Merge(s.WithOpName("mrg2"), {ent1, nxt1});
653 
654   GrapplerItem item;
655   item.fetch = {"fetch"};
656   TF_CHECK_OK(s.ToGraphDef(&item.graph));
657   NodeMap node_map_original(&item.graph);
658   auto merge_node = node_map_original.GetNode("mrg1");
659   // Modify the graph to create a loop.
660   merge_node->set_input(1, "nxt1");
661   // Add a control edge to ensure the loop condition is inside the frame.
662   auto const_node = node_map_original.GetNode("con1");
663   const_node->add_input("^mrg1");
664   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
665 
666   AutoMixedPrecision optimizer;
667   GraphDef output;
668   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
669 
670   VLOG(1) << output.DebugString();
671 
672   GraphView output_view(&output);
673   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
674   // Note that mrg1 gets painted deny because it is between deny1 and infer1.
675   // This forces nxt1 and mrg2 to be painted deny as well (they would otherwise
676   // be painted allow because they are clear and have a direct path to allow1).
677   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
678   EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
679   EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
680   EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
681   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
682   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
683   EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
684   EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
685   EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
686 
687   auto tensors = EvaluateNodes(output, item.fetch);
688   EXPECT_EQ(tensors.size(), tensors_expected.size());
689   EXPECT_EQ(tensors.size(), item.fetch.size());
690   for (int i = 0; i < item.fetch.size(); ++i) {
691     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
692   }
693 }
694 
TEST_F(AutoMixedPrecisionTest,TensorListSetGet)695 TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
696   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
697   tensorflow::Input shape = {32, 32};
698   auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
699   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
700   Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
701   Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
702   Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
703   auto tl1w1 =
704       ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
705   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
706   auto tl1w2 =
707       ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
708   // Ensure that TensorListResize doesn't cause any problems.
709   Output tl1rs =
710       ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
711   Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
712                                         shape, DT_FLOAT)
713                      .item;
714   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
715   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
716   auto tl1w3 =
717       ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
718   Output tl1r2 =
719       ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
720                              shape, DT_FLOAT)
721           .item;
722   auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
723   auto tl2w1 =
724       ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
725   Output tl2r1 =
726       ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
727                              shape, DT_FLOAT)
728           .item;
729   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
730   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
731 
732   GrapplerItem item;
733   item.fetch = {"fetch1", "fetch2"};
734   TF_CHECK_OK(s.ToGraphDef(&item.graph));
735   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
736 
737   AutoMixedPrecision optimizer;
738   GraphDef output;
739   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
740 
741   VLOG(1) << output.DebugString();
742 
743   GraphView output_view(&output);
744   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
745   const char* type_key = "element_dtype";
746   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
747   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
748   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
749   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
750   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
751   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
752   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
753   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
754   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
755   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
756   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
757 
758   auto tensors = EvaluateNodes(output, item.fetch);
759   EXPECT_EQ(tensors.size(), tensors_expected.size());
760   EXPECT_EQ(tensors.size(), item.fetch.size());
761   for (int i = 0; i < item.fetch.size(); ++i) {
762     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
763   }
764 }
765 
TEST_F(AutoMixedPrecisionTest,TensorListPushPop)766 TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
767   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
768   tensorflow::Input shape = {32, 32};
769   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
770   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
771   auto tl1w1 =
772       ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
773   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
774   auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
775                                        tl1w1.output_handle, allow1);
776   Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
777                                         tl1w2.output_handle, shape, DT_FLOAT)
778                      .tensor;
779   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
780   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
781   auto tl1w3 =
782       ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
783   Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
784                                         tl1w3.output_handle, shape, DT_FLOAT)
785                      .tensor;
786   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
787   auto tl2w1 =
788       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, input);
789   Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
790                                         tl2w1.output_handle, shape, DT_FLOAT)
791                      .tensor;
792   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
793   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
794 
795   GrapplerItem item;
796   item.fetch = {"fetch1", "fetch2"};
797   TF_CHECK_OK(s.ToGraphDef(&item.graph));
798   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
799 
800   AutoMixedPrecision optimizer;
801   GraphDef output;
802   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
803 
804   VLOG(1) << output.DebugString();
805 
806   GraphView output_view(&output);
807   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
808   const char* type_key = "element_dtype";
809   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
810   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
811   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
812   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
813   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
814   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
815   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
816   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
817   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
818   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
819   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
820 
821   auto tensors = EvaluateNodes(output, item.fetch);
822   EXPECT_EQ(tensors.size(), tensors_expected.size());
823   EXPECT_EQ(tensors.size(), item.fetch.size());
824   for (int i = 0; i < item.fetch.size(); ++i) {
825     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
826   }
827 }
828 
TEST_F(AutoMixedPrecisionTest,TensorListFromTensor)829 TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
830   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
831   tensorflow::Input shape = {32};
832   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
833   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
834   auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
835   Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
836                                       shape, DT_FLOAT)
837                      .tensor;
838   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl1r1);
839   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
840   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
841 
842   // This tests that a allow-painted object node (tl2) will force an unpainted
843   // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
844   // would remain unpainted, producing an invalid graph).
845   auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
846   auto tl2w1 =
847       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
848 
849   GrapplerItem item;
850   item.fetch = {"fetch1"};
851   TF_CHECK_OK(s.ToGraphDef(&item.graph));
852   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
853 
854   AutoMixedPrecision optimizer;
855   GraphDef output;
856   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
857 
858   VLOG(1) << output.DebugString();
859 
860   GraphView output_view(&output);
861   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
862   const char* type_key = "element_dtype";
863   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
864   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
865   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
866   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
867   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
868   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
869   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
870 
871   auto tensors = EvaluateNodes(output, item.fetch);
872   EXPECT_EQ(tensors.size(), tensors_expected.size());
873   EXPECT_EQ(tensors.size(), item.fetch.size());
874   for (int i = 0; i < item.fetch.size(); ++i) {
875     test::ExpectClose(tensors_expected[i], tensors[i], -1, 2e-4);
876   }
877 }
878 
TEST_F(AutoMixedPrecisionTest,TensorListPushBackBatchAndConcatLists)879 TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
880   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
881   tensorflow::Input shape = {32, 32};
882   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
883   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
884   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
885   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
886   Output tl1_tl2 =
887       ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
888   Output allow1_allow1 =
889       ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
890   auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
891                                              allow1_allow1);
892   OutputList tl12w1_outputs =
893       ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
894           .output;
895   Output scalar_shape = ops::Const(s.WithOpName("scalar_shape"), 0, {0});
896   Output tl12w1_output0 = ops::Reshape(s.WithOpName("tl12w1_output0"),
897                                        tl12w1_outputs[0], scalar_shape);
898   Output tl12w1_output1 = ops::Reshape(s.WithOpName("tl12w1_output1"),
899                                        tl12w1_outputs[1], scalar_shape);
900   Output tl3 = ops::TensorListConcatLists(s.WithOpName("tl3"), tl12w1_output0,
901                                           tl12w1_output1, DT_FLOAT);
902   Output tl3r1 =
903       ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
904           .tensor;
905   Output infer1 = ops::Tanh(s.WithOpName("infer1"), tl3r1);
906   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
907   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
908 
909   GrapplerItem item;
910   item.fetch = {"fetch1"};
911   TF_CHECK_OK(s.ToGraphDef(&item.graph));
912   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
913 
914   AutoMixedPrecision optimizer;
915   GraphDef output;
916   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
917 
918   VLOG(1) << output.DebugString();
919 
920   GraphView output_view(&output);
921   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
922   const char* type_key = "element_dtype";
923   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
924   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
925   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
926   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
927   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
928   EXPECT_EQ(output_view.GetNode("tl3")->attr().at(type_key).type(), DT_HALF);
929   EXPECT_EQ(output_view.GetNode("tl3r1")->attr().at(type_key).type(), DT_HALF);
930 
931   auto tensors = EvaluateNodes(output, item.fetch);
932   EXPECT_EQ(tensors.size(), tensors_expected.size());
933   EXPECT_EQ(tensors.size(), item.fetch.size());
934   for (int i = 0; i < item.fetch.size(); ++i) {
935     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
936   }
937 }
938 
TEST_F(AutoMixedPrecisionTest,TensorListThroughFunction)939 TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
940   // This test passes a tensor list handle through a function with its own
941   // Tensor List ops inside to test that the types are not changed to a
942   // conflicting state.
943   // A separate Tensor List cluster is added to test that it is still changed to
944   // DT_HALF.
945   FunctionDefLibrary function_lib;
946   const Tensor kShape = test::AsTensor<int32>({32, 32});
947   FunctionDef func1 = FunctionDefHelper::Define(
948       "Func1", {"ihandle: variant", "x: float"},
949       {"ohandle: variant", "y: float"}, {},
950       {
951           {{"tl1w1_handle"},
952            "TensorListPushBack",
953            {"ihandle", "x"},
954            {{"element_dtype", DT_FLOAT}}},
955           {{"shape"}, "Const", {}, {{"value", kShape}, {"dtype", DT_INT32}}},
956           {{"tl1r1_handle", "tl1r1_data"},
957            "TensorListPopBack",
958            {"tl1w1_handle", "shape"},
959            {{"element_dtype", DT_FLOAT}}},
960           {{"ohandle"}, "Identity", {"tl1r1_handle"}, {{"T", DT_VARIANT}}},
961           {{"y"}, "Identity", {"tl1r1_data"}, {{"T", DT_FLOAT}}},
962       });
963   function_lib.add_function()->Swap(&func1);
964 
965   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
966   TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
967   tensorflow::Input shape = {32, 32};
968   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
969   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
970   Output infer1 = ops::Tanh(s.WithOpName("infer1"), allow1);
971   auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
972   auto tl1w1 =
973       ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, infer1);
974   auto _infer1 = tensorflow::ops::AsNodeOut(s, infer1);
975   auto _tl1w1_handle = tensorflow::ops::AsNodeOut(s, tl1w1.output_handle);
976   auto builder =
977       tensorflow::NodeBuilder("Func1", "Func1", s.graph()->op_registry());
978   tensorflow::Node* func1_op;
979   TF_CHECK_OK(builder.Input(_tl1w1_handle)
980                   .Input(_infer1)
981                   .Finalize(s.graph(), &func1_op));
982   Output func1_handle(func1_op, 0);
983   Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), func1_handle,
984                                         shape, DT_FLOAT)
985                      .tensor;
986   auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
987   auto tl2w1 =
988       ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.handle, infer1);
989   Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
990                                         tl2w1.output_handle, shape, DT_FLOAT)
991                      .tensor;
992   Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
993   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
994 
995   GrapplerItem item;
996   item.fetch = {"fetch1"};
997   TF_CHECK_OK(s.ToGraphDef(&item.graph));
998   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
999 
1000   AutoMixedPrecision optimizer;
1001   GraphDef output;
1002   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1003 
1004   VLOG(1) << output.DebugString();
1005 
1006   GraphView output_view(&output);
1007   const char* type_key = "element_dtype";
1008   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1009   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
1010   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_HALF);
1011   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
1012   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
1013   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_HALF);
1014 
1015   auto tensors = EvaluateNodes(output, item.fetch);
1016   EXPECT_EQ(tensors.size(), tensors_expected.size());
1017   EXPECT_EQ(tensors.size(), item.fetch.size());
1018   for (int i = 0; i < item.fetch.size(); ++i) {
1019     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1020   }
1021 }
1022 
GetCudaVersion(const Cluster & cluster)1023 int GetCudaVersion(const Cluster& cluster) {
1024   auto devices = cluster.GetDevices();
1025   for (const auto& device : devices) {
1026     const DeviceProperties& device_properties = device.second;
1027     if (device_properties.type() == "GPU") {
1028       const auto& device_env = device_properties.environment();
1029       auto it = device_env.find("cuda");
1030       if (it != device_env.end()) {
1031         string cuda_version_str = it->second;
1032         return std::stoi(cuda_version_str);
1033       }
1034     }
1035   }
1036   return 0;
1037 }
1038 
TEST_F(AutoMixedPrecisionTest,BatchMatMul)1039 TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
1040   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1041   Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
1042   Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
1043   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
1044 
1045   GrapplerItem item;
1046   item.fetch = {"fetch1"};
1047   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1048   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1049 
1050   AutoMixedPrecision optimizer;
1051   GraphDef output;
1052   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1053 
1054   VLOG(1) << output.DebugString();
1055 
1056   GraphView output_view(&output);
1057   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1058   if (GetCudaVersion(*virtual_cluster_.get()) >= 9010) {
1059     EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1060     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
1061   } else {
1062     EXPECT_EQ(output.node_size(), item.graph.node_size());
1063     EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
1064   }
1065 
1066   auto tensors = EvaluateNodes(output, item.fetch);
1067   EXPECT_EQ(tensors.size(), tensors_expected.size());
1068   EXPECT_EQ(tensors.size(), item.fetch.size());
1069   for (int i = 0; i < item.fetch.size(); ++i) {
1070     test::ExpectClose(tensors_expected[i], tensors[i], -1, 3.0e-3);
1071   }
1072 }
1073 
TEST_F(AutoMixedPrecisionTest,EluOp)1074 TEST_F(AutoMixedPrecisionTest, EluOp) {
1075   TestSimpleUnaryInferOp(
1076       -5, 5, 1.0e-3, 1.0e-3,
1077       [](const tensorflow::Scope& scope, Output input) -> Output {
1078         return ops::Elu(scope, input);
1079       });
1080 }
1081 
TEST_F(AutoMixedPrecisionTest,ErfOp)1082 TEST_F(AutoMixedPrecisionTest, ErfOp) {
1083   TestSimpleUnaryInferOp(
1084       -5, 5, 1.0e-3, -1,
1085       [](const tensorflow::Scope& scope, Output input) -> Output {
1086         return ops::Erf(scope, input);
1087       });
1088 }
1089 
TEST_F(AutoMixedPrecisionTest,ErfcOp)1090 TEST_F(AutoMixedPrecisionTest, ErfcOp) {
1091   TestSimpleUnaryInferOp(
1092       -5, 5, 1.0e-3, -1,
1093       [](const tensorflow::Scope& scope, Output input) -> Output {
1094         return ops::Erfc(scope, input);
1095       });
1096 }
1097 
TEST_F(AutoMixedPrecisionTest,InvOp)1098 TEST_F(AutoMixedPrecisionTest, InvOp) {
1099   TestSimpleUnaryInferOp(
1100       0.01, 10, -1, 1.0e-3,
1101       [](const tensorflow::Scope& scope, Output input) -> Output {
1102         return ops::Inv(scope, input);
1103       });
1104 }
1105 
TEST_F(AutoMixedPrecisionTest,LogOp)1106 TEST_F(AutoMixedPrecisionTest, LogOp) {
1107   TestSimpleUnaryInferOp(
1108       0.01, 10, 1.0e-3, 2.0e-3,
1109       [](const tensorflow::Scope& scope, Output input) -> Output {
1110         return ops::Log(scope, input);
1111       });
1112 }
1113 
TEST_F(AutoMixedPrecisionTest,Log1pOp)1114 TEST_F(AutoMixedPrecisionTest, Log1pOp) {
1115   TestSimpleUnaryInferOp(
1116       -0.99, 9, 1.0e-3, 5.0e-3,
1117       [](const tensorflow::Scope& scope, Output input) -> Output {
1118         return ops::Log1p(scope, input);
1119       });
1120 }
1121 
TEST_F(AutoMixedPrecisionTest,LogSoftmaxOp)1122 TEST_F(AutoMixedPrecisionTest, LogSoftmaxOp) {
1123   TestSimpleUnaryInferOp(
1124       -8, 8, -1, 1.0e-2,
1125       [](const tensorflow::Scope& scope, Output input) -> Output {
1126         return ops::LogSoftmax(scope, input);
1127       });
1128 }
1129 
TEST_F(AutoMixedPrecisionTest,ReciprocalOp)1130 TEST_F(AutoMixedPrecisionTest, ReciprocalOp) {
1131   TestSimpleUnaryInferOp(
1132       0.01, 10, -1, 1.0e-3,
1133       [](const tensorflow::Scope& scope, Output input) -> Output {
1134         return ops::Reciprocal(scope, input);
1135       });
1136 }
1137 
TEST_F(AutoMixedPrecisionTest,SigmoidOp)1138 TEST_F(AutoMixedPrecisionTest, SigmoidOp) {
1139   TestSimpleUnaryInferOp(
1140       -5, 5, 1.0e-3, -1,
1141       [](const tensorflow::Scope& scope, Output input) -> Output {
1142         return ops::Sigmoid(scope, input);
1143       });
1144 }
1145 
TEST_F(AutoMixedPrecisionTest,SoftmaxOp)1146 TEST_F(AutoMixedPrecisionTest, SoftmaxOp) {
1147   TestSimpleUnaryInferOp(
1148       -8, 8, 2.0e-3, -1,
1149       [](const tensorflow::Scope& scope, Output input) -> Output {
1150         return ops::Softmax(scope, input);
1151       });
1152 }
1153 
TEST_F(AutoMixedPrecisionTest,SoftplusOp)1154 TEST_F(AutoMixedPrecisionTest, SoftplusOp) {
1155   TestSimpleUnaryInferOp(
1156       -5, 5, 1.0e-3, 1.0e-3,
1157       [](const tensorflow::Scope& scope, Output input) -> Output {
1158         return ops::Softplus(scope, input);
1159       });
1160 }
1161 
TEST_F(AutoMixedPrecisionTest,SqrtOp)1162 TEST_F(AutoMixedPrecisionTest, SqrtOp) {
1163   TestSimpleUnaryInferOp(
1164       0, 10, 1.0e-3, 1.0e-3,
1165       [](const tensorflow::Scope& scope, Output input) -> Output {
1166         return ops::Sqrt(scope, input);
1167       });
1168 }
1169 
TEST_F(AutoMixedPrecisionTest,TanhOp)1170 TEST_F(AutoMixedPrecisionTest, TanhOp) {
1171   TestSimpleUnaryInferOp(
1172       -5, 5, 1.0e-3, -1,
1173       [](const tensorflow::Scope& scope, Output input) -> Output {
1174         return ops::Tanh(scope, input);
1175       });
1176 }
1177 
1178 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1179 
1180 #if INTEL_MKL
1181 #ifdef ENABLE_INTEL_MKL_BFLOAT16
1182 
1183 class AutoMixedPrecisionMklTest : public GrapplerTest {
1184  protected:
SetUp()1185   void SetUp() override {
1186     virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
1187     TF_CHECK_OK(virtual_cluster_->Provision());
1188   }
TearDown()1189   void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
1190 
1191   std::unique_ptr<Cluster> virtual_cluster_;
1192 };
1193 
TEST_F(AutoMixedPrecisionMklTest,AlreadyBf16)1194 TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
1195   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1196   Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
1197   Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
1198   Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
1199   Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
1200   Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
1201   Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
1202   Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
1203 
1204   GrapplerItem item;
1205   item.fetch = {"fetch"};
1206   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1207   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1208 
1209   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1210   GraphDef output;
1211   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1212   VLOG(1) << output.DebugString();
1213 
1214   VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
1215   GraphView output_view(&output);
1216   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1217   EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
1218   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1219   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
1220   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
1221   EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
1222   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
1223 
1224   auto tensors = EvaluateNodes(output, item.fetch);
1225   EXPECT_EQ(tensors.size(), tensors_expected.size());
1226   EXPECT_EQ(tensors.size(), item.fetch.size());
1227   for (int i = 0; i < item.fetch.size(); ++i) {
1228     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
1229   }
1230 }
1231 
TEST_F(AutoMixedPrecisionMklTest,Simple)1232 TEST_F(AutoMixedPrecisionMklTest, Simple) {
1233   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1234   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1235   Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
1236   Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
1237   Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
1238   Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
1239   Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
1240   Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
1241   Output deny2 = ops::Log(s.WithOpName("deny2"), clr3);
1242   Output clr4 = ops::Relu(s.WithOpName("clr4"), deny2);
1243   Output deny3 = ops::SparseMatMul(s.WithOpName("deny3"), clr4, clr4);
1244   Output clr5 = ops::Relu(s.WithOpName("clr5"), deny3);
1245   Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
1246 
1247   GrapplerItem item;
1248   item.fetch = {"fetch"};
1249   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1250   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1251 
1252   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1253   GraphDef output;
1254   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1255 
1256   VLOG(1) << output.DebugString();
1257 
1258   GraphView output_view(&output);
1259   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1260   EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
1261   EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
1262   EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
1263   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
1264   EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
1265   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1266   EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
1267   EXPECT_EQ(output_view.GetNode("deny2")->attr().at("T").type(), DT_FLOAT);
1268   EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
1269   EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Ta").type(), DT_FLOAT);
1270   EXPECT_EQ(output_view.GetNode("deny3")->attr().at("Tb").type(), DT_FLOAT);
1271   EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
1272 
1273   auto tensors = EvaluateNodes(output, item.fetch);
1274   EXPECT_EQ(tensors.size(), tensors_expected.size());
1275   EXPECT_EQ(tensors.size(), item.fetch.size());
1276   for (int i = 0; i < item.fetch.size(); ++i) {
1277     test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
1278   }
1279 }
1280 
TEST_F(AutoMixedPrecisionMklTest,TensorListSetGet)1281 TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
1282   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1283   tensorflow::Input shape = {32, 32};
1284   auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
1285   Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
1286   Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
1287   Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
1288   Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
1289   auto tl1w1 =
1290       ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
1291   Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
1292   auto tl1w2 =
1293       ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
1294   // Ensure that TensorListResize doesn't cause any problems.
1295   Output tl1rs =
1296       ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
1297   Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
1298                                         shape, DT_FLOAT)
1299                      .item;
1300   Output infer1 = ops::Mul(s.WithOpName("infer1"), tl1r1, tl1r1);
1301   Output allow2 = ops::MatMul(s.WithOpName("allow2"), infer1, infer1);
1302   auto tl1w3 =
1303       ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
1304   Output tl1r2 =
1305       ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
1306                              shape, DT_FLOAT)
1307           .item;
1308   auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
1309   auto tl2w1 =
1310       ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
1311   Output tl2r1 =
1312       ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
1313                              shape, DT_FLOAT)
1314           .item;
1315   Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
1316   Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
1317 
1318   GrapplerItem item;
1319   item.fetch = {"fetch1", "fetch2"};
1320   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1321   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1322 
1323   AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
1324   GraphDef output;
1325   TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
1326 
1327   VLOG(1) << output.DebugString();
1328 
1329   GraphView output_view(&output);
1330   EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
1331   const char* type_key = "element_dtype";
1332   EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
1333             DT_BFLOAT16);
1334   EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
1335             DT_BFLOAT16);
1336   EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
1337   EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
1338             DT_BFLOAT16);
1339   EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
1340             DT_BFLOAT16);
1341   EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_BFLOAT16);
1342   EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
1343   EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
1344             DT_BFLOAT16);
1345   EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
1346   EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
1347   EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
1348 
1349   auto tensors = EvaluateNodes(output, item.fetch);
1350   EXPECT_EQ(tensors.size(), tensors_expected.size());
1351   EXPECT_EQ(tensors.size(), item.fetch.size());
1352   for (int i = 0; i < item.fetch.size(); ++i) {
1353     test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
1354   }
1355 }
1356 
1357 #endif  // ENABLE_INTEL_MKL_BFLOAT16
1358 #endif  // INTEL_MKL
1359 
1360 }  // namespace
1361 }  // namespace grappler
1362 }  // namespace tensorflow
1363 
1364 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || (INTEL_MKL &&
1365         // defined(ENABLE_INTEL_MKL_BFLOAT16))
1366