1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_test_util.h"
17 
18 #include "tensorflow/c/c_api_experimental.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/public/session_options.h"
25 
26 using tensorflow::GraphDef;
27 using tensorflow::NodeDef;
28 
BoolDeallocator(void * data,size_t,void * arg)29 static void BoolDeallocator(void* data, size_t, void* arg) {
30   delete[] static_cast<bool*>(data);
31 }
32 
Int32Deallocator(void * data,size_t,void * arg)33 static void Int32Deallocator(void* data, size_t, void* arg) {
34   delete[] static_cast<int32_t*>(data);
35 }
36 
DoubleDeallocator(void * data,size_t,void * arg)37 static void DoubleDeallocator(void* data, size_t, void* arg) {
38   delete[] static_cast<double*>(data);
39 }
40 
FloatDeallocator(void * data,size_t,void * arg)41 static void FloatDeallocator(void* data, size_t, void* arg) {
42   delete[] static_cast<float*>(data);
43 }
44 
BoolTensor(bool v)45 TF_Tensor* BoolTensor(bool v) {
46   const int num_bytes = sizeof(bool);
47   bool* values = new bool[1];
48   values[0] = v;
49   return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
50                       nullptr);
51 }
52 
Int8Tensor(const int64_t * dims,int num_dims,const char * values)53 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
54   int64_t num_values = 1;
55   for (int i = 0; i < num_dims; ++i) {
56     num_values *= dims[i];
57   }
58   TF_Tensor* t =
59       TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
60   memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
61   return t;
62 }
63 
Int32Tensor(const int64_t * dims,int num_dims,const int32_t * values)64 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
65                        const int32_t* values) {
66   int64_t num_values = 1;
67   for (int i = 0; i < num_dims; ++i) {
68     num_values *= dims[i];
69   }
70   TF_Tensor* t =
71       TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
72   memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
73   return t;
74 }
75 
Int32Tensor(const std::vector<int32_t> & values)76 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
77   int64_t dims = values.size();
78   return Int32Tensor(&dims, 1, values.data());
79 }
80 
Int32Tensor(int32_t v)81 TF_Tensor* Int32Tensor(int32_t v) {
82   const int num_bytes = sizeof(int32_t);
83   int32_t* values = new int32_t[1];
84   values[0] = v;
85   return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes,
86                       &Int32Deallocator, nullptr);
87 }
88 
DoubleTensor(double v)89 TF_Tensor* DoubleTensor(double v) {
90   const int num_bytes = sizeof(double);
91   double* values = new double[1];
92   values[0] = v;
93   return TF_NewTensor(TF_DOUBLE, nullptr, 0, values, num_bytes,
94                       &DoubleDeallocator, nullptr);
95 }
96 
FloatTensor(float v)97 TF_Tensor* FloatTensor(float v) {
98   const int num_bytes = sizeof(float);
99   float* values = new float[1];
100   values[0] = v;
101   return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes,
102                       &FloatDeallocator, nullptr);
103 }
104 
105 // All the *Helper methods are used as a workaround for the restrictions that
106 // one cannot call ASSERT_* methods in non-void-returning functions (when
107 // exceptions are disabled during compilation)
PlaceholderHelper(TF_Graph * graph,TF_Status * s,const char * name,TF_DataType dtype,const std::vector<int64_t> & dims,TF_Operation ** op)108 void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
109                        TF_DataType dtype, const std::vector<int64_t>& dims,
110                        TF_Operation** op) {
111   TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
112   TF_SetAttrType(desc, "dtype", dtype);
113   if (!dims.empty()) {
114     TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
115   }
116   *op = TF_FinishOperation(desc, s);
117   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
118   ASSERT_NE(*op, nullptr);
119 }
120 
Placeholder(TF_Graph * graph,TF_Status * s,const char * name,TF_DataType dtype,const std::vector<int64_t> & dims)121 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name,
122                           TF_DataType dtype, const std::vector<int64_t>& dims) {
123   TF_Operation* op;
124   PlaceholderHelper(graph, s, name, dtype, dims, &op);
125   return op;
126 }
127 
ConstHelper(TF_Tensor * t,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)128 void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
129                  TF_Operation** op) {
130   TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
131   TF_SetAttrTensor(desc, "value", t, s);
132   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
133   TF_SetAttrType(desc, "dtype", TF_TensorType(t));
134   *op = TF_FinishOperation(desc, s);
135   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
136   ASSERT_NE(*op, nullptr);
137 }
138 
Const(TF_Tensor * t,TF_Graph * graph,TF_Status * s,const char * name)139 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
140                     const char* name) {
141   TF_Operation* op;
142   ConstHelper(t, graph, s, name, &op);
143   return op;
144 }
145 
ScalarConst(bool v,TF_Graph * graph,TF_Status * s,const char * name)146 TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
147                           const char* name) {
148   unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
149   return Const(tensor.get(), graph, s, name);
150 }
151 
ScalarConst(int32_t v,TF_Graph * graph,TF_Status * s,const char * name)152 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
153                           const char* name) {
154   unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
155   return Const(tensor.get(), graph, s, name);
156 }
157 
ScalarConst(double v,TF_Graph * graph,TF_Status * s,const char * name)158 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
159                           const char* name) {
160   unique_tensor_ptr tensor(DoubleTensor(v), TF_DeleteTensor);
161   return Const(tensor.get(), graph, s, name);
162 }
163 
ScalarConst(float v,TF_Graph * graph,TF_Status * s,const char * name)164 TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
165                           const char* name) {
166   unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor);
167   return Const(tensor.get(), graph, s, name);
168 }
169 
AddOpHelper(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op,bool check)170 void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
171                  TF_Status* s, const char* name, TF_Operation** op,
172                  bool check) {
173   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
174   TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
175   TF_AddInputList(desc, add_inputs, 2);
176   *op = TF_FinishOperation(desc, s);
177   if (check) {
178     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
179     ASSERT_NE(*op, nullptr);
180   }
181 }
182 
Add(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)183 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
184                   TF_Status* s, const char* name) {
185   TF_Operation* op;
186   AddOpHelper(l, r, graph, s, name, &op, true);
187   return op;
188 }
189 
AddNoCheck(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)190 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
191                          TF_Status* s, const char* name) {
192   TF_Operation* op;
193   AddOpHelper(l, r, graph, s, name, &op, false);
194   return op;
195 }
196 
AddWithCtrlDependency(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Operation * ctrl_op,TF_Status * s,const char * name)197 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
198                                     TF_Graph* graph, TF_Operation* ctrl_op,
199                                     TF_Status* s, const char* name) {
200   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
201   TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
202   TF_AddInputList(desc, add_inputs, 2);
203   TF_AddControlInput(desc, ctrl_op);
204   return TF_FinishOperation(desc, s);
205 }
206 
207 // If `op_device` is non-empty, set the created op on that device.
BinaryOpHelper(const char * op_name,TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op,const string & op_device,bool check)208 void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r,
209                     TF_Graph* graph, TF_Status* s, const char* name,
210                     TF_Operation** op, const string& op_device, bool check) {
211   TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name);
212   if (!op_device.empty()) {
213     TF_SetDevice(desc, op_device.c_str());
214   }
215   TF_AddInput(desc, {l, 0});
216   TF_AddInput(desc, {r, 0});
217   *op = TF_FinishOperation(desc, s);
218   if (check) {
219     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
220     ASSERT_NE(*op, nullptr);
221   }
222 }
223 
MinWithDevice(TF_Operation * l,TF_Operation * r,TF_Graph * graph,const string & op_device,TF_Status * s,const char * name)224 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
225                             const string& op_device, TF_Status* s,
226                             const char* name) {
227   TF_Operation* op;
228   BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true);
229   return op;
230 }
231 
Min(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)232 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
233                   TF_Status* s, const char* name) {
234   return MinWithDevice(l, r, graph, /*op_device=*/"", s, name);
235 }
236 
Mul(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)237 TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
238                   TF_Status* s, const char* name) {
239   TF_Operation* op;
240   BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true);
241   return op;
242 }
243 
Add(TF_Output l,TF_Output r,TF_Graph * graph,TF_Status * s,const char * name)244 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
245                   const char* name) {
246   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
247   TF_Output inputs[2] = {l, r};
248   TF_AddInputList(desc, inputs, 2);
249   return TF_FinishOperation(desc, s);
250 }
251 
NegHelper(TF_Operation * n,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)252 void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name,
253                TF_Operation** op) {
254   TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", name);
255   TF_Output neg_input = {n, 0};
256   TF_AddInput(desc, neg_input);
257   *op = TF_FinishOperation(desc, s);
258   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
259   ASSERT_NE(*op, nullptr);
260 }
261 
Neg(TF_Operation * n,TF_Graph * graph,TF_Status * s,const char * name)262 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
263                   const char* name) {
264   TF_Operation* op;
265   NegHelper(n, graph, s, name, &op);
266   return op;
267 }
268 
LessThan(TF_Output l,TF_Output r,TF_Graph * graph,TF_Status * s)269 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
270                        TF_Status* s) {
271   TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
272   TF_AddInput(desc, l);
273   TF_AddInput(desc, r);
274   return TF_FinishOperation(desc, s);
275 }
276 
RandomUniform(TF_Operation * shape,TF_DataType dtype,TF_Graph * graph,TF_Status * s)277 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
278                             TF_Graph* graph, TF_Status* s) {
279   TF_OperationDescription* desc =
280       TF_NewOperation(graph, "RandomUniform", "random_uniform");
281   TF_AddInput(desc, {shape, 0});
282   TF_SetAttrType(desc, "dtype", dtype);
283   return TF_FinishOperation(desc, s);
284 }
285 
Split3Helper(TF_Operation * input,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)286 void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
287                   const char* name, TF_Operation** op) {
288   TF_Operation* zero = ScalarConst(
289       0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
290   TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
291   TF_AddInput(desc, {zero, 0});
292   TF_AddInput(desc, {input, 0});
293   TF_SetAttrInt(desc, "num_split", 3);
294   TF_SetAttrType(desc, "T", TF_INT32);
295   // Set device to CPU since there is no version of split for int32 on GPU
296   // TODO(iga): Convert all these helpers and tests to use floats because
297   // they are usually available on GPUs. After doing this, remove TF_SetDevice
298   // call in c_api_function_test.cc
299   TF_SetDevice(desc, "/cpu:0");
300   *op = TF_FinishOperation(desc, s);
301   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
302   ASSERT_NE(*op, nullptr);
303 }
304 
Split3(TF_Operation * input,TF_Graph * graph,TF_Status * s,const char * name)305 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
306                      const char* name) {
307   TF_Operation* op;
308   Split3Helper(input, graph, s, name, &op);
309   return op;
310 }
311 
IsPlaceholder(const tensorflow::NodeDef & node_def)312 bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
313   if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
314     return false;
315   }
316   bool found_dtype = false;
317   bool found_shape = false;
318   for (const auto& attr : node_def.attr()) {
319     if (attr.first == "dtype") {
320       if (attr.second.type() == tensorflow::DT_INT32) {
321         found_dtype = true;
322       } else {
323         return false;
324       }
325     } else if (attr.first == "shape") {
326       found_shape = true;
327     }
328   }
329   return found_dtype && found_shape;
330 }
331 
IsScalarConst(const tensorflow::NodeDef & node_def,int v)332 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v) {
333   if (node_def.op() != "Const" || node_def.name() != "scalar") {
334     return false;
335   }
336   bool found_dtype = false;
337   bool found_value = false;
338   for (const auto& attr : node_def.attr()) {
339     if (attr.first == "dtype") {
340       if (attr.second.type() == tensorflow::DT_INT32) {
341         found_dtype = true;
342       } else {
343         return false;
344       }
345     } else if (attr.first == "value") {
346       if (attr.second.has_tensor() &&
347           attr.second.tensor().int_val_size() == 1 &&
348           attr.second.tensor().int_val(0) == v) {
349         found_value = true;
350       } else {
351         return false;
352       }
353     }
354   }
355   return found_dtype && found_value;
356 }
357 
IsAddN(const tensorflow::NodeDef & node_def,int n)358 bool IsAddN(const tensorflow::NodeDef& node_def, int n) {
359   if (node_def.op() != "AddN" || node_def.name() != "add" ||
360       node_def.input_size() != n) {
361     return false;
362   }
363   bool found_t = false;
364   bool found_n = false;
365   for (const auto& attr : node_def.attr()) {
366     if (attr.first == "T") {
367       if (attr.second.type() == tensorflow::DT_INT32) {
368         found_t = true;
369       } else {
370         return false;
371       }
372     } else if (attr.first == "N") {
373       if (attr.second.i() == n) {
374         found_n = true;
375       } else {
376         return false;
377       }
378     }
379   }
380   return found_t && found_n;
381 }
382 
IsNeg(const tensorflow::NodeDef & node_def,const string & input)383 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input) {
384   return node_def.op() == "Neg" && node_def.name() == "neg" &&
385          node_def.input_size() == 1 && node_def.input(0) == input;
386 }
387 
GetGraphDef(TF_Graph * graph,tensorflow::GraphDef * graph_def)388 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def) {
389   TF_Status* s = TF_NewStatus();
390   TF_Buffer* buffer = TF_NewBuffer();
391   TF_GraphToGraphDef(graph, buffer, s);
392   bool ret = TF_GetCode(s) == TF_OK;
393   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
394   if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
395   TF_DeleteBuffer(buffer);
396   TF_DeleteStatus(s);
397   return ret;
398 }
399 
GetNodeDef(TF_Operation * oper,tensorflow::NodeDef * node_def)400 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
401   TF_Status* s = TF_NewStatus();
402   TF_Buffer* buffer = TF_NewBuffer();
403   TF_OperationToNodeDef(oper, buffer, s);
404   bool ret = TF_GetCode(s) == TF_OK;
405   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
406   if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
407   TF_DeleteBuffer(buffer);
408   TF_DeleteStatus(s);
409   return ret;
410 }
411 
GetFunctionDef(TF_Function * func,tensorflow::FunctionDef * func_def)412 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
413   TF_Status* s = TF_NewStatus();
414   TF_Buffer* buffer = TF_NewBuffer();
415   TF_FunctionToFunctionDef(func, buffer, s);
416   bool ret = TF_GetCode(s) == TF_OK;
417   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
418   if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
419   TF_DeleteBuffer(buffer);
420   TF_DeleteStatus(s);
421   return ret;
422 }
423 
GetAttrValue(TF_Operation * oper,const char * attr_name,tensorflow::AttrValue * attr_value,TF_Status * s)424 bool GetAttrValue(TF_Operation* oper, const char* attr_name,
425                   tensorflow::AttrValue* attr_value, TF_Status* s) {
426   TF_Buffer* buffer = TF_NewBuffer();
427   TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
428   bool ret = TF_GetCode(s) == TF_OK;
429   if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
430   TF_DeleteBuffer(buffer);
431   return ret;
432 }
433 
GetGradDefs(const tensorflow::GraphDef & graph_def)434 std::vector<std::pair<string, string>> GetGradDefs(
435     const tensorflow::GraphDef& graph_def) {
436   std::vector<std::pair<string, string>> grads;
437   for (const tensorflow::GradientDef& grad : graph_def.library().gradient()) {
438     grads.emplace_back(grad.function_name(), grad.gradient_func());
439   }
440   std::sort(grads.begin(), grads.end());
441   return grads;
442 }
443 
GetFuncNames(const tensorflow::GraphDef & graph_def)444 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) {
445   std::vector<string> names;
446   for (const tensorflow::FunctionDef& func : graph_def.library().function()) {
447     names.push_back(func.signature().name());
448   }
449   std::sort(names.begin(), names.end());
450   return names;
451 }
452 
CSession(TF_Graph * graph,TF_Status * s,bool use_XLA)453 CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) {
454   TF_SessionOptions* opts = TF_NewSessionOptions();
455   TF_EnableXLACompilation(opts, use_XLA);
456   session_ = TF_NewSession(graph, opts, s);
457   TF_DeleteSessionOptions(opts);
458 }
459 
CSession(TF_Session * session)460 CSession::CSession(TF_Session* session) : session_(session) {}
461 
~CSession()462 CSession::~CSession() {
463   TF_Status* s = TF_NewStatus();
464   CloseAndDelete(s);
465   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
466   TF_DeleteStatus(s);
467 }
468 
SetInputs(std::vector<std::pair<TF_Operation *,TF_Tensor * >> inputs)469 void CSession::SetInputs(
470     std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
471   DeleteInputValues();
472   inputs_.clear();
473   for (const auto& p : inputs) {
474     inputs_.emplace_back(TF_Output{p.first, 0});
475     input_values_.emplace_back(p.second);
476   }
477 }
478 
SetOutputs(std::initializer_list<TF_Operation * > outputs)479 void CSession::SetOutputs(std::initializer_list<TF_Operation*> outputs) {
480   ResetOutputValues();
481   outputs_.clear();
482   for (TF_Operation* o : outputs) {
483     outputs_.emplace_back(TF_Output{o, 0});
484   }
485   output_values_.resize(outputs_.size());
486 }
487 
SetOutputs(const std::vector<TF_Output> & outputs)488 void CSession::SetOutputs(const std::vector<TF_Output>& outputs) {
489   ResetOutputValues();
490   outputs_ = outputs;
491   output_values_.resize(outputs_.size());
492 }
493 
SetTargets(std::initializer_list<TF_Operation * > targets)494 void CSession::SetTargets(std::initializer_list<TF_Operation*> targets) {
495   targets_.clear();
496   for (TF_Operation* t : targets) {
497     targets_.emplace_back(t);
498   }
499 }
500 
Run(TF_Status * s)501 void CSession::Run(TF_Status* s) {
502   if (inputs_.size() != input_values_.size()) {
503     ADD_FAILURE() << "Call SetInputs() before Run()";
504     return;
505   }
506   ResetOutputValues();
507   output_values_.resize(outputs_.size(), nullptr);
508 
509   const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0];
510   TF_Tensor* const* input_values_ptr =
511       input_values_.empty() ? nullptr : &input_values_[0];
512 
513   const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0];
514   TF_Tensor** output_values_ptr =
515       output_values_.empty() ? nullptr : &output_values_[0];
516 
517   TF_Operation* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0];
518 
519   TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, inputs_.size(),
520                 outputs_ptr, output_values_ptr, outputs_.size(), targets_ptr,
521                 targets_.size(), nullptr, s);
522 
523   DeleteInputValues();
524 }
525 
CloseAndDelete(TF_Status * s)526 void CSession::CloseAndDelete(TF_Status* s) {
527   DeleteInputValues();
528   ResetOutputValues();
529   if (session_ != nullptr) {
530     TF_CloseSession(session_, s);
531     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
532     TF_DeleteSession(session_, s);
533     session_ = nullptr;
534   }
535 }
536 
DeleteInputValues()537 void CSession::DeleteInputValues() {
538   for (size_t i = 0; i < input_values_.size(); ++i) {
539     TF_DeleteTensor(input_values_[i]);
540   }
541   input_values_.clear();
542 }
543 
ResetOutputValues()544 void CSession::ResetOutputValues() {
545   for (size_t i = 0; i < output_values_.size(); ++i) {
546     if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]);
547   }
548   output_values_.clear();
549 }
550