1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/optimization_registry.h"
17 #include "tensorflow/core/graph/node_builder.h"
18 
19 namespace tensorflow {
20 namespace {
21 
make_zeros(const DataType & dtype,const TensorShapeProto & shape)22 Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) {
23   Tensor tensor(dtype, TensorShape(shape));
24 
25   // Conveniently, all numeric data types have 0x0 == zero.  Otherwise we would
26   // need a giant switch statement here.
27   memset(const_cast<char*>(tensor.tensor_data().data()), 0,
28          tensor.tensor_data().size());
29 
30   return tensor;
31 }
32 
33 // Replaces occurrences of the "AccumulateNV2" stub operator with a graph of
34 // lower-level ops. The graph is equivalent (modulo certain corner cases)
35 // to the semantics of the original accumulate_n() Python op in math_ops.py.
36 // Implementing the op with a rewrite allows this new variant of accumulate_n
37 // to be differentiable.
38 //
39 // The binary code that generates AccumulateNV2 stub ops is located in a
40 // dynamic library built out of tensorflow/contrib/framework. Ideally, this
41 // class would also be in contrib, but calls to REGISTER_OPTIMIZATION() from
42 // third-party libraries aren't currently supported.
43 class AccumulateNV2RemovePass : public GraphOptimizationPass {
44  public:
Run(const GraphOptimizationPassOptions & options)45   Status Run(const GraphOptimizationPassOptions& options) override {
46     // TODO(freiss.oss@gmail.com): Substantial shared code with
47     // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
48     // a third similar rewrite.
49     if (options.graph == nullptr) {
50       // TODO(apassos) returning OK feels weird here as we can't do anything
51       // without a graph, but some tests require this.
52       return Status::OK();
53     }
54 
55     Graph* g = options.graph->get();
56     if (g == nullptr) {
57       return errors::Internal(
58           "AccumulateNV2 removal should happen before partitioning and a "
59           "graph should be available.");
60     }
61 
62     // Build up a todo list of ops to replace, *then* modify the graph
63     gtl::InlinedVector<Node*, 2> matches;
64     for (Node* n : g->op_nodes()) {
65       if (n->type_string() == "AccumulateNV2") {
66         matches.push_back(n);
67       }
68     }
69     for (Node* n : matches) {
70       TF_RETURN_IF_ERROR(rewriteNode(n, g));
71     }
72     return Status::OK();
73   }
74 
rewriteNode(Node * n,Graph * g)75   Status rewriteNode(Node* n, Graph* g) {
76     AttrSlice n_attrs = n->attrs();
77     auto base_make_node = [n, &n_attrs](const string& op, const string& name) {
78       NodeDebugInfo debug_info(*n);
79       NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
80 
81       // The pieces of AccumulateNV2 should all be on the same node.
82       node_builder.Device(n->requested_device());
83       string colo;
84       if (GetNodeAttr(n_attrs, kColocationAttrName, &colo).ok()) {
85         node_builder.Attr(kColocationAttrName, colo);
86       }
87       return node_builder;
88     };
89     auto make_node = [n, g, &base_make_node](string op) {
90       return base_make_node(
91           op, g->NewName(strings::StrCat(n->name(), "/Internal")));
92     };
93 
94     DataType dtype;
95     TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
96     TensorShapeProto shape;
97     TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
98 
99     std::vector<const Edge*> data_edges, control_edges;
100     for (const Edge* input_edge : n->in_edges()) {
101       if (input_edge->IsControlEdge()) {
102         control_edges.push_back(input_edge);
103       } else {
104         data_edges.push_back(input_edge);
105       }
106     }
107 
108     // Create the following ops to replace the AccumulateNV2 placeholder:
109     Node* create_accumulator = nullptr;            // TemporaryVariable op
110     Node* initial_val = nullptr;                   // Const op
111     Node* initialize_accumulator = nullptr;        // Assign op
112     std::vector<Node*> add_values_to_accumulator;  // AssignAdd ops
113     Node* clean_up_accumulator = nullptr;          // DestroyTemporaryVariable
114 
115     const string accumulator_name =
116         strings::StrCat(n->name(), "/Internal/Accumulator");
117     TensorShapeProto variable_shape;
118     variable_shape.add_dim()->set_size(0);
119     TF_RETURN_IF_ERROR(make_node("TemporaryVariable")
120                            .Attr("shape", variable_shape)
121                            .Attr("dtype", dtype)
122                            .Attr("var_name", accumulator_name)
123                            .Finalize(g, &create_accumulator));
124     PartialTensorShape partial_shape(shape);
125     // Make a Fill operation to make a zero tensor with the shape of the first
126     // input.
127     Node* shape_node;
128     TF_RETURN_IF_ERROR(
129         make_node("Shape")
130             .Input(data_edges[0]->src(), data_edges[0]->src_output())
131             .Finalize(g, &shape_node));
132     Node* zero;
133     TF_RETURN_IF_ERROR(make_node("Const")
134                            .Attr("value", make_zeros(dtype, TensorShapeProto()))
135                            .Attr("dtype", dtype)
136                            .Finalize(g, &zero));
137     TF_RETURN_IF_ERROR(make_node("Fill")
138                            .Input(shape_node)
139                            .Input(zero)
140                            .Finalize(g, &initial_val));
141     TF_RETURN_IF_ERROR(make_node("Assign")
142                            .Attr("T", dtype)
143                            .Input(create_accumulator)  // ref: Ref(T)
144                            .Input(initial_val)         // value: T
145                            .Attr("validate_shape", false)
146                            .Finalize(g, &initialize_accumulator));
147     for (int i = 0; i < data_edges.size(); ++i) {
148       Node* assignAdd;
149       TF_RETURN_IF_ERROR(make_node("AssignAdd")
150                              .Attr("T", dtype)
151                              .Attr("use_locking", true)
152                              .Input(initialize_accumulator)  // ref: Ref(T)
153                              .Input(data_edges[i]->src(),
154                                     data_edges[i]->src_output())  // value: T
155                              .Finalize(g, &assignAdd));
156 
157       add_values_to_accumulator.push_back(assignAdd);
158     }
159 
160     // Note that we use the original placeholder op's name here
161     TF_RETURN_IF_ERROR(base_make_node("DestroyTemporaryVariable", n->name())
162                            .Attr("T", dtype)
163                            .Attr("var_name", accumulator_name)
164                            .Input(initialize_accumulator)
165                            .Finalize(g, &clean_up_accumulator));
166 
167     // Add edges to the graph to ensure that operations occur in the right
168     // order:
169     // 1. Do anything that had a control edge to the AccumulateNV2 placeholder
170     // 2. Initialize accumulator
171     // 3. Add input values to accumulator (already handled by data edges
172     //    added above)
173     // 4. Reclaim the buffer that held the accumulator
174     // 5. Do anything that depended on the AccumulateNV2 placeholder
175     for (const Edge* control_edge : control_edges) {
176       g->AddControlEdge(control_edge->src(), initialize_accumulator);
177     }
178 
179     for (Node* assign_add : add_values_to_accumulator) {
180       g->AddControlEdge(assign_add, clean_up_accumulator);
181     }
182 
183     for (const Edge* out_edge : n->out_edges()) {
184       if (out_edge->IsControlEdge()) {
185         g->AddControlEdge(clean_up_accumulator, out_edge->dst());
186       } else {
187         g->AddEdge(clean_up_accumulator, 0, out_edge->dst(),
188                    out_edge->dst_input());
189       }
190     }
191 
192     // Remove the original AccumulateNV2 placeholder op.
193     // This removal modifies the op and must happen after we have finished
194     // using its incoming/outgoing edge sets.
195     g->RemoveNode(n);
196 
197     return Status::OK();
198   }
199 };
200 REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
201                       AccumulateNV2RemovePass);
202 
203 }  // namespace
204 }  // namespace tensorflow
205