1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
17
18 #include <fstream>
19 #include <memory>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/grappler/clusters/cluster.h"
29 #include "tensorflow/core/grappler/costs/virtual_placer.h"
30 #include "tensorflow/core/grappler/devices.h"
31 #include "tensorflow/core/grappler/grappler_item.h"
32 #include "tensorflow/core/grappler/mutable_graph_view.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/utils.h"
37 #include "tensorflow/core/lib/io/path.h"
38 #include "tensorflow/core/lib/strings/numbers.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/env_var.h"
43
44 namespace tensorflow {
45 namespace grappler {
46 namespace {
47
48 #if GOOGLE_CUDA
49 const std::pair<int, int> kMinGPUArch = {7, 0};
50 #else
51 const std::pair<int, int> kMinGPUArch = {0, 0};
52 #endif
53
54 const char kSuffix[] = "AutoMixedPrecision";
55 const char kCastToFp16[] = "CastToFp16";
56 const char kCastToBf16[] = "CastToBf16";
57 const char kCastToFp32[] = "CastToFp32";
58
59 // Instances of this class represent unique type attribute identifiers within a
60 // node. It handles regular type attributes, list type attributes (where
61 // type_index is set to the index in the type list), and fixed types.
62 struct TypeAttrId {
63 static constexpr int kSingleType = -1;
64
TypeAttrIdtensorflow::grappler::__anon69ea7d4c0111::TypeAttrId65 explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType)
66 : attr_name(_attr_name),
67 type_index(_type_index),
68 fixed_type(DT_INVALID) {}
69
TypeAttrIdtensorflow::grappler::__anon69ea7d4c0111::TypeAttrId70 explicit TypeAttrId(DataType _fixed_type)
71 : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {}
72
operator ==tensorflow::grappler::__anon69ea7d4c0111::TypeAttrId73 bool operator==(const TypeAttrId& other) const {
74 return attr_name == other.attr_name && type_index == other.type_index &&
75 fixed_type == other.fixed_type;
76 }
77
operator <tensorflow::grappler::__anon69ea7d4c0111::TypeAttrId78 bool operator<(const TypeAttrId& other) const {
79 return std::make_tuple(attr_name, type_index, fixed_type) <
80 std::make_tuple(other.attr_name, other.type_index, other.fixed_type);
81 }
82
83 template <typename H>
AbslHashValue(H h,const TypeAttrId & ta)84 friend H AbslHashValue(H h, const TypeAttrId& ta) {
85 return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type);
86 }
87
DebugStringtensorflow::grappler::__anon69ea7d4c0111::TypeAttrId88 string DebugString() const {
89 if (!attr_name.empty()) {
90 if (type_index == kSingleType) {
91 return attr_name;
92 } else {
93 return strings::StrCat(attr_name, "[", type_index, "]");
94 }
95 } else {
96 return tensorflow::DataTypeString(fixed_type);
97 }
98 }
99
100 string attr_name;
101 // If attr_name is a list(type), this is the index into the list. Otherwise
102 // this is kSingleType.
103 int type_index;
104 DataType fixed_type;
105 };
106
107 // Returns the data type of the given type attribute, or DT_INVALID if the type
108 // attribute is invalid.
GetDataType(const NodeDef & node,const TypeAttrId & type_attr)109 DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) {
110 if (type_attr.attr_name.empty()) {
111 return type_attr.fixed_type;
112 }
113 if (!node.attr().count(type_attr.attr_name)) {
114 return DT_INVALID;
115 }
116 const AttrValue& attr_value = node.attr().at(type_attr.attr_name);
117 if (type_attr.type_index == TypeAttrId::kSingleType) {
118 return attr_value.type();
119 } else {
120 if (type_attr.type_index < 0 ||
121 type_attr.type_index >= attr_value.list().type_size()) {
122 return DT_INVALID;
123 }
124 return attr_value.list().type(type_attr.type_index);
125 }
126 }
127
128 // Sets the data type of the given type attribute. Returns false if the type
129 // attribute is invalid, otherwise true.
SetDataType(NodeDef * node,const TypeAttrId & type_attr,DataType type)130 bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) {
131 if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) {
132 return false;
133 }
134 AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name);
135 if (type_attr.type_index == TypeAttrId::kSingleType) {
136 attr_value.set_type(type);
137 } else {
138 if (type_attr.type_index < 0 ||
139 type_attr.type_index >= attr_value.list().type_size()) {
140 return false;
141 }
142 attr_value.mutable_list()->set_type(type_attr.type_index, type);
143 }
144 return true;
145 }
146
ArgDefIndexes(const NodeDef & node,int arg_idx,const OpDef::ArgDef & arg_def)147 std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx,
148 const OpDef::ArgDef& arg_def) {
149 std::vector<std::pair<int, int>> argdef_inds;
150 if (!arg_def.type_list_attr().empty()) {
151 int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size();
152 for (int type_idx = 0; type_idx < num_types; ++type_idx) {
153 argdef_inds.push_back({arg_idx, type_idx});
154 }
155 } else {
156 int num_repeat = 1;
157 if (node.attr().count(arg_def.number_attr())) {
158 num_repeat = node.attr().at(arg_def.number_attr()).i();
159 }
160 argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1});
161 }
162 return argdef_inds;
163 }
164
165 // Returns a pair (arg_index, type_index) for each input to the node, where
166 // arg_index is the index of the input_arg in op_def and type_index is the index
167 // of the type in type_list_attr (only defined for list arguments).
InputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)168 std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node,
169 const OpDef& op_def) {
170 std::vector<std::pair<int, int>> argdef_inds;
171 argdef_inds.reserve(op_def.input_arg_size()); // Final size may differ.
172 for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) {
173 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx);
174 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
175 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
176 arg_results.end());
177 }
178 return argdef_inds;
179 }
180
181 // Returns a pair (arg_index, type_index) for each output to the node, where
182 // arg_index is the index of the output_arg in op_def and type_index is the
183 // index of the type in type_list_attr (only defined for list arguments).
OutputPortArgDefIndexes(const NodeDef & node,const OpDef & op_def)184 std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node,
185 const OpDef& op_def) {
186 std::vector<std::pair<int, int>> argdef_inds;
187 argdef_inds.reserve(op_def.output_arg_size()); // Final size may differ.
188 for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) {
189 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx);
190 auto arg_results = ArgDefIndexes(node, arg_idx, arg_def);
191 argdef_inds.insert(argdef_inds.end(), arg_results.begin(),
192 arg_results.end());
193 }
194 return argdef_inds;
195 }
196
GetTypeAttrId(const OpDef::ArgDef & arg_def,int arg_type_index)197 TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) {
198 if (!arg_def.type_list_attr().empty()) {
199 return TypeAttrId(arg_def.type_list_attr(), arg_type_index);
200 } else if (!arg_def.type_attr().empty()) {
201 return TypeAttrId(arg_def.type_attr());
202 } else {
203 return TypeAttrId(arg_def.type());
204 }
205 }
206
NonControlInputs(const NodeDef & node)207 std::vector<int> NonControlInputs(const NodeDef& node) {
208 std::vector<int> pos;
209 for (int i = 0; i < node.input_size(); i++) {
210 if (!IsControlInput(node.input(i))) {
211 pos.push_back(i);
212 }
213 }
214 return pos;
215 }
216
217 // A utility class to lookup node type attributes and type attribute <->
218 // input/output port mappings.
219 class NodeTypeAttrMap {
220 public:
NodeTypeAttrMap()221 NodeTypeAttrMap() {}
222
NodeTypeAttrMap(const GraphDef & graph)223 explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); }
224
Init(const GraphDef & graph)225 Status Init(const GraphDef& graph) {
226 if (graph_ != nullptr) {
227 return errors::InvalidArgument("NodeTypeAttrMap is already initialized.");
228 }
229 graph_ = &graph;
230 function_library_.reset(
231 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
232 for (const NodeDef& node : graph.node()) {
233 TF_RETURN_IF_ERROR(AddNode(node));
234 }
235 return Status::OK();
236 }
237
is_initialized() const238 bool is_initialized() const { return graph_ != nullptr; }
239
240 // Returns the set of all type attributes in the given node.
GetTypeAttrs(const NodeDef & node) const241 absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const {
242 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
243 absl::flat_hash_set<TypeAttrId> type_attrs;
244 const auto iter = type2io_.find(&node);
245 CHECK(iter != type2io_.end()); // Crash Ok
246 for (const auto& key_value : iter->second) {
247 type_attrs.insert(key_value.first);
248 }
249 return type_attrs;
250 }
251
GetInputPorts(const NodeDef & node,const TypeAttrId & type_attr) const252 const absl::flat_hash_set<int>& GetInputPorts(
253 const NodeDef& node, const TypeAttrId& type_attr) const {
254 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
255 return type2io_.at(&node).at(type_attr).first;
256 }
257
GetOutputPorts(const NodeDef & node,const TypeAttrId & type_attr) const258 const absl::flat_hash_set<int>& GetOutputPorts(
259 const NodeDef& node, const TypeAttrId& type_attr) const {
260 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
261 return type2io_.at(&node).at(type_attr).second;
262 }
263
GetInputTypeAttr(const NodeDef & node,int port) const264 TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const {
265 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
266 auto type_vec = io2type_.at(&node).first;
267 CHECK_GE(port, 0); // Crash Ok
268 CHECK_LT(port, type_vec.size()); // Crash Ok
269 return type_vec[port];
270 }
271
GetOutputTypeAttr(const NodeDef & node,int port) const272 TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const {
273 DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized";
274 auto type_vec = io2type_.at(&node).second;
275 CHECK_GE(port, 0); // Crash Ok
276 CHECK_LT(port, type_vec.size()); // Crash Ok
277 return type_vec[port];
278 }
279
280 private:
AddNode(const NodeDef & node)281 Status AddNode(const NodeDef& node) {
282 const OpDef* op_def_ptr = nullptr;
283 TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr));
284 const OpDef& op_def = *op_def_ptr;
285 auto& type2io_entry = type2io_[&node];
286 auto& io2type_entry = io2type_[&node];
287 auto input_arg_inds = InputPortArgDefIndexes(node, op_def);
288 if (NonControlInputs(node).size() != input_arg_inds.size()) {
289 return errors::InvalidArgument(
290 "Expected ", node.op(), " node ", node.name(), " to have ",
291 input_arg_inds.size(), " non-control input(s), but got ",
292 node.input_size());
293 }
294 // Note that the mappings generated here include inputs/outputs with fixed
295 // types. This makes the mappings complete (all inputs and outputs are
296 // included), and allows the graph rewriter to propagate deny paint
297 // from/through ops with fixed types.
298 io2type_entry.first.reserve(input_arg_inds.size());
299 for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) {
300 const auto& arg_inds = input_arg_inds[i];
301 const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first);
302 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
303 if (!type_attr.attr_name.empty() &&
304 !node.attr().count(type_attr.attr_name)) {
305 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
306 " is not present in node ", node.name());
307 }
308 type2io_entry[type_attr].first.insert(i);
309 io2type_entry.first.push_back(type_attr);
310 }
311
312 auto output_arg_inds = OutputPortArgDefIndexes(node, op_def);
313 io2type_entry.second.reserve(output_arg_inds.size());
314 for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) {
315 const auto& arg_inds = output_arg_inds[i];
316 const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first);
317 TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second);
318 if (!type_attr.attr_name.empty() &&
319 !node.attr().count(type_attr.attr_name)) {
320 return errors::InvalidArgument("Type attribute ", type_attr.attr_name,
321 " is not present in node ", node.name());
322 }
323 type2io_entry[type_attr].second.insert(i);
324 io2type_entry.second.push_back(type_attr);
325 }
326
327 // Also ensure that type attributes that aren't associated with any inputs
328 // or outputs (e.g., StackV2's elem_type) are added to the map.
329 for (const auto& attr : node.attr()) {
330 const string& attr_name = attr.first;
331 if (!attr_name.empty() && attr_name[0] == '_') continue;
332 const AttrValue& attr_value = attr.second;
333 const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def);
334 if (!attr_def) {
335 return errors::InvalidArgument("AttrDef not found for attribute ",
336 attr_name, " of node ", node.name());
337 }
338 if (attr_def->type() == "type") {
339 type2io_entry[TypeAttrId(attr_name)];
340 } else if (attr_def->type() == "list(type)") {
341 for (int i = 0; i < attr_value.list().type_size(); ++i) {
342 type2io_entry[TypeAttrId(attr_name, i)];
343 }
344 }
345 }
346 return Status::OK();
347 }
348
349 // WARN: `graph_` must outlive this object (node pointers must remain valid).
350 const GraphDef* graph_ = nullptr; // do not own
351 std::unique_ptr<FunctionLibraryDefinition> function_library_;
352
353 typedef absl::flat_hash_set<int> IntSet;
354 // Maps a type attr id -> (input port set, output port set)
355 typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap;
356 // Maps a node -> type attr mapping
357 absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_;
358 // Maps a port -> type attr id
359 typedef std::vector<TypeAttrId> TypeAttrIdVec;
360 // Maps a node -> (input port mapping, output port mapping)
361 absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>>
362 io2type_;
363 };
364
365 struct NodeTypeId {
NodeTypeIdtensorflow::grappler::__anon69ea7d4c0111::NodeTypeId366 NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr)
367 : node(_node), type_attr(_type_attr) {}
368
369 const NodeDef* node;
370 TypeAttrId type_attr;
371
operator ==tensorflow::grappler::__anon69ea7d4c0111::NodeTypeId372 bool operator==(const NodeTypeId& other) const {
373 return node == other.node && type_attr == other.type_attr;
374 }
375
376 template <typename H>
AbslHashValue(H h,const NodeTypeId & nt)377 friend H AbslHashValue(H h, const NodeTypeId& nt) {
378 return H::combine(std::move(h), nt.node, nt.type_attr);
379 }
380 };
381
382 struct NodeTypeIdEdge {
NodeTypeIdEdgetensorflow::grappler::__anon69ea7d4c0111::NodeTypeIdEdge383 NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst)
384 : src(_src), dst(_dst) {}
385 NodeTypeId src;
386 NodeTypeId dst;
387 };
388
389 // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be
390 // used instead of this modified version.
391 // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as
392 // the vertices instead of just NodeDef.
393 // For example, if node A has output A:0 with TypeAttrId 'T', and node B has
394 // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there
395 // will be an edge from (A, T) to (B, U).
396 class GraphTypeTopologyView {
397 public:
398 GraphTypeTopologyView() = default;
GraphTypeTopologyView(bool skip_invalid_edges)399 explicit GraphTypeTopologyView(bool skip_invalid_edges)
400 : skip_invalid_edges_(skip_invalid_edges) {}
401
402 // Initialize graph topology view from the graph. It's possible to pass
403 // additional edges that do not exist in a graph, but must be respected when
404 // computing graph topology. Example: Tensorflow runtime allows concurrent
405 // execution of dequeue/enqueue ops from the same queue resource, but we might
406 // want to enforce ordering between them for the purpose of graph analysis.
407 Status InitializeFromGraph(const GraphDef& graph,
408 const NodeTypeAttrMap& node_type_map);
409
410 Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges);
411
is_initialized() const412 bool is_initialized() const { return graph_ != nullptr; }
num_nodes() const413 int num_nodes() const { return num_nodes_; }
graph() const414 const GraphDef* graph() const { return graph_; }
415
416 // Returns true iff the node exists in the underlying graph.
417 bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const;
418
419 // Finds a node by name or returns `nullptr` if it's not in the graph.
420 const NodeTypeId* GetNode(absl::string_view node_name,
421 const TypeAttrId& type_attr) const;
422 // Returns a node corresponding to the given node index.
423 const NodeTypeId* GetNode(int node_idx) const;
424
425 // Returns a node index for the given node name, if the name exists in the
426 // underlying graph. Otherwise returns empty optional.
427 const absl::optional<int> GetNodeIndex(absl::string_view node_name,
428 const TypeAttrId& type_attr) const;
429 // Returns a node index for the given node, if the node belongs to the
430 // underlying graph. Otherwise returns empty optional.
431 const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const;
432
433 // Returns all the node indexes that are in the direct fanin of the given
434 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
435 const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
436 // Returns all the node indexes that are in the direct fanout of the given
437 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
438 const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
439
440 private:
441 // The key type used to uniquely identify a type attribute on a node.
442 struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> {
443 typedef std::pair<absl::string_view, TypeAttrId> Base;
444
445 // Inherit the set of constructors.
446 using Base::pair;
447
448 template <typename H>
AbslHashValue(H h,const NodeTypeKey & nt)449 friend H AbslHashValue(H h, const NodeTypeKey& nt) {
450 return H::combine(std::move(h), nt.first, nt.second);
451 }
452 };
453
454 // If true, all invalid edges and inputs (srd, dst or input node not found in
455 // a graph) will be skipped, otherwise initialization will fail with error.
456 bool skip_invalid_edges_ = false;
457
458 // WARN: `graph_` must outlive this object and graph nodes must not be
459 // destructed, because node names captured with absl::string_view.
460 const GraphDef* graph_ = nullptr; // do not own
461 int num_nodes_ = 0;
462 std::vector<NodeTypeId> node_type_attrs_;
463 absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
464 absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_;
465
466 std::vector<absl::InlinedVector<int, 4>> fanins_;
467 std::vector<absl::InlinedVector<int, 2>> fanouts_;
468
469 // We need a valid reference to return from GetFanin/GetFanout if the
470 // `node_idx` argument is outside of the [0, num_nodes_) range.
471 absl::InlinedVector<int, 4> empty_fanin_;
472 absl::InlinedVector<int, 2> empty_fanout_;
473 };
474
475 template <typename T>
SortAndRemoveDuplicates(T * v)476 inline void SortAndRemoveDuplicates(T* v) {
477 std::sort(v->begin(), v->end());
478 v->erase(std::unique(v->begin(), v->end()), v->end());
479 }
480
InitializeFromGraph(const GraphDef & graph,const NodeTypeAttrMap & node_type_map)481 Status GraphTypeTopologyView::InitializeFromGraph(
482 const GraphDef& graph, const NodeTypeAttrMap& node_type_map) {
483 if (graph_ != nullptr) {
484 return errors::InvalidArgument(
485 "GraphTypeTopologyView is already initialized.");
486 }
487
488 graph_ = &graph;
489 int num_nodedefs = graph.node_size();
490 node_name_to_index_.rehash(num_nodedefs);
491
492 // Build maps from name to index.
493 node_type_attrs_.reserve(num_nodedefs); // Only approximate.
494 node_type_name_to_index_.rehash(num_nodedefs); // Only approximate.
495 for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) {
496 const NodeDef& node = graph.node(node_idx);
497 node_name_to_index_.emplace(node.name(), node_idx);
498
499 for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) {
500 int node_type_idx = node_type_attrs_.size();
501 node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr),
502 node_type_idx);
503 node_type_attrs_.emplace_back(&node, type_attr);
504 }
505 }
506 num_nodes_ = node_type_attrs_.size();
507 fanins_.resize(num_nodes_);
508 fanouts_.resize(num_nodes_);
509
510 // Add graph edges to the adjacency lists.
511 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
512 const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx);
513 auto input_ports =
514 node_type_map.GetInputPorts(*node_type.node, node_type.type_attr);
515 fanins_[node_type_idx].reserve(input_ports.size());
516 for (int port : input_ports) {
517 const string& input = node_type.node->input(port);
518 TensorId tensor = ParseTensorName(input);
519 const auto it = node_name_to_index_.find(tensor.node());
520 const bool valid_input = it != node_name_to_index_.end();
521
522 if (!valid_input) {
523 const string error_message = absl::StrCat(
524 "Non-existent input ", input, " in node ", node_type.node->name());
525 if (skip_invalid_edges_) {
526 VLOG(3) << "Skip error: " << error_message;
527 } else {
528 return errors::InvalidArgument(error_message);
529 }
530 }
531
532 if (valid_input) {
533 const int input_idx = it->second;
534 const NodeDef& input_node = graph_->node(input_idx);
535 TypeAttrId input_type_attr =
536 node_type_map.GetOutputTypeAttr(input_node, tensor.index());
537 const auto it2 = node_type_name_to_index_.find(
538 NodeTypeKey(input_node.name(), input_type_attr));
539 if (it2 == node_type_name_to_index_.end()) {
540 if (!skip_invalid_edges_) {
541 return errors::InvalidArgument("Did not find type attr ",
542 input_type_attr.DebugString(),
543 " in node ", input_node.name());
544 }
545 continue;
546 }
547 int input_node_type_idx = it2->second;
548 fanins_[node_type_idx].push_back(input_node_type_idx);
549 fanouts_[input_node_type_idx].push_back(node_type_idx);
550 }
551 }
552
553 // Dedup the input list while it's still hot in cache.
554 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
555 }
556
557 // Dedup outputs for all the graph nodes.
558 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
559 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
560 }
561
562 return Status::OK();
563 }
564
AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges)565 Status GraphTypeTopologyView::AddEphemeralEdges(
566 absl::Span<const NodeTypeIdEdge> ephemeral_edges) {
567 // Add ephemeral edges to the adjacency lists.
568 for (const NodeTypeIdEdge& edge : ephemeral_edges) {
569 const auto src = node_name_to_index_.find(edge.src.node->name());
570 const bool valid_src = src != node_name_to_index_.end();
571
572 if (!valid_src) {
573 const string error_message =
574 absl::StrCat("Non-existent src node: ", edge.src.node->name());
575 if (skip_invalid_edges_) {
576 VLOG(0) << "Skip error: " << error_message;
577 } else {
578 return errors::InvalidArgument(error_message);
579 }
580 }
581
582 const auto dst = node_name_to_index_.find(edge.dst.node->name());
583 const bool valid_dst = dst != node_name_to_index_.end();
584
585 if (!valid_dst) {
586 const string error_message =
587 absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
588 if (skip_invalid_edges_) {
589 VLOG(0) << "Skip error: " << error_message;
590 } else {
591 return errors::InvalidArgument(error_message);
592 }
593 }
594
595 if (valid_dst && valid_src) {
596 // TODO(benbarsdell): Check for failure.
597 int src_node_type_idx = node_type_name_to_index_.at(
598 NodeTypeKey(edge.src.node->name(), edge.src.type_attr));
599 int dst_node_type_idx = node_type_name_to_index_.at(
600 NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr));
601 fanins_[dst_node_type_idx].push_back(src_node_type_idx);
602 fanouts_[src_node_type_idx].push_back(dst_node_type_idx);
603 }
604 }
605
606 // Dedup inputs and outputs for all the graph nodes.
607 for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) {
608 SortAndRemoveDuplicates(&fanins_[node_type_idx]);
609 SortAndRemoveDuplicates(&fanouts_[node_type_idx]);
610 }
611
612 return Status::OK();
613 }
614
HasNode(absl::string_view node_name,const TypeAttrId & type_attr) const615 bool GraphTypeTopologyView::HasNode(absl::string_view node_name,
616 const TypeAttrId& type_attr) const {
617 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
618 NodeTypeKey key(node_name, type_attr);
619 const auto it = node_type_name_to_index_.find(key);
620 return it != node_type_name_to_index_.end();
621 }
622
GetNode(absl::string_view node_name,const TypeAttrId & type_attr) const623 const NodeTypeId* GraphTypeTopologyView::GetNode(
624 absl::string_view node_name, const TypeAttrId& type_attr) const {
625 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
626 NodeTypeKey key(node_name, type_attr);
627 const auto it = node_type_name_to_index_.find(key);
628 return it == node_type_name_to_index_.end()
629 ? nullptr
630 : &node_type_attrs_.at(it->second);
631 }
632
GetNode(int node_idx) const633 const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const {
634 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
635 DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
636 return &node_type_attrs_.at(node_idx);
637 }
638
GetNodeIndex(absl::string_view node_name,const TypeAttrId & type_attr) const639 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
640 absl::string_view node_name, const TypeAttrId& type_attr) const {
641 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
642 NodeTypeKey key(node_name, type_attr);
643 const auto it = node_type_name_to_index_.find(key);
644 DCHECK(it != node_type_name_to_index_.end())
645 << "Node doesn't exist in a graph";
646 return it == node_type_name_to_index_.end() ? absl::nullopt
647 : absl::make_optional(it->second);
648 }
649
GetNodeIndex(const NodeTypeId & node) const650 const absl::optional<int> GraphTypeTopologyView::GetNodeIndex(
651 const NodeTypeId& node) const {
652 return GetNodeIndex(node.node->name(), node.type_attr);
653 }
654
GetFanin(int node_idx) const655 const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin(
656 int node_idx) const {
657 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
658 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
659 DCHECK(is_valid_node_idx) << "node_idx is out of range";
660 return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
661 }
662
GetFanout(int node_idx) const663 const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout(
664 int node_idx) const {
665 DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized";
666 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
667 DCHECK(is_valid_node_idx) << "node_idx is out of range";
668 return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
669 }
670
671 enum class TypeTraversalDirection {
672 kFollowInputs,
673 kFollowOutputs,
674 kFollowInputsAndOutputs,
675 };
676
677 // Encapsulate DFS callbacks that will be called during the graph traversal.
678 //
679 // If non-empty, the `pre_order` and `post_order` functors will be called on
680 // each reachable node (including the `from` nodes) in pre and post order. If
681 // loops are found, the `on_back_edge` functor will be called on the
682 // corresponding back edges. Moreover, the pre and post order will assume that
683 // these back edges will be cut.
684 struct DfsTypeCallbacks {
685 DfsTypeCallbacks() = default;
DfsTypeCallbackstensorflow::grappler::__anon69ea7d4c0111::DfsTypeCallbacks686 DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post,
687 std::function<void(int, int)> back_edge)
688 : pre_order(std::move(pre)),
689 post_order(std::move(post)),
690 on_back_edge(std::move(back_edge)) {}
691
PreOrdertensorflow::grappler::__anon69ea7d4c0111::DfsTypeCallbacks692 static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) {
693 return DfsTypeCallbacks(std::move(pre), nullptr, nullptr);
694 }
695
PostOrdertensorflow::grappler::__anon69ea7d4c0111::DfsTypeCallbacks696 static DfsTypeCallbacks PostOrder(std::function<void(int)> post) {
697 return DfsTypeCallbacks(nullptr, std::move(post), nullptr);
698 }
699
700 std::function<void(int)> pre_order;
701 std::function<void(int)> post_order;
702 std::function<void(int, int)> on_back_edge;
703 };
704
705 // Encapsulate DFS predicates for traversing the graph.
706 //
707 // The `enter` predicate decides if traversal should enter the node, and the
708 // `advance` predicate decides if the traversal should follow inputs/outputs
709 // from the node.
710 //
711 // If predicates are empty (default initialized), it's assumed that we can enter
712 // into any node and advance from any node respectively.
713 struct DfsTypePredicates {
714 DfsTypePredicates() = default;
DfsTypePredicatestensorflow::grappler::__anon69ea7d4c0111::DfsTypePredicates715 DfsTypePredicates(std::function<bool(int)> enter,
716 std::function<bool(int)> advance)
717 : enter(std::move(enter)), advance(std::move(advance)) {}
718
Entertensorflow::grappler::__anon69ea7d4c0111::DfsTypePredicates719 static DfsTypePredicates Enter(std::function<bool(int)> enter) {
720 return DfsTypePredicates(std::move(enter), nullptr);
721 }
722
Advancetensorflow::grappler::__anon69ea7d4c0111::DfsTypePredicates723 static DfsTypePredicates Advance(std::function<bool(int)> advance) {
724 return DfsTypePredicates(nullptr, std::move(advance));
725 }
726
727 std::function<bool(int)> enter;
728 std::function<bool(int)> advance;
729 };
730
731 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon69ea7d4c0111::DfsStackElem732 DfsStackElem(int node, bool children_visited, int src)
733 : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon69ea7d4c0111::DfsStackElem734 explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
735
736 // Index of the node in the graph ∊ [0, num_nodes).
737 int node;
738 // `True` if visited all the input/output nodes (pushed all input/output nodes
739 // to the stack).
740 bool children_visited;
741 // Index of the node in the graph, from which we entered the `node`.
742 int src;
743 };
744
745 enum class NodeState { kNotVisited, kVisiting, kDone };
746
DfsTypeTraversal(const GraphTypeTopologyView & graph_type_view,const absl::Span<const NodeTypeId * const> from,const TypeTraversalDirection direction,const DfsTypePredicates & predicates,const DfsTypeCallbacks & callbacks)747 void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view,
748 const absl::Span<const NodeTypeId* const> from,
749 const TypeTraversalDirection direction,
750 const DfsTypePredicates& predicates,
751 const DfsTypeCallbacks& callbacks) {
752 std::vector<DfsStackElem> stack;
753 stack.reserve(from.size());
754
755 for (const NodeTypeId* node : from) {
756 const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node);
757 DCHECK(node_idx.has_value())
758 << "Illegal start node: " << node->node->name();
759 if (node_idx.has_value()) {
760 stack.emplace_back(node_idx.value());
761 }
762 }
763
764 absl::flat_hash_map<int, NodeState> node_state;
765 while (!stack.empty()) {
766 DfsStackElem w = stack.back();
767 stack.pop_back();
768
769 NodeState& state = node_state[w.node];
770 if (state == NodeState::kDone) continue;
771
772 // Skip nodes that we should not enter.
773 if (predicates.enter && !predicates.enter(w.node)) {
774 state = NodeState::kDone;
775 continue;
776 }
777
778 // We've processed all the children of this node.
779 if (w.children_visited) {
780 state = NodeState::kDone;
781 if (callbacks.post_order) {
782 callbacks.post_order(w.node);
783 }
784 continue;
785 }
786
787 // Loop detected.
788 if (state == NodeState::kVisiting) {
789 if (callbacks.on_back_edge) {
790 callbacks.on_back_edge(w.src, w.node);
791 }
792 continue;
793 }
794
795 state = NodeState::kVisiting;
796 if (callbacks.pre_order) {
797 callbacks.pre_order(w.node);
798 }
799
800 // Enqueue the node again with the children_visited flag set to true.
801 stack.emplace_back(w.node, true, w.src);
802
803 // Check if we can continue traversal from the current node.
804 if (predicates.advance && !predicates.advance(w.node)) {
805 continue;
806 }
807
808 // Now enqueue the fanin/fanout nodes.
809 if (direction == TypeTraversalDirection::kFollowInputs ||
810 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
811 for (const int fanin : graph_type_view.GetFanin(w.node)) {
812 stack.emplace_back(fanin, false, w.node);
813 }
814 }
815 if (direction == TypeTraversalDirection::kFollowOutputs ||
816 direction == TypeTraversalDirection::kFollowInputsAndOutputs) {
817 for (const int fanout : graph_type_view.GetFanout(w.node)) {
818 stack.emplace_back(fanout, false, w.node);
819 }
820 }
821 }
822 }
823
AllowedDataTypes(const OpDef::AttrDef & attr_def)824 DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) {
825 const auto& allowed_types = attr_def.allowed_values().list().type();
826 if (allowed_types.empty()) {
827 return AllTypes();
828 }
829 uint32 dtype_mask = 0;
830 for (int dtype : allowed_types) {
831 dtype_mask |= 1u << dtype;
832 }
833 return DataTypeSet(dtype_mask);
834 }
835
AllowedDataTypes(const OpDef & op_def,const TypeAttrId & t_attr_id)836 DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
837 if (t_attr_id.attr_name.empty()) {
838 return ToSet(t_attr_id.fixed_type);
839 }
840 const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def);
841 CHECK(attr_def); // Crash Ok
842 return AllowedDataTypes(*attr_def);
843 }
844
ValidateLists(const gtl::FlatSet<string> & allow_list,const gtl::FlatSet<string> & deny_list,const gtl::FlatSet<string> & infer_list,const gtl::FlatSet<string> & clear_list)845 Status ValidateLists(const gtl::FlatSet<string>& allow_list,
846 const gtl::FlatSet<string>& deny_list,
847 const gtl::FlatSet<string>& infer_list,
848 const gtl::FlatSet<string>& clear_list) {
849 std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list,
850 clear_list};
851 std::multiset<string> counts;
852 for (const auto& list : lists) {
853 counts.insert(list.begin(), list.end());
854 }
855 bool duplicates = false;
856 for (const auto& s : counts) {
857 if (counts.count(s) > 1) {
858 duplicates = true;
859 LOG(ERROR) << "Op present in multiple lists: " << s;
860 }
861 }
862 if (duplicates) {
863 return errors::InvalidArgument("Op lists have conflicting entries");
864 } else {
865 return Status::OK();
866 }
867 }
868
HasInputOrOutputRefs(const NodeDef & node)869 bool HasInputOrOutputRefs(const NodeDef& node) {
870 const OpDef* op_def;
871 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
872 if (!status.ok()) {
873 return true;
874 }
875 for (const auto& input : op_def->input_arg()) {
876 if (input.is_ref()) {
877 return true;
878 }
879 }
880 for (const auto& output : op_def->output_arg()) {
881 if (output.is_ref()) {
882 return true;
883 }
884 }
885 return false;
886 }
887
888 // See TF issue 25977 for no-FP16 on SCEWL
CanForceFP16(const NodeDef & node)889 bool CanForceFP16(const NodeDef& node) {
890 return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" &&
891 !IsStateful(node) && !HasInputOrOutputRefs(node);
892 }
893
GetCudaVersion(const Cluster & cluster)894 int GetCudaVersion(const Cluster& cluster) {
895 auto devices = cluster.GetDevices();
896 for (const auto& device : devices) {
897 const DeviceProperties& device_properties = device.second;
898 if (device_properties.type() == "GPU") {
899 const auto& device_env = device_properties.environment();
900 auto it = device_env.find("cuda");
901 if (it != device_env.end()) {
902 string cuda_version_str = it->second;
903 return std::stoi(cuda_version_str);
904 }
905 }
906 }
907 return 0;
908 }
909
GetCudnnVersion(const Cluster & cluster)910 int GetCudnnVersion(const Cluster& cluster) {
911 auto devices = cluster.GetDevices();
912 for (const auto& device : devices) {
913 const DeviceProperties& device_properties = device.second;
914 if (device_properties.type() == "GPU") {
915 const auto& device_env = device_properties.environment();
916 auto it = device_env.find("cudnn");
917 if (it != device_env.end()) {
918 string cudnn_version_str = it->second;
919 return std::stoi(cudnn_version_str);
920 }
921 }
922 }
923 return 0;
924 }
925
926 class AutoMixedPrecisionImpl {
927 public:
AutoMixedPrecisionImpl(Cluster * cluster,const std::unordered_set<string> & nodes_to_preserve,GraphDef * graph,string id,AutoMixedPrecisionMode mode)928 AutoMixedPrecisionImpl(Cluster* cluster,
929 const std::unordered_set<string>& nodes_to_preserve,
930 GraphDef* graph, string id,
931 AutoMixedPrecisionMode mode)
932 : virtual_placer_(cluster->GetDevices()),
933 nodes_to_preserve_(nodes_to_preserve),
934 graph_(graph),
935 function_library_(OpRegistry::Global(), graph->library()),
936 id_(id),
937 graph_view_(graph),
938 cuda_version_(GetCudaVersion(*cluster)),
939 cudnn_version_(GetCudnnVersion(*cluster)),
940 mode_(mode),
941 target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
942 : DT_BFLOAT16) {}
943
944 Status Optimize();
945
946 private:
947 typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
948
get_mixed_precision_lists() const949 std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
950 switch (mode_) {
951 case AutoMixedPrecisionMode::CUDA:
952 return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
953 cudnn_version_);
954 case AutoMixedPrecisionMode::MKL:
955 return std::make_unique<AutoMixedPrecisionListsMkl>();
956 }
957 }
958 Status PrintDebugLogs(bool preop, size_t timestamp);
959 void LogSkippedNode(const NodeDef& node) const;
960 bool MustPreserve(const NodeDef& node) const;
961 bool IsOnDevice(const NodeDef& node, const string& device_type) const;
962 bool IsOnSuitableGPUArch(const NodeDef& node) const;
963 bool ShouldProcess(const NodeDef& node) const;
964 bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
965 bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
966 void ConvertBatchNormOpsToV2();
967 bool SupportsF16(const NodeTypeId& node_type) const;
968 bool SupportsF16DataType(const NodeTypeId& node_type) const;
969 const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
970 bool IsSourceOrSinkOp(const string& op) const;
971 void FindFloat32TensorListOpClustersAndDenylistUnsafe(
972 std::vector<absl::flat_hash_set<const NodeDef*>>* clusters,
973 absl::flat_hash_set<int>* deny_set) const;
974 void FindTensorListImplicitFloat32Edges(
975 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
976 std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
977 void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
978 void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const;
979 void PropagateDenyFwdThroughClearAndInfer(
980 absl::flat_hash_set<int>* deny_set) const;
981 void ForceColorMatchBetweenTensorListOps(
982 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
983 absl::flat_hash_set<int>* allow_set,
984 absl::flat_hash_set<int>* deny_set) const;
985 void AddClearAndInferToAllowIfBetweenAllow(
986 const absl::flat_hash_set<int>& deny_set,
987 absl::flat_hash_set<int>* allow_set) const;
988 void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set,
989 absl::flat_hash_set<int>* allow_set) const;
990 Status ForceColorMatchOnRecurrentEdges(
991 absl::flat_hash_set<int>* allow_set) const;
992 void MakeCastsAllowIfAllOutputsAllow(
993 absl::flat_hash_set<int>* allow_set) const;
994 NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
995 const string& device) const;
996 Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
997
998 VirtualPlacer virtual_placer_;
999 std::unordered_set<string> nodes_to_preserve_;
1000 GraphDef* graph_;
1001 FunctionLibraryDefinition function_library_;
1002 string id_;
1003 MutableGraphView graph_view_;
1004 int cuda_version_;
1005 int cudnn_version_;
1006 NodeTypeAttrMap node_type_map_;
1007 GraphTypeTopologyView graph_type_view_;
1008 bool force_all_fp16_;
1009 AutoMixedPrecisionMode mode_;
1010 gtl::FlatSet<string> f16_allowlist_;
1011 gtl::FlatSet<string> f16_denylist_;
1012 gtl::FlatSet<string> f16_inferlist_;
1013 gtl::FlatSet<string> f16_clearlist_;
1014 absl::flat_hash_set<const NodeDef*> should_process_nodes_;
1015 DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
1016 };
1017
BuildCastNode(const MutableGraphView::OutputPort & src,bool to_f16,const string & device) const1018 NodeDef AutoMixedPrecisionImpl::BuildCastNode(
1019 const MutableGraphView::OutputPort& src, bool to_f16,
1020 const string& device) const {
1021 DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
1022 DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
1023 const char* cast_string =
1024 !to_f16 ? kCastToFp32
1025 : target_dtype_ == DT_HALF ? kCastToFp16 : kCastToBf16;
1026 string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
1027 cast_string, "-", kSuffix);
1028 NodeDef node;
1029 node.set_name(name);
1030 node.set_op("Cast");
1031 node.set_device(device);
1032 node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
1033 (*node.mutable_attr())["SrcT"].set_type(src_type);
1034 (*node.mutable_attr())["DstT"].set_type(dst_type);
1035 (*node.mutable_attr())["Truncate"].set_b(false);
1036 return node;
1037 }
1038
NodeHasF16KernelForTypeAttr(const NodeDef & node,TypeAttrId taid) const1039 bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
1040 const NodeDef& node, TypeAttrId taid) const {
1041 NodeDef node_copy(node);
1042 if (node.device().empty()) {
1043 string device_name = virtual_placer_.get_canonical_device_name(node);
1044 node_copy.set_device(device_name);
1045 }
1046 if (!SetDataType(&node_copy, taid, target_dtype_)) {
1047 return false;
1048 }
1049 return IsKernelRegisteredForNode(node_copy).ok();
1050 }
1051
PrintDebugLogs(bool preop,size_t timestamp)1052 Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
1053 string prepend_path;
1054 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1055 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path));
1056 if (prepend_path.empty()) return Status::OK();
1057
1058 string suffix =
1059 strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp);
1060
1061 string fname =
1062 io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb"));
1063 std::fstream f;
1064 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
1065 f << graph_->SerializeAsString();
1066 f.close();
1067 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1068 << " graph as binary to " << fname;
1069
1070 fname = io::JoinPath(prepend_path,
1071 strings::StrCat("graphdef", suffix, ".pb.txt"));
1072 f.open(fname.c_str(), std::fstream::out);
1073 f << graph_->DebugString();
1074 f.close();
1075 LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization")
1076 << " graph as text to " << fname;
1077
1078 if (!preop) {
1079 fname = io::JoinPath(prepend_path,
1080 strings::StrCat("paintbuckets", suffix, ".txt"));
1081 f.open(fname.c_str(), std::fstream::out);
1082 std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1083 get_mixed_precision_lists();
1084 f << "AllowList:\n";
1085 for (const auto& x : mp_lists->AllowList()) {
1086 f << x << "\n";
1087 }
1088 f << "\nDenyList:\n";
1089 for (const auto& x : mp_lists->DenyList()) {
1090 f << x << "\n";
1091 }
1092 f << "\nInferList:\n";
1093 for (const auto& x : mp_lists->InferList()) {
1094 f << x << "\n";
1095 }
1096 f << "\nClearList:\n";
1097 for (const auto& x : mp_lists->ClearList()) {
1098 f << x << "\n";
1099 }
1100 f.close();
1101 LOG(INFO) << "Saved paint bucket info to " << fname;
1102 }
1103 return Status::OK();
1104 }
1105
LogSkippedNode(const NodeDef & node) const1106 void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
1107 VLOG(2) << "Skipping " << node.op() << " node " << node.name()
1108 << " because it "
1109 << (MustPreserve(node)
1110 ? "must be preserved"
1111 : "is not on the GPU, or the GPU arch is not suitable");
1112 }
1113
MustPreserve(const NodeDef & node) const1114 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
1115 return nodes_to_preserve_.count(node.name());
1116 }
1117
IsOnDevice(const NodeDef & node,const string & device_type) const1118 bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
1119 const string& device_type) const {
1120 string device_name;
1121 if (node.device().empty()) {
1122 device_name = virtual_placer_.get_canonical_device_name(node);
1123 } else {
1124 device_name = node.device();
1125 }
1126 string device;
1127 string not_used;
1128 if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
1129 absl::StrContains(absl::AsciiStrToLower(device),
1130 absl::AsciiStrToLower(device_type))) {
1131 return true;
1132 }
1133 return false;
1134 }
1135
1136 // Returns the GPU architecture (compute capability) as a (major, minor) pair.
GetDeviceGPUArch(const DeviceProperties & device_properties)1137 std::pair<int, int> GetDeviceGPUArch(
1138 const DeviceProperties& device_properties) {
1139 if (device_properties.type() != "GPU") return {0, 0};
1140 string arch_str = device_properties.environment().at("architecture");
1141 std::vector<string> split_arch_str = str_util::Split(arch_str, '.');
1142 if (split_arch_str.empty()) {
1143 return {0, 0};
1144 }
1145
1146 int major, minor;
1147 if (!strings::safe_strto32(split_arch_str[0], &major)) {
1148 return {0, 0};
1149 }
1150
1151 if (split_arch_str.size() > 1) {
1152 if (strings::safe_strto32(split_arch_str[1], &minor)) {
1153 return {major, minor};
1154 } else {
1155 return {0, 0};
1156 }
1157 } else {
1158 return {major, 0};
1159 }
1160 }
1161
IsOnSuitableGPUArch(const NodeDef & node) const1162 bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const {
1163 return GetDeviceGPUArch(virtual_placer_.get_device(node)) >= kMinGPUArch;
1164 }
1165
ShouldProcess(const NodeDef & node) const1166 bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const {
1167 return should_process_nodes_.count(&node);
1168 }
1169
IsFloat32(const NodeTypeId & node_type)1170 bool IsFloat32(const NodeTypeId& node_type) {
1171 return GetDataType(*node_type.node, node_type.type_attr) ==
1172 DataType::DT_FLOAT;
1173 }
1174
IsTensorListOp(const string & op)1175 bool IsTensorListOp(const string& op) {
1176 return op.find("TensorList") != string::npos;
1177 }
1178
IsTensorListReaderOp(const string & op)1179 bool IsTensorListReaderOp(const string& op) {
1180 static const gtl::FlatSet<string> tensor_list_reader_ops = {
1181 "TensorListConcat", "TensorListConcatV2", "TensorListGather",
1182 "TensorListGetItem", "TensorListPopBack", "TensorListStack"};
1183 return tensor_list_reader_ops.count(op);
1184 }
1185
IsTensorListWriterOp(const string & op)1186 bool IsTensorListWriterOp(const string& op) {
1187 static const gtl::FlatSet<string> tensor_list_writer_ops = {
1188 "TensorListFromTensor", "TensorListPushBack",
1189 "TensorListPushBackBatch", "TensorListScatter",
1190 "TensorListScatterV2", "TensorListScatterIntoExistingList",
1191 "TensorListSetItem", "TensorListSplit"};
1192 return tensor_list_writer_ops.count(op);
1193 }
1194
SupportsF16(const NodeTypeId & node_type) const1195 bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
1196 const OpDef* op_def;
1197 Status status =
1198 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1199 if (!status.ok()) return false;
1200 return AllowedDataTypes(*op_def, node_type.type_attr)
1201 .Contains(target_dtype_) &&
1202 NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
1203 }
1204
SupportsF16DataType(const NodeTypeId & node_type) const1205 bool AutoMixedPrecisionImpl::SupportsF16DataType(
1206 const NodeTypeId& node_type) const {
1207 const OpDef* op_def;
1208 Status status =
1209 OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
1210 if (!status.ok()) return false;
1211 return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_);
1212 }
1213
1214 // TODO(mconley): Make this change the node's name (to aid debugging). Need to
1215 // make sure that doing this won't break anything.
ConvertBatchNormOpsToV2()1216 void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() {
1217 for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) {
1218 NodeDef* node = graph_->mutable_node(node_idx);
1219 if (!ShouldProcess(*node)) continue;
1220 bool changed = false;
1221 if (node->op() == "FusedBatchNorm") {
1222 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1223 << " to FusedBatchNormV2";
1224 node->set_op("FusedBatchNormV2");
1225 changed = true;
1226 } else if (node->op() == "FusedBatchNormGrad") {
1227 VLOG(2) << "Changing op of " << node->op() << " node " << node->name()
1228 << " to FusedBatchNormGradV2";
1229 node->set_op("FusedBatchNormGradV2");
1230 changed = true;
1231 }
1232 if (changed) {
1233 (*node->mutable_attr())["U"].set_type(DT_FLOAT);
1234 }
1235 }
1236 }
1237
1238 // A helper function to decide whether to ignore the effect on performance when
1239 // rewriting the graph. This can be useful for testing the numerical effects of
1240 // reduced precision on systems that have poor mixed precision performance.
ShouldIgnorePerformance()1241 bool ShouldIgnorePerformance() {
1242 static bool is_enabled = [] {
1243 bool ret = false;
1244 TF_CHECK_OK(ReadBoolFromEnvVar(
1245 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE",
1246 /*default_val=*/false, &ret));
1247 return ret;
1248 }();
1249 return is_enabled;
1250 }
1251
Optimize()1252 Status AutoMixedPrecisionImpl::Optimize() {
1253 string optimization_level;
1254 TF_RETURN_IF_ERROR(ReadStringFromEnvVar(
1255 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
1256 optimization_level = absl::AsciiStrToUpper(optimization_level);
1257 force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
1258 if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
1259 // Many ops do not support bfloat16 on the CPU so we disallowing forcing to
1260 // bfloat16.
1261 return errors::InvalidArgument(
1262 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
1263 "UNSAFE_FORCE_ALL when MKL is used");
1264 }
1265
1266 std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
1267 get_mixed_precision_lists();
1268 f16_allowlist_ = mp_lists->AllowList();
1269 f16_denylist_ = mp_lists->DenyList();
1270 f16_inferlist_ = mp_lists->InferList();
1271 f16_clearlist_ = mp_lists->ClearList();
1272 TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_,
1273 f16_inferlist_, f16_clearlist_));
1274
1275 size_t timestamp = Env::Default()->NowMicros() / 1000;
1276 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
1277
1278 VLOG(2) << "Identifying nodes that should be processed";
1279 for (const NodeDef& node : graph_->node()) {
1280 bool should_process;
1281 switch (mode_) {
1282 case AutoMixedPrecisionMode::CUDA:
1283 should_process =
1284 !MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
1285 (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
1286 break;
1287 case AutoMixedPrecisionMode::MKL:
1288 should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
1289 break;
1290 }
1291 if (should_process) {
1292 should_process_nodes_.insert(&node);
1293 } else {
1294 LogSkippedNode(node);
1295 }
1296 }
1297
1298 VLOG(2) << "Converting FusedBatchNorm* ops to V2";
1299 ConvertBatchNormOpsToV2();
1300
1301 VLOG(2) << "Building node type map for graph";
1302 TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_));
1303
1304 VLOG(2) << "Constructing graph type attribute topology view";
1305 TF_RETURN_IF_ERROR(
1306 graph_type_view_.InitializeFromGraph(*graph_, node_type_map_));
1307
1308 absl::flat_hash_set<int> deny_set;
1309
1310 std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters;
1311 FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters,
1312 &deny_set);
1313 std::vector<NodeTypeIdEdge> ephemeral_edges;
1314 for (const auto& cluster : tensor_list_clusters) {
1315 VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
1316 for (const NodeDef* node : cluster) {
1317 VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
1318 }
1319 FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
1320 }
1321 TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
1322
1323 // The goal here is to change performance-critical ops to fp16 or bf16, and to
1324 // do so with the minimal number of casts, subject to the constraint that the
1325 // model's convergence is not affected. This is achieved by first identifying
1326 // which nodes should be changed to f16 and then inserting casts at the
1327 // boundaries between f16/non-f16 nodes.
1328
1329 // The algorithm for deciding which nodes to change to f16 is as follows:
1330 // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
1331 // This is done under the assumption that allowlist ops are always
1332 // numerically-safe in f16 and that they are the most important ops for
1333 // improving performance.
1334 // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka
1335 // "denylist" ops) or they are on a forward path from a denylist node to
1336 // a deny/infer node (including the node at the end of the path) through
1337 // non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops).
1338 // This is done to prevent numerically-dangerous ops and their downstream
1339 // effects from being changed to f16, which would risk breaking the
1340 // numerical accuracy of the model.
1341 // 3) For all remaining nodes that are not considered dangerous (inferlist
1342 // and clearlist ops), find those that are between (i.e., both upstream
1343 // and downstream of) allow nodes, and add them to the allow_set.
1344 // This is done to avoid unnecessary casts between allowlist ops.
1345 // 4) For all remaining clearlist nodes, add them to the allow_set if they are
1346 // connected to a node in the allow_set via other clearlist nodes.
1347 // This is done to increase the number of ops in the allow_set without
1348 // affecting numerical stability.
1349
1350 absl::flat_hash_set<int> allow_set;
1351 VLOG(2) << "Beginning pass 1 to add allowlist ops";
1352 AddAllowlistOps(&allow_set);
1353 VLOG(2) << "Finished pass 1";
1354
1355 if (allow_set.empty()) {
1356 LOG(INFO) << "No allowlist ops found, nothing to do";
1357 return Status::OK();
1358 }
1359
1360 VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops "
1361 "through clear/inferlist ops";
1362 PropagateDenyFwdThroughClearAndInfer(&deny_set);
1363 VLOG(2) << "Finished pass 2";
1364
1365 VLOG(2) << "Forcing color match between data structure ops";
1366 for (const auto& cluster : tensor_list_clusters) {
1367 ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1368 }
1369
1370 VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they "
1371 "are between allow ops";
1372 AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set);
1373 VLOG(2) << "Finished pass 3";
1374
1375 VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
1376 "clearlist ops";
1377 PropagateAllowThroughClear(deny_set, &allow_set);
1378 VLOG(2) << "Finished pass 4";
1379
1380 VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed "
1381 "to F16"
1382 "from allow set";
1383 RemoveAllowsetWithFp32(&allow_set);
1384 VLOG(2) << "Finished pass 5";
1385
1386 VLOG(2) << "Forcing color match between data structure ops";
1387 for (const auto& cluster : tensor_list_clusters) {
1388 ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set);
1389 }
1390
1391 VLOG(2) << "Forcing color match on loop edges";
1392 TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
1393
1394 VLOG(2) << "Finding existing casts that can be made allow";
1395 MakeCastsAllowIfAllOutputsAllow(&allow_set);
1396
1397 VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
1398 "ops at paint boundaries";
1399 TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
1400 VLOG(2) << "Finished final pass";
1401
1402 TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
1403
1404 return Status::OK();
1405 }
1406
1407 // If node is a Tensor List op with a float32 data type attribute then this
1408 // returns a pointer to the NodeTypeId representing that type attribute. In
1409 // all other cases this returns nullptr.
GetTensorListFloat32NodeTypeId(const NodeDef & node) const1410 const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId(
1411 const NodeDef& node) const {
1412 if (!IsTensorListOp(node.op())) return nullptr;
1413 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) {
1414 const NodeTypeId* node_type =
1415 graph_type_view_.GetNode(node.name(), type_attr);
1416 // This assumes that the float32 data type on a Tensor List op is always a
1417 // non-fixed type attribute containing a single type, and that this type
1418 // attribute represents the dtype of the values in the list.
1419 // TODO(benbarsdell): A new Tensor List op could theoretically break these
1420 // assumptions.
1421 if (node_type && node_type->type_attr.fixed_type == DT_INVALID &&
1422 node_type->type_attr.type_index == TypeAttrId::kSingleType &&
1423 IsFloat32(*node_type)) {
1424 return node_type;
1425 }
1426 }
1427 return nullptr;
1428 }
1429
IsSourceOrSinkOp(const string & op) const1430 bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const {
1431 const gtl::FlatSet<string> source_and_sink_ops = {
1432 "_Arg",
1433 "_Retval",
1434 "OptionalFromValue",
1435 "OptionalGetValue",
1436 "PartitionedCall",
1437 "Placeholder",
1438 "StatefulPartitionedCall",
1439 };
1440 return source_and_sink_ops.count(op) || function_library_.Find(op);
1441 }
1442
1443 // Finds all clusters of float32 Tensor List nodes that are connected via their
1444 // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges
1445 // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc.
1446 // nodes) are added to deny_set. The caller should paint all nodes in a cluster
1447 // the same color, as they may all refer to the same Tensor List.
FindFloat32TensorListOpClustersAndDenylistUnsafe(std::vector<absl::flat_hash_set<const NodeDef * >> * tensor_list_clusters,absl::flat_hash_set<int> * deny_set) const1448 void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe(
1449 std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters,
1450 absl::flat_hash_set<int>* deny_set) const {
1451 absl::flat_hash_set<const NodeDef*> tensor_list_prop_set;
1452 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1453 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1454 if (!ShouldProcess(*root.node) ||
1455 root.type_attr.fixed_type != DataType::DT_VARIANT ||
1456 !GetTensorListFloat32NodeTypeId(*root.node) ||
1457 tensor_list_prop_set.count(root.node)) {
1458 continue;
1459 }
1460 const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1461 const absl::optional<int> maybe_root_fp32_idx =
1462 graph_type_view_.GetNodeIndex(*root_fp32);
1463 DCHECK(maybe_root_fp32_idx.has_value())
1464 << "Type attribute " << root_fp32->type_attr.DebugString()
1465 << " of node " << root.node->name() << " not found in graph view";
1466 int root_fp32_idx = maybe_root_fp32_idx.value();
1467 // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all
1468 // connected Tensor List nodes.
1469 absl::flat_hash_set<const NodeDef*> cluster({root.node});
1470 DfsTypeTraversal(graph_type_view_, {&root},
1471 TypeTraversalDirection::kFollowInputsAndOutputs,
1472 DfsTypePredicates::Enter([&](int idx) -> bool {
1473 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1474 return !tensor_list_prop_set.count(item.node);
1475 }),
1476 DfsTypeCallbacks::PreOrder([&](int idx) {
1477 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1478 const NodeDef* node = item.node;
1479 if (GetTensorListFloat32NodeTypeId(*node)) {
1480 cluster.insert(node);
1481 if (!ShouldProcess(*node)) {
1482 // The cluster contains an un-processable node.
1483 deny_set->insert(root_fp32_idx);
1484 }
1485 // TODO(benbarsdell): In a theoretical pathological
1486 // case of a Tensor List of Tensor List handles, the
1487 // Tensor List itself would need to be treated as a
1488 // sink.
1489 } else if (IsSourceOrSinkOp(node->op())) {
1490 // The cluster crosses an untraversable boundary.
1491 deny_set->insert(root_fp32_idx);
1492 }
1493 }));
1494 tensor_list_clusters->push_back(cluster);
1495 }
1496 }
1497
1498 // Finds all writer -> reader pairs in the given set that are connected via
1499 // their handles, and adds corresponding float32 edges to *implicit_fp32_edges.
FindTensorListImplicitFloat32Edges(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,std::vector<NodeTypeIdEdge> * implicit_fp32_edges) const1500 void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
1501 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1502 std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const {
1503 for (const NodeDef* root_node : tensor_list_nodes) {
1504 if (!IsTensorListReaderOp(root_node->op())) continue;
1505 NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT));
1506 const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node);
1507 CHECK(root_fp32) << "No float32 type attribute found for " // Crash OK
1508 << root.node->op() << " node " << root.node->name();
1509 // Search backwards through handle edges (DT_VARIANT) for all writer ops,
1510 // adding direct implicit edges between them and the reader.
1511 DfsTypeTraversal(
1512 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1513 DfsTypePredicates::Enter([&](int idx) -> bool {
1514 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1515 return ShouldProcess(*item.node);
1516 }),
1517 DfsTypeCallbacks::PreOrder([&](int idx) {
1518 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1519 if (IsTensorListWriterOp(item.node->op())) {
1520 const NodeTypeId* item_fp32 =
1521 GetTensorListFloat32NodeTypeId(*item.node);
1522 CHECK(item_fp32) // Crash OK
1523 << "No float32 type attribute found for " << item.node->op()
1524 << " node " << item.node->name();
1525 VLOG(2) << "Adding ephemeral float32 edge from "
1526 << item_fp32->node->op() << " node "
1527 << item_fp32->node->name() << " to "
1528 << root_fp32->node->op() << " node "
1529 << root_fp32->node->name();
1530 implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32);
1531 }
1532 }));
1533 }
1534 }
1535
AddAllowlistOps(absl::flat_hash_set<int> * allow_set) const1536 void AutoMixedPrecisionImpl::AddAllowlistOps(
1537 absl::flat_hash_set<int>* allow_set) const {
1538 // Add allowlisted ops to allow_set.
1539 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1540 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1541 if (!ShouldProcess(*root.node)) continue;
1542 bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
1543 if (f16_allowlist_.count(root.node->op()) || force_allow) {
1544 bool inserted = allow_set->insert(root_idx).second;
1545 if (VLOG_IS_ON(2) && inserted) {
1546 VLOG(2) << "Painting type " << root.type_attr.DebugString()
1547 << " of node " << root.node->name() << " ALLOW because its op "
1548 << root.node->op() << " is on the allowlist";
1549 }
1550 }
1551 }
1552 }
1553
1554 // Adds nodes to deny_set iff they are on the denylist or they are on a
1555 // forward path from a denylist node to a deny/infer node (including the node
1556 // at the end of the path) through clear and infer nodes.
1557 // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer
1558 // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer.
PropagateDenyFwdThroughClearAndInfer(absl::flat_hash_set<int> * deny_set) const1559 void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer(
1560 absl::flat_hash_set<int>* deny_set) const {
1561 if (force_all_fp16_) return;
1562
1563 // Find clear nodes that are upstream of deny or infer.
1564 absl::flat_hash_set<int> upstream_of_deny_or_infer_set;
1565 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1566 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1567 if (!(f16_denylist_.count(root.node->op()) ||
1568 f16_inferlist_.count(root.node->op()))) {
1569 continue;
1570 }
1571 DfsTypeTraversal(graph_type_view_, {&root},
1572 TypeTraversalDirection::kFollowInputs,
1573 DfsTypePredicates::Enter([&](int idx) -> bool {
1574 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1575 return idx == root_idx ||
1576 (!upstream_of_deny_or_infer_set.count(idx) &&
1577 f16_clearlist_.count(item.node->op()));
1578 }),
1579 DfsTypeCallbacks::PreOrder([&](int idx) {
1580 upstream_of_deny_or_infer_set.insert(idx);
1581 }));
1582 }
1583
1584 // Propagate deny forward through nodes in upstream_of_deny_or_infer_set.
1585 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1586 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1587 if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) {
1588 continue;
1589 }
1590 DfsTypeTraversal(
1591 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1592 DfsTypePredicates::Enter([&](int idx) -> bool {
1593 return idx == root_idx || (!deny_set->count(idx) &&
1594 upstream_of_deny_or_infer_set.count(idx));
1595 }),
1596 DfsTypeCallbacks::PreOrder([&](int idx) {
1597 bool inserted = deny_set->insert(idx).second;
1598 if (VLOG_IS_ON(2) && inserted) {
1599 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1600 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1601 << " of " << item.node->op() << " node "
1602 << item.node->name() << " DENY";
1603 }
1604 }));
1605 }
1606 }
1607
AddClearAndInferToAllowIfBetweenAllow(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1608 void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow(
1609 const absl::flat_hash_set<int>& deny_set,
1610 absl::flat_hash_set<int>* allow_set) const {
1611 // Find clear/inferlist ops that are downstream of allow ops.
1612 absl::flat_hash_set<int> downstream_of_allow_set;
1613 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1614 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1615 if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
1616 continue;
1617 }
1618 DfsTypeTraversal(
1619 graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs,
1620 DfsTypePredicates::Enter([&](int idx) -> bool {
1621 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1622 return idx == root_idx ||
1623 (!downstream_of_allow_set.count(idx) &&
1624 !f16_allowlist_.count(item.node->op()) &&
1625 !deny_set.count(idx) && ShouldProcess(*item.node) &&
1626 // TODO(benbarsdell): Consider allowing propagation through
1627 // ops that are already float16 in order to reduce the number
1628 // of casts.
1629 IsFloat32(item) && SupportsF16(item) &&
1630 (f16_clearlist_.count(item.node->op()) ||
1631 f16_inferlist_.count(item.node->op())));
1632 }),
1633 DfsTypeCallbacks::PreOrder(
1634 [&](int idx) { downstream_of_allow_set.insert(idx); }));
1635 }
1636
1637 // Set nodes that are both downstream and upstream of allow ops to allow.
1638 absl::flat_hash_set<int> upstream_of_allow_set;
1639 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1640 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1641 if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
1642 !f16_allowlist_.count(root.node->op())) {
1643 continue;
1644 }
1645 DfsTypeTraversal(
1646 graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
1647 DfsTypePredicates::Enter([&](int idx) -> bool {
1648 return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
1649 downstream_of_allow_set.count(idx));
1650 }),
1651 DfsTypeCallbacks::PreOrder([&](int idx) {
1652 upstream_of_allow_set.insert(idx);
1653 bool inserted = allow_set->insert(idx).second;
1654 if (VLOG_IS_ON(2) && inserted) {
1655 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1656 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1657 << " of " << item.node->op() << " node "
1658 << item.node->name() << " ALLOW";
1659 }
1660 }));
1661 }
1662 }
1663
PropagateAllowThroughClear(const absl::flat_hash_set<int> & deny_set,absl::flat_hash_set<int> * allow_set) const1664 void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
1665 const absl::flat_hash_set<int>& deny_set,
1666 absl::flat_hash_set<int>* allow_set) const {
1667 // Propagate allow from allow nodes through clearlist ops.
1668 absl::flat_hash_set<int> clear_prop_set;
1669 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1670 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1671 if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
1672 !allow_set->count(root_idx)) {
1673 continue;
1674 }
1675 DfsTypeTraversal(
1676 graph_type_view_, {&root},
1677 TypeTraversalDirection::kFollowInputsAndOutputs,
1678 DfsTypePredicates::Enter([&](int idx) -> bool {
1679 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1680 return idx == root_idx ||
1681 (!allow_set->count(idx) && !deny_set.count(idx) &&
1682 ShouldProcess(*item.node) && IsFloat32(item) &&
1683 SupportsF16(item) &&
1684 (f16_clearlist_.count(item.node->op())) &&
1685 // We don't propagate (backwards) through nodes that read
1686 // Variables because it can break the behavior of TensorBoard
1687 // visualization and/or (in the case of Enter nodes) the model
1688 // itself. This is only a problem for non-resource variables.
1689 !NodeImplicitlyReadsNonResourceVariable(*item.node));
1690 }),
1691 DfsTypeCallbacks::PreOrder([&](int idx) {
1692 clear_prop_set.insert(idx);
1693 bool inserted = allow_set->insert(idx).second;
1694 if (VLOG_IS_ON(2) && inserted) {
1695 const NodeTypeId& item = *graph_type_view_.GetNode(idx);
1696 VLOG(2) << "Painting type " << item.type_attr.DebugString()
1697 << " of " << item.node->op() << " node "
1698 << item.node->name() << " ALLOW";
1699 }
1700 }));
1701 }
1702 }
1703
1704 // If ops have one or more type_attr, But this type_attr could not be converted
1705 // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only
1706 // support float. So we will remove this node from allow_set.
RemoveAllowsetWithFp32(absl::flat_hash_set<int> * allow_set) const1707 void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32(
1708 absl::flat_hash_set<int>* allow_set) const {
1709 for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
1710 const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
1711 if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) &&
1712 !SupportsF16DataType(root)) {
1713 auto erased = allow_set->erase(root_idx);
1714 if (VLOG_IS_ON(2) && erased) {
1715 VLOG(2) << "UnPainting type " << root.type_attr.DebugString()
1716 << " of node " << root.node->name() << " ALLOW because its op "
1717 << root.node->op() << " is not support F16 DataType";
1718 }
1719 }
1720 }
1721 }
1722
1723 // Forces NextIteration nodes and their output Merge node(s) to have the same
1724 // color. Specifically, it removes them all from allow_set if any of the Merge
1725 // nodes is not in allow_set, otherwise it adds the NextIteration node to
1726 // allow_set.
ForceColorMatchOnRecurrentEdges(absl::flat_hash_set<int> * allow_set) const1727 Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
1728 absl::flat_hash_set<int>* allow_set) const {
1729 for (const NodeDef& node : graph_->node()) {
1730 if (node.op() == "NextIteration") {
1731 GraphView::OutputPort output_port(&node, 0);
1732 const auto& fanout = graph_view_.GetFanout(output_port);
1733 std::vector<int> merge_idxs;
1734 merge_idxs.reserve(fanout.size());
1735 bool any_merge_is_not_allow = false;
1736 for (const auto& output : fanout) {
1737 const NodeDef& merge_node = *output.node;
1738 if (merge_node.op() != "Merge") {
1739 return errors::FailedPrecondition(
1740 "Expected Merge node after NextIteration, got ", merge_node.op());
1741 }
1742 const absl::optional<int> maybe_merge_idx =
1743 graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T"));
1744 if (!maybe_merge_idx.has_value()) {
1745 return errors::Internal("Type attribute T of Merge node ",
1746 merge_node.name(),
1747 " not found in graph view");
1748 }
1749 int merge_idx = maybe_merge_idx.value();
1750 merge_idxs.push_back(merge_idx);
1751 any_merge_is_not_allow =
1752 any_merge_is_not_allow || !allow_set->count(merge_idx);
1753 }
1754 const absl::optional<int> maybe_nextiter_idx =
1755 graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
1756 if (!maybe_nextiter_idx.has_value()) {
1757 return errors::Internal("Type attribute T of NextIteration node ",
1758 node.name(), " not found in graph view");
1759 }
1760 int nextiter_idx = maybe_nextiter_idx.value();
1761 if (any_merge_is_not_allow) {
1762 for (int merge_idx : merge_idxs) {
1763 if (allow_set->erase(merge_idx)) {
1764 VLOG(2) << "Painting type T of Merge node "
1765 << graph_type_view_.GetNode(merge_idx)->node->name()
1766 << " DENY to match the color of its sibling Merge nodes "
1767 "with common NextIteration node "
1768 << node.name();
1769 }
1770 }
1771 if (allow_set->erase(nextiter_idx)) {
1772 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1773 << " DENY to match the color of its output Merge node(s)";
1774 }
1775 } else {
1776 if (allow_set->insert(nextiter_idx).second) {
1777 VLOG(2) << "Painting type T of NextIteration node " << node.name()
1778 << " ALLOW to match the color of its output Merge node(s)";
1779 }
1780 }
1781 }
1782 }
1783 return Status::OK();
1784 }
1785
1786 // Forces all of the given Tensor List nodes into the same color set.
ForceColorMatchBetweenTensorListOps(const absl::flat_hash_set<const NodeDef * > & tensor_list_nodes,absl::flat_hash_set<int> * allow_set,absl::flat_hash_set<int> * deny_set) const1787 void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
1788 const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
1789 absl::flat_hash_set<int>* allow_set,
1790 absl::flat_hash_set<int>* deny_set) const {
1791 bool any_deny = false;
1792 bool any_allow = false;
1793 std::vector<int> node_type_idxs;
1794 node_type_idxs.reserve(tensor_list_nodes.size());
1795 for (const NodeDef* node : tensor_list_nodes) {
1796 const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node);
1797 const absl::optional<int> maybe_node_type_idx =
1798 graph_type_view_.GetNodeIndex(node_type);
1799 DCHECK(maybe_node_type_idx.has_value())
1800 << "Type attribute " << node_type.type_attr.DebugString() << " of node "
1801 << node->name() << " not found in graph view";
1802 node_type_idxs.push_back(maybe_node_type_idx.value());
1803 }
1804 for (int node_type_idx : node_type_idxs) {
1805 if (deny_set->count(node_type_idx)) {
1806 any_deny = true;
1807 break;
1808 } else if (allow_set->count(node_type_idx)) {
1809 any_allow = true;
1810 }
1811 }
1812 if (!any_deny && !any_allow) return;
1813 for (int node_type_idx : node_type_idxs) {
1814 const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
1815 VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
1816 << node_type.node->op() << " node " << node_type.node->name() << " "
1817 << (any_deny ? "DENY" : "ALLOW")
1818 << " because at least one of its siblings is "
1819 << (any_deny ? "DENY" : "ALLOW");
1820 if (any_deny) {
1821 allow_set->erase(node_type_idx);
1822 deny_set->insert(node_type_idx);
1823 } else {
1824 allow_set->insert(node_type_idx);
1825 }
1826 }
1827 }
1828
NodeImplicitlyReadsNonResourceVariable(const NodeDef & node) const1829 bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
1830 const NodeDef& node) const {
1831 if (node.op() == "Identity" || node.op() == "Enter") {
1832 GraphView::InputPort node_input(&node, 0);
1833 MutableGraphView::OutputPort prev_output =
1834 graph_view_.GetRegularFanin(node_input);
1835 const NodeDef* input = prev_output.node;
1836 if (input && ((node.op() == "Identity" && (input->op() == "Variable" ||
1837 input->op() == "VariableV2")) ||
1838 (node.op() == "Enter" &&
1839 NodeImplicitlyReadsNonResourceVariable(*input)))) {
1840 return true;
1841 }
1842 }
1843 return false;
1844 }
1845
1846 // This adds existing Cast nodes to allow_set if all of their outputs are allow,
1847 // avoiding the need to add a new Cast node after an existing Cast.
MakeCastsAllowIfAllOutputsAllow(absl::flat_hash_set<int> * allow_set) const1848 void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
1849 absl::flat_hash_set<int>* allow_set) const {
1850 int num_nodes_preop = graph_->node_size();
1851 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1852 NodeDef* node = graph_->mutable_node(node_idx);
1853 NodeTypeId node_type(node, TypeAttrId("DstT"));
1854 if (node->op() != "Cast" || !IsFloat32(node_type)) {
1855 continue;
1856 }
1857 bool all_fanouts_allow = true;
1858 MutableGraphView::OutputPort src(node, 0);
1859 const auto& fanout = graph_view_.GetFanout(src);
1860 for (const MutableGraphView::InputPort& dst : fanout) {
1861 TypeAttrId dst_type_attr =
1862 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1863 const absl::optional<int> maybe_dst_type_idx =
1864 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1865 DCHECK(maybe_dst_type_idx.has_value())
1866 << "Type attribute " << dst_type_attr.DebugString() << " of node "
1867 << dst.node->name() << " not found in graph view";
1868 int dst_type_idx = maybe_dst_type_idx.value();
1869 bool dst_is_allow = allow_set->count(dst_type_idx);
1870 if (!dst_is_allow) {
1871 all_fanouts_allow = false;
1872 break;
1873 }
1874 }
1875 if (!fanout.empty() && all_fanouts_allow) {
1876 const absl::optional<int> maybe_node_type_idx =
1877 graph_type_view_.GetNodeIndex(node_type);
1878 DCHECK(maybe_node_type_idx.has_value())
1879 << "Type attribute " << node_type.type_attr.DebugString()
1880 << " of node " << node_type.node->name()
1881 << " not found in graph view";
1882 int node_type_idx = maybe_node_type_idx.value();
1883 allow_set->insert(node_type_idx);
1884 }
1885 }
1886 }
1887
1888 // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
1889 // inserts Cast nodes at node outputs for all edges that connect
1890 // allow-painted <-> non-allow-painted type attributes.
ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int> & allow_set)1891 Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
1892 const absl::flat_hash_set<int>& allow_set) {
1893 int num_nodes_changed = 0;
1894 int num_nonvar_casts_to_f16 = 0;
1895 int num_nodes_preop = graph_->node_size();
1896 for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
1897 NodeDef* node = graph_->mutable_node(node_idx);
1898 for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) {
1899 const absl::optional<int> maybe_node_type_idx =
1900 graph_type_view_.GetNodeIndex(node->name(), type_attr);
1901 if (!maybe_node_type_idx.has_value()) {
1902 return errors::Internal("Type attribute ", type_attr.DebugString(),
1903 " of ", node->op(), " node ", node->name(),
1904 " not found in graph view");
1905 }
1906 int node_type_idx = maybe_node_type_idx.value();
1907 if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
1908 bool src_is_allow = allow_set.count(node_type_idx);
1909 if (src_is_allow) {
1910 VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
1911 << node->op() << " node " << node->name() << " to "
1912 << DataTypeString(target_dtype_);
1913 if (!SetDataType(node, type_attr, target_dtype_)) {
1914 return errors::Internal("Failed to set type attribute");
1915 }
1916 ++num_nodes_changed;
1917 }
1918 for (int output_port : node_type_map_.GetOutputPorts(*node, type_attr)) {
1919 MutableGraphView::OutputPort src(node, output_port);
1920 NodeDef* added_cast_node = nullptr;
1921 // Note: This is copied so that edges can be modified inside the loop.
1922 auto fanout = graph_view_.GetFanout(src);
1923 for (const MutableGraphView::InputPort& dst : fanout) {
1924 TypeAttrId dst_type_attr =
1925 node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id);
1926 const absl::optional<int> maybe_dst_type_idx =
1927 graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr);
1928 if (!maybe_dst_type_idx.has_value()) {
1929 return errors::Internal("Type attribute ",
1930 dst_type_attr.DebugString(), " of ",
1931 dst.node->op(), " node ", dst.node->name(),
1932 " not found in graph view");
1933 }
1934 int dst_type_idx = maybe_dst_type_idx.value();
1935 bool dst_is_allow = allow_set.count(dst_type_idx);
1936 if (src_is_allow != dst_is_allow) {
1937 if (!added_cast_node) {
1938 bool to_f16 = dst_is_allow;
1939 VLOG(1) << "Inserting cast to "
1940 << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
1941 << " at " << src.node->op() << " " << src.node->name()
1942 << ":" << src.port_id;
1943 added_cast_node = graph_view_.AddNode(
1944 BuildCastNode(src, to_f16, src.node->device()));
1945 if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
1946 !NodeImplicitlyReadsNonResourceVariable(*node)) {
1947 ++num_nonvar_casts_to_f16;
1948 }
1949 }
1950 TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
1951 dst.node->name(), dst.port_id, {added_cast_node->name(), 0}));
1952 }
1953 }
1954 }
1955 }
1956 }
1957 // Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
1958 // since many Python users will see this message.
1959 const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
1960 LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
1961 << " nodes to " << type_str << " precision using "
1962 << num_nonvar_casts_to_f16 << " cast(s) to " << type_str
1963 << " (excluding Const and Variable casts)";
1964 return Status::OK();
1965 }
1966
GetNumGPUs(const Cluster & cluster,const std::pair<int,int> & min_arch={0, 0})1967 int GetNumGPUs(const Cluster& cluster,
1968 const std::pair<int, int>& min_arch = {0, 0}) {
1969 auto devices = cluster.GetDevices();
1970 int num_gpus = 0;
1971 for (const auto& device : devices) {
1972 const DeviceProperties& device_properties = device.second;
1973 std::pair<int, int> arch = GetDeviceGPUArch(device_properties);
1974 if (device_properties.type() == "GPU" && arch >= min_arch) {
1975 num_gpus++;
1976 }
1977 }
1978 return num_gpus;
1979 }
1980
1981 } // end namespace
1982
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)1983 Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
1984 GraphDef* output) {
1985 if (cluster == nullptr) {
1986 return errors::InvalidArgument("cluster == nullptr");
1987 }
1988
1989 #if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
1990 if (mode_ == AutoMixedPrecisionMode::MKL) {
1991 return errors::Unimplemented(
1992 "The auto_mixed_precision_mkl optimizer cannot be used since "
1993 "this build of TensorFlow is not compiled with MKL support for "
1994 "bfloat16. "
1995 "For information on MKL builds, see: "
1996 "https://software.intel.com/en-us/articles/intel-optimization-for-"
1997 "tensorflow-installation-guide");
1998 }
1999 #endif
2000
2001 // Start by copying input graph to output.
2002 *output = item.graph;
2003
2004 int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
2005 : GetNumGPUs(*cluster, kMinGPUArch);
2006 if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
2007 // AutoMixedPrecision is currently only tuned for GPU.
2008 LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
2009 << " graph optimizer";
2010 return Status::OK();
2011 }
2012
2013 // Optimize the output graph in-place.
2014 AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
2015 item.id, mode_);
2016 if (item.id == "tf_graph") {
2017 LOG(INFO) << "Running " << name() << " graph optimizer";
2018 } else {
2019 VLOG(1) << "Running " << name() << " graph optimizer on " << item.id;
2020 }
2021 Status status = optimizer.Optimize();
2022 if (!status.ok()) {
2023 // Restore the original graph.
2024 *output = item.graph;
2025 LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString();
2026 }
2027 return status;
2028 }
2029
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)2030 void AutoMixedPrecision::Feedback(Cluster* cluster, const GrapplerItem& item,
2031 const GraphDef& optimize_output,
2032 double result) {
2033 // Nothing to do for AutoMixedPrecision.
2034 }
2035
2036 } // end namespace grappler
2037 } // end namespace tensorflow
2038