1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/optimization_registry.h"
17 #include "tensorflow/core/graph/node_builder.h"
18
19 namespace tensorflow {
20 namespace {
21
make_zeros(const DataType & dtype,const TensorShapeProto & shape)22 Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) {
23 Tensor tensor(dtype, TensorShape(shape));
24
25 // Conveniently, all numeric data types have 0x0 == zero. Otherwise we would
26 // need a giant switch statement here.
27 memset(const_cast<char*>(tensor.tensor_data().data()), 0,
28 tensor.tensor_data().size());
29
30 return tensor;
31 }
32
33 // Replaces occurrences of the "AccumulateNV2" stub operator with a graph of
34 // lower-level ops. The graph is equivalent (modulo certain corner cases)
35 // to the semantics of the original accumulate_n() Python op in math_ops.py.
36 // Implementing the op with a rewrite allows this new variant of accumulate_n
37 // to be differentiable.
38 //
39 // The binary code that generates AccumulateNV2 stub ops is located in a
40 // dynamic library built out of tensorflow/contrib/framework. Ideally, this
41 // class would also be in contrib, but calls to REGISTER_OPTIMIZATION() from
42 // third-party libraries aren't currently supported.
43 class AccumulateNV2RemovePass : public GraphOptimizationPass {
44 public:
Run(const GraphOptimizationPassOptions & options)45 Status Run(const GraphOptimizationPassOptions& options) override {
46 // TODO(freiss.oss@gmail.com): Substantial shared code with
47 // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
48 // a third similar rewrite.
49 if (options.graph == nullptr) {
50 // TODO(apassos) returning OK feels weird here as we can't do anything
51 // without a graph, but some tests require this.
52 return Status::OK();
53 }
54
55 Graph* g = options.graph->get();
56 if (g == nullptr) {
57 return errors::Internal(
58 "AccumulateNV2 removal should happen before partitioning and a "
59 "graph should be available.");
60 }
61
62 // Build up a todo list of ops to replace, *then* modify the graph
63 gtl::InlinedVector<Node*, 2> matches;
64 for (Node* n : g->op_nodes()) {
65 if (n->type_string() == "AccumulateNV2") {
66 matches.push_back(n);
67 }
68 }
69 for (Node* n : matches) {
70 TF_RETURN_IF_ERROR(rewriteNode(n, g));
71 }
72 return Status::OK();
73 }
74
rewriteNode(Node * n,Graph * g)75 Status rewriteNode(Node* n, Graph* g) {
76 AttrSlice n_attrs = n->attrs();
77 auto base_make_node = [n, &n_attrs](const string& op, const string& name) {
78 NodeDebugInfo debug_info(*n);
79 NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
80
81 // The pieces of AccumulateNV2 should all be on the same node.
82 node_builder.Device(n->requested_device());
83 string colo;
84 if (GetNodeAttr(n_attrs, kColocationAttrName, &colo).ok()) {
85 node_builder.Attr(kColocationAttrName, colo);
86 }
87 return node_builder;
88 };
89 auto make_node = [n, g, &base_make_node](string op) {
90 return base_make_node(
91 op, g->NewName(strings::StrCat(n->name(), "/Internal")));
92 };
93
94 DataType dtype;
95 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
96 TensorShapeProto shape;
97 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
98
99 std::vector<const Edge*> data_edges, control_edges;
100 for (const Edge* input_edge : n->in_edges()) {
101 if (input_edge->IsControlEdge()) {
102 control_edges.push_back(input_edge);
103 } else {
104 data_edges.push_back(input_edge);
105 }
106 }
107
108 // Create the following ops to replace the AccumulateNV2 placeholder:
109 Node* create_accumulator = nullptr; // TemporaryVariable op
110 Node* initial_val = nullptr; // Const op
111 Node* initialize_accumulator = nullptr; // Assign op
112 std::vector<Node*> add_values_to_accumulator; // AssignAdd ops
113 Node* clean_up_accumulator = nullptr; // DestroyTemporaryVariable
114
115 const string accumulator_name =
116 strings::StrCat(n->name(), "/Internal/Accumulator");
117 TensorShapeProto variable_shape;
118 variable_shape.add_dim()->set_size(0);
119 TF_RETURN_IF_ERROR(make_node("TemporaryVariable")
120 .Attr("shape", variable_shape)
121 .Attr("dtype", dtype)
122 .Attr("var_name", accumulator_name)
123 .Finalize(g, &create_accumulator));
124 PartialTensorShape partial_shape(shape);
125 // Make a Fill operation to make a zero tensor with the shape of the first
126 // input.
127 Node* shape_node;
128 TF_RETURN_IF_ERROR(
129 make_node("Shape")
130 .Input(data_edges[0]->src(), data_edges[0]->src_output())
131 .Finalize(g, &shape_node));
132 Node* zero;
133 TF_RETURN_IF_ERROR(make_node("Const")
134 .Attr("value", make_zeros(dtype, TensorShapeProto()))
135 .Attr("dtype", dtype)
136 .Finalize(g, &zero));
137 TF_RETURN_IF_ERROR(make_node("Fill")
138 .Input(shape_node)
139 .Input(zero)
140 .Finalize(g, &initial_val));
141 TF_RETURN_IF_ERROR(make_node("Assign")
142 .Attr("T", dtype)
143 .Input(create_accumulator) // ref: Ref(T)
144 .Input(initial_val) // value: T
145 .Attr("validate_shape", false)
146 .Finalize(g, &initialize_accumulator));
147 for (int i = 0; i < data_edges.size(); ++i) {
148 Node* assignAdd;
149 TF_RETURN_IF_ERROR(make_node("AssignAdd")
150 .Attr("T", dtype)
151 .Attr("use_locking", true)
152 .Input(initialize_accumulator) // ref: Ref(T)
153 .Input(data_edges[i]->src(),
154 data_edges[i]->src_output()) // value: T
155 .Finalize(g, &assignAdd));
156
157 add_values_to_accumulator.push_back(assignAdd);
158 }
159
160 // Note that we use the original placeholder op's name here
161 TF_RETURN_IF_ERROR(base_make_node("DestroyTemporaryVariable", n->name())
162 .Attr("T", dtype)
163 .Attr("var_name", accumulator_name)
164 .Input(initialize_accumulator)
165 .Finalize(g, &clean_up_accumulator));
166
167 // Add edges to the graph to ensure that operations occur in the right
168 // order:
169 // 1. Do anything that had a control edge to the AccumulateNV2 placeholder
170 // 2. Initialize accumulator
171 // 3. Add input values to accumulator (already handled by data edges
172 // added above)
173 // 4. Reclaim the buffer that held the accumulator
174 // 5. Do anything that depended on the AccumulateNV2 placeholder
175 for (const Edge* control_edge : control_edges) {
176 g->AddControlEdge(control_edge->src(), initialize_accumulator);
177 }
178
179 for (Node* assign_add : add_values_to_accumulator) {
180 g->AddControlEdge(assign_add, clean_up_accumulator);
181 }
182
183 for (const Edge* out_edge : n->out_edges()) {
184 if (out_edge->IsControlEdge()) {
185 g->AddControlEdge(clean_up_accumulator, out_edge->dst());
186 } else {
187 g->AddEdge(clean_up_accumulator, 0, out_edge->dst(),
188 out_edge->dst_input());
189 }
190 }
191
192 // Remove the original AccumulateNV2 placeholder op.
193 // This removal modifies the op and must happen after we have finished
194 // using its incoming/outgoing edge sets.
195 g->RemoveNode(n);
196
197 return Status::OK();
198 }
199 };
200 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
201 AccumulateNV2RemovePass);
202
203 } // namespace
204 } // namespace tensorflow
205