1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/python_api.h"
17 
18 #include "tensorflow/c/c_api_internal.h"
19 #include "tensorflow/python/framework/cpp_shape_inference.pb.h"
20 
21 namespace tensorflow {
22 
AddControlInput(TF_Graph * graph,TF_Operation * op,TF_Operation * input)23 void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
24   mutex_lock l(graph->mu);
25   graph->graph.AddControlEdge(&input->node, &op->node);
26   RecordMutation(graph, *op, "adding control input");
27 }
28 
SetAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Buffer * attr_value_proto,TF_Status * status)29 void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
30              TF_Buffer* attr_value_proto, TF_Status* status) {
31   AttrValue attr_val;
32   if (!attr_val.ParseFromArray(attr_value_proto->data,
33                                attr_value_proto->length)) {
34     status->status =
35         tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
36     return;
37   }
38 
39   mutex_lock l(graph->mu);
40   op->node.AddAttr(attr_name, attr_val);
41   RecordMutation(graph, *op, "setting attribute");
42 }
43 
ClearAttr(TF_Graph * graph,TF_Operation * op,const char * attr_name,TF_Status * status)44 void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
45                TF_Status* status) {
46   AttrValue attr_val;
47 
48   mutex_lock l(graph->mu);
49   op->node.ClearAttr(attr_name);
50   RecordMutation(graph, *op, "clearing attribute");
51 }
52 
SetRequestedDevice(TF_Graph * graph,TF_Operation * op,const char * device)53 void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
54   mutex_lock l(graph->mu);
55   op->node.set_requested_device(device);
56   RecordMutation(graph, *op, "setting device");
57 }
58 
UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)59 void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
60                 TF_Status* status) {
61   mutex_lock l(graph->mu);
62   tensorflow::shape_inference::InferenceContext* ic =
63       graph->refiner.GetContext(&new_src.oper->node);
64 
65   if (ic->num_outputs() <= new_src.index) {
66     status->status = tensorflow::errors::OutOfRange(
67         "Cannot update edge. Output index [", new_src.index,
68         "] is greater than the number of total outputs [", ic->num_outputs(),
69         "].");
70     return;
71   }
72   tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
73 
74   tensorflow::shape_inference::InferenceContext* ic_dst =
75       graph->refiner.GetContext(&dst.oper->node);
76   if (ic_dst->num_inputs() <= dst.index) {
77     status->status = tensorflow::errors::OutOfRange(
78         "Cannot update edge. Input index [", dst.index,
79         "] is greater than the number of total inputs [", ic_dst->num_inputs(),
80         "].");
81     return;
82   }
83   if (!ic_dst->MergeInput(dst.index, shape)) {
84     status->status = tensorflow::errors::InvalidArgument(
85         "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
86         " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
87     return;
88   }
89   status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
90                                            &dst.oper->node, dst.index);
91 
92   if (TF_GetCode(status) == TF_OK) {
93     // This modification only updates the destination node for
94     // the purposes of running this graph in a session. Thus, we don't
95     // record the source node as being modified.
96     RecordMutation(graph, *dst.oper, "updating input tensor");
97   }
98 }
99 
RemoveAllControlInputs(TF_Graph * graph,TF_Operation * op)100 void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
101   mutex_lock l(graph->mu);
102   std::vector<const Edge*> control_edges;
103   for (const Edge* edge : op->node.in_edges()) {
104     if (!edge->IsControlEdge()) continue;
105     control_edges.push_back(edge);
106   }
107   for (const Edge* edge : control_edges) {
108     graph->graph.RemoveControlEdge(edge);
109   }
110 }
111 
SetRequireShapeInferenceFns(TF_Graph * graph,bool require)112 void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
113   mutex_lock l(graph->mu);
114   graph->refiner.set_require_shape_inference_fns(require);
115 }
116 
ExtendSession(TF_Session * session,TF_Status * status)117 void ExtendSession(TF_Session* session, TF_Status* status) {
118   ExtendSessionGraphHelper(session, status);
119   session->extend_before_run = false;
120 }
121 
GetHandleShapeAndType(TF_Graph * graph,TF_Output output)122 std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
123   Node* node = &output.oper->node;
124   CppShapeInferenceResult::HandleData handle_data;
125   handle_data.set_is_set(true);
126   {
127     mutex_lock l(graph->mu);
128     tensorflow::shape_inference::InferenceContext* ic =
129         graph->refiner.GetContext(node);
130     CHECK(ic != nullptr);
131     CHECK_LT(output.index, ic->num_outputs());
132     const auto* shapes_and_types =
133         ic->output_handle_shapes_and_types(output.index);
134     if (shapes_and_types == nullptr) return "";
135 
136     for (const auto& p : *shapes_and_types) {
137       auto* out_shape_and_type = handle_data.add_shape_and_type();
138       ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
139       out_shape_and_type->set_dtype(p.dtype);
140     }
141   }
142   string result;
143   handle_data.SerializeToString(&result);
144   return result;
145 }
146 
SetHandleShapeAndType(TF_Graph * graph,TF_Output output,const void * proto,size_t proto_len,TF_Status * status)147 void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
148                            size_t proto_len, TF_Status* status) {
149   tensorflow::CppShapeInferenceResult::HandleData handle_data;
150   if (!handle_data.ParseFromArray(proto, proto_len)) {
151     status->status = tensorflow::errors::InvalidArgument(
152         "Couldn't deserialize HandleData proto");
153     return;
154   }
155   DCHECK(handle_data.is_set());
156 
157   tensorflow::mutex_lock l(graph->mu);
158   tensorflow::shape_inference::InferenceContext* ic =
159       graph->refiner.GetContext(&output.oper->node);
160 
161   std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
162   for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
163     tensorflow::shape_inference::ShapeHandle shape;
164     status->status =
165         ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
166     if (TF_GetCode(status) != TF_OK) return;
167     shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
168   }
169   ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
170 }
171 
AddWhileInputHack(TF_Graph * graph,TF_Output new_src,TF_Operation * dst,TF_Status * status)172 void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
173                        TF_Status* status) {
174   mutex_lock l(graph->mu);
175   status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
176                                                   new_src.index, &dst->node);
177   if (TF_GetCode(status) == TF_OK) {
178     // This modification only updates the destination node for
179     // the purposes of running this graph in a session. Thus, we don't
180     // record the source node as being modified.
181     RecordMutation(graph, *dst, "adding input tensor");
182   }
183 }
184 
185 }  // namespace tensorflow
186