1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_test_util.h"
24 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
25 #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
26 #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
27 #include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
28 #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
29 #include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
30 #include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
31 #include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
32 #include "tensorflow/c/tf_datatype.h"
33 #include "tensorflow/c/tf_shape.h"
34 #include "tensorflow/c/tf_status.h"
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/core/lib/io/path.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/stringpiece.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/platform/tstring.h"
41 
42 namespace {
43 
44 using tensorflow::tstring;
45 
46 constexpr char kTestData[] = "cc/saved_model/testdata";
47 const char* kServeTag[] = {"serve"};
48 
SavedModelPath(tensorflow::StringPiece saved_model_dir)49 std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
50   return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
51                                   kTestData, saved_model_dir);
52 }
53 
54 // This value parameterized test allows us to test both TFRT
55 // and non TFRT runtimes.
56 // https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
57 class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
58 
TEST_P(CSavedModelAPITest,LoadsSavedModelWithTags)59 TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
60   TF_Status* status = TF_NewStatus();
61   TFE_ContextOptions* opts = TFE_NewContextOptions();
62   bool use_tfrt = GetParam();
63   if (use_tfrt) {
64     TFE_DeleteContextOptions(opts);
65     TF_DeleteStatus(status);
66     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
67   }
68 
69   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
70 
71   TFE_Context* ctx = TFE_NewContext(opts, status);
72   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
73   TFE_DeleteContextOptions(opts);
74 
75   std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
76 
77   TF_SavedModel* saved_model =
78       TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
79 
80   // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
81   // That unblocks writing other tests that require a TF_SavedModel*,
82   // like loading a ConcreteFunction. This test at least checks that the
83   // C API builds and can be minimally run.
84   EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
85 
86   TF_DeleteSavedModel(saved_model);
87   TF_DeleteStatus(status);
88   TFE_DeleteContext(ctx);
89 }
90 
TEST_P(CSavedModelAPITest,LoadsSavedModel)91 TEST_P(CSavedModelAPITest, LoadsSavedModel) {
92   TF_Status* status = TF_NewStatus();
93   TFE_ContextOptions* opts = TFE_NewContextOptions();
94   bool use_tfrt = GetParam();
95   if (use_tfrt) {
96     TFE_DeleteContextOptions(opts);
97     TF_DeleteStatus(status);
98     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
99   }
100 
101   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
102 
103   TFE_Context* ctx = TFE_NewContext(opts, status);
104   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
105   TFE_DeleteContextOptions(opts);
106 
107   std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
108 
109   TF_SavedModel* saved_model =
110       TF_LoadSavedModel(model_dir.c_str(), ctx, status);
111 
112   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
113   TF_ConcreteFunction* compute_fn =
114       TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
115   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
116 
117   std::vector<TFE_TensorHandle*> compute_fn_inputs;
118   TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
119   TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
120   compute_fn_inputs.push_back(input_a);
121   compute_fn_inputs.push_back(input_b);
122 
123   TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
124       compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
125   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
126 
127   // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
128   // inputs + outputs a function has.
129   TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
130   int num_retvals = 1;
131 
132   TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
133   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
134 
135   TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
136   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
137 
138   EXPECT_EQ(TF_NumDims(result), 0);
139   float output_value = *static_cast<float*>(TF_TensorData(result));
140   // (1 + 2) * (2 + 1) / 3 + 5 should be 8
141   EXPECT_FLOAT_EQ(output_value, 8.0);
142 
143   TF_DeleteTensor(result);
144   TFE_DeleteTensorHandle(compute_fn_outputs[0]);
145   TFE_DeleteTensorHandle(input_a);
146   TFE_DeleteTensorHandle(input_b);
147   TFE_DeleteOp(compute_fn_op);
148   TF_DeleteSavedModel(saved_model);
149   TF_DeleteStatus(status);
150   TFE_DeleteContext(ctx);
151 }
152 
153 // This tests running the "serving_default" SignatureDefFunction from the
154 // VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
155 // protobuf in the metagraph looks like:
156 // signature_def: {
157 //   key  : "serving_default"
158 //   value: {
159 //     inputs: {
160 //       key  : "a"
161 //       value: {
162 //         name : "serving_default_a:0"
163 //         dtype: DT_FLOAT
164 //         tensor_shape: {
165 //         }
166 //       }
167 //     }
168 //     inputs: {
169 //       key  : "b"
170 //       value: {
171 //         name : "serving_default_b:0"
172 //         dtype: DT_FLOAT
173 //         tensor_shape: {
174 //         }
175 //       }
176 //     }
177 //     outputs: {
178 //       key  : "output_0"
179 //       value: {
180 //         name : "StatefulPartitionedCall:0"
181 //         dtype: DT_FLOAT
182 //         tensor_shape: {
183 //         }
184 //       }
185 //     }
186 //     method_name: "tensorflow/serving/predict"
187 //   }
188 // }
TEST_P(CSavedModelAPITest,RunsSignatureDefFunction)189 TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
190   TF_Status* status = TF_NewStatus();
191   TFE_ContextOptions* opts = TFE_NewContextOptions();
192   bool use_tfrt = GetParam();
193   if (use_tfrt) {
194     TFE_DeleteContextOptions(opts);
195     TF_DeleteStatus(status);
196     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
197   }
198 
199   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
200 
201   TFE_Context* ctx = TFE_NewContext(opts, status);
202   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
203   TFE_DeleteContextOptions(opts);
204 
205   std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
206 
207   TF_SavedModel* saved_model =
208       TF_LoadSavedModel(model_dir.c_str(), ctx, status);
209 
210   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
211   TF_SignatureDefFunction* serving_default =
212       TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
213                                            status);
214   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
215 
216   TF_SignatureDefFunctionMetadata* metadata =
217       TF_SignatureDefFunctionGetMetadata(serving_default);
218 
219   const TF_SignatureDefParamList* args =
220       TF_SignatureDefFunctionMetadataArgs(metadata);
221   const TF_SignatureDefParamList* returns =
222       TF_SignatureDefFunctionMetadataReturns(metadata);
223 
224   EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
225   const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
226   const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
227   const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
228 
229   // Input "a" is a scalar, float32 tensor
230   EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
231   EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
232   EXPECT_EQ(0, TF_ShapeDims(shape_a));
233 
234   const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
235   const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
236   const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
237 
238   // Input "b" is a scalar, float32 tensor
239   EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
240   EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
241   EXPECT_EQ(0, TF_ShapeDims(shape_b));
242 
243   EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
244 
245   const TF_SignatureDefParam* param_out =
246       TF_SignatureDefParamListGet(returns, 0);
247   const TF_TensorSpec* tensor_spec_out =
248       TF_SignatureDefParamTensorSpec(param_out);
249   const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
250 
251   // Output "output_0" is a scalar, float32 tensor
252   EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
253   EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
254   EXPECT_EQ(0, TF_ShapeDims(shape_out));
255 
256   std::vector<TFE_TensorHandle*> compute_fn_inputs;
257   TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
258   TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
259   compute_fn_inputs.push_back(input_a);
260   compute_fn_inputs.push_back(input_b);
261 
262   TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
263       serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
264       status);
265   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
266 
267   std::vector<TFE_TensorHandle*> compute_fn_outputs(
268       TF_SignatureDefParamListSize(returns));
269   int num_retvals = TF_SignatureDefParamListSize(returns);
270 
271   TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
272               status);
273   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
274 
275   TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
276   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
277 
278   EXPECT_EQ(TF_NumDims(result), 0);
279   float output_value = *static_cast<float*>(TF_TensorData(result));
280   // (1 + 2) * (2 + 1) / 3 + 5 should be 8
281   EXPECT_FLOAT_EQ(output_value, 8.0);
282 
283   TF_DeleteTensor(result);
284   TFE_DeleteTensorHandle(compute_fn_outputs[0]);
285   TFE_DeleteTensorHandle(input_a);
286   TFE_DeleteTensorHandle(input_b);
287   TFE_DeleteOp(serving_default_op);
288   TF_DeleteSavedModel(saved_model);
289   TF_DeleteStatus(status);
290   TFE_DeleteContext(ctx);
291 }
292 
TEST_P(CSavedModelAPITest,LoadsAssetSavedModel)293 TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
294   TF_Status* status = TF_NewStatus();
295   TFE_ContextOptions* opts = TFE_NewContextOptions();
296   bool use_tfrt = GetParam();
297   if (use_tfrt) {
298     TFE_DeleteContextOptions(opts);
299     TF_DeleteStatus(status);
300     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
301   }
302 
303   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
304 
305   TFE_Context* ctx = TFE_NewContext(opts, status);
306   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
307   TFE_DeleteContextOptions(opts);
308 
309   std::string model_dir = SavedModelPath("AssetModule");
310 
311   TF_SavedModel* saved_model =
312       TF_LoadSavedModel(model_dir.c_str(), ctx, status);
313 
314   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
315   TF_ConcreteFunction* read_file_fn =
316       TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
317   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
318 
319   TFE_Op* read_file_op =
320       TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
321   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
322 
323   // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
324   // inputs + outputs a function has.
325   TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
326   int num_retvals = 1;
327 
328   TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
329   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
330 
331   TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
332   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
333 
334   EXPECT_EQ(TF_NumDims(result), 0);
335   tensorflow::tstring* output_value =
336       static_cast<tensorflow::tstring*>(TF_TensorData(result));
337   std::string file_contents(*output_value);
338   EXPECT_NE(file_contents.find("TEST ASSET FILE CONTENTS"), std::string::npos);
339 
340   TF_DeleteTensor(result);
341   TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
342   TFE_DeleteOp(read_file_op);
343   TF_DeleteSavedModel(saved_model);
344   TF_DeleteStatus(status);
345   TFE_DeleteContext(ctx);
346 }
347 
TEST_P(CSavedModelAPITest,LoadsStaticHashtableSavedModel)348 TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) {
349   TF_Status* status = TF_NewStatus();
350   TFE_ContextOptions* opts = TFE_NewContextOptions();
351   bool use_tfrt = GetParam();
352   if (use_tfrt) {
353     TFE_DeleteContextOptions(opts);
354     TF_DeleteStatus(status);
355     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
356   }
357 
358   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
359 
360   TFE_Context* ctx = TFE_NewContext(opts, status);
361   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
362   TFE_DeleteContextOptions(opts);
363 
364   std::string model_dir = SavedModelPath("StaticHashTableModule");
365 
366   TF_SavedModel* saved_model =
367       TF_LoadSavedModel(model_dir.c_str(), ctx, status);
368 
369   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
370   TF_ConcreteFunction* lookup_fn =
371       TF_GetSavedModelConcreteFunction(saved_model, "lookup", status);
372   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
373 
374   // Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following
375   // mapping:
376   // "foo" -> 0
377   // "bar" -> 1
378   // "baz" -> 2
379   // "wombat" -> 3
380   // all other strings -> -1
381 
382   // Call lookup function with input "foo", expecting an output of 0
383   {
384     std::vector<TFE_TensorHandle*> lookup_fn_inputs;
385     TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo"));
386     lookup_fn_inputs.push_back(input_foo);
387 
388     TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
389         lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
390     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
391 
392     // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
393     // inputs + outputs a function has.
394     TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
395     int num_retvals = 1;
396 
397     TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
398     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
399 
400     TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
401     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
402 
403     EXPECT_EQ(TF_NumDims(result), 0);
404     tensorflow::int64* output_value =
405         static_cast<tensorflow::int64*>(TF_TensorData(result));
406     EXPECT_EQ(*output_value, 0);
407 
408     TF_DeleteTensor(result);
409     TFE_DeleteTensorHandle(input_foo);
410     TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
411     TFE_DeleteOp(lookup_op);
412   }
413 
414   // Call lookup function with input "baz", expecting an output of 2
415   {
416     std::vector<TFE_TensorHandle*> lookup_fn_inputs;
417     TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz"));
418     lookup_fn_inputs.push_back(input_foo);
419 
420     TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
421         lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
422     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
423 
424     // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
425     // inputs + outputs a function has.
426     TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
427     int num_retvals = 1;
428 
429     TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
430     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
431 
432     TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
433     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
434 
435     EXPECT_EQ(TF_NumDims(result), 0);
436     tensorflow::int64* output_value =
437         static_cast<tensorflow::int64*>(TF_TensorData(result));
438     EXPECT_EQ(*output_value, 2);
439 
440     TF_DeleteTensor(result);
441     TFE_DeleteTensorHandle(input_foo);
442     TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
443     TFE_DeleteOp(lookup_op);
444   }
445 
446   // Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1
447   {
448     std::vector<TFE_TensorHandle*> lookup_fn_inputs;
449     TFE_TensorHandle* input_foo =
450         TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY"));
451     lookup_fn_inputs.push_back(input_foo);
452 
453     TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
454         lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
455     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
456 
457     // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
458     // inputs + outputs a function has.
459     TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
460     int num_retvals = 1;
461 
462     TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
463     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
464 
465     TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
466     EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
467 
468     EXPECT_EQ(TF_NumDims(result), 0);
469     tensorflow::int64* output_value =
470         static_cast<tensorflow::int64*>(TF_TensorData(result));
471     EXPECT_EQ(*output_value, -1);
472 
473     TF_DeleteTensor(result);
474     TFE_DeleteTensorHandle(input_foo);
475     TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
476     TFE_DeleteOp(lookup_op);
477   }
478 
479   TF_DeleteSavedModel(saved_model);
480   TF_DeleteStatus(status);
481   TFE_DeleteContext(ctx);
482 }
483 
TEST_P(CSavedModelAPITest,LoadSavedModelWithUninitializedVariable)484 TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
485   TF_Status* status = TF_NewStatus();
486   TFE_ContextOptions* opts = TFE_NewContextOptions();
487   bool use_tfrt = GetParam();
488   if (use_tfrt) {
489     TFE_DeleteContextOptions(opts);
490     TF_DeleteStatus(status);
491     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
492   }
493 
494   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
495 
496   TFE_Context* ctx = TFE_NewContext(opts, status);
497   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
498   TFE_DeleteContextOptions(opts);
499 
500   std::string model_dir = tensorflow::io::JoinPath(
501       tensorflow::testing::TensorFlowSrcRoot(),
502       "c/experimental/saved_model/internal/testdata/UninitializedVariable");
503 
504   TF_SavedModel* saved_model =
505       TF_LoadSavedModel(model_dir.c_str(), ctx, status);
506   EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
507 
508   tensorflow::TFSavedModelAPI* model_api =
509       tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
510           tensorflow::unwrap(saved_model));
511   tensorflow::Variable* uninitialized_variable;
512   ASSERT_EQ(tensorflow::Status::OK(),
513             model_api->GetVariable("uninitialized_variable",
514                                    &uninitialized_variable));
515   ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
516 
517   ASSERT_EQ(tensorflow::Status::OK(),
518             model_api->GetVariable("sub_module.uninitialized_variable",
519                                    &uninitialized_variable));
520   ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
521 
522   TF_DeleteSavedModel(saved_model);
523   TF_DeleteStatus(status);
524   TFE_DeleteContext(ctx);
525 }
526 
TEST_P(CSavedModelAPITest,LoadSavedModelWithWhileLoop)527 TEST_P(CSavedModelAPITest, LoadSavedModelWithWhileLoop) {
528   TF_Status* status = TF_NewStatus();
529   TFE_ContextOptions* opts = TFE_NewContextOptions();
530   bool use_tfrt = GetParam();
531   if (use_tfrt) {
532     TFE_DeleteContextOptions(opts);
533     TF_DeleteStatus(status);
534     GTEST_SKIP();  // TODO(chky) : Enable this once TFRT is open sourced.
535   }
536 
537   TFE_ContextOptionsSetTfrt(opts, use_tfrt);
538 
539   TFE_Context* ctx = TFE_NewContext(opts, status);
540   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
541   TFE_DeleteContextOptions(opts);
542 
543   std::string model_dir = tensorflow::io::JoinPath(
544       tensorflow::testing::TensorFlowSrcRoot(),
545       "c/experimental/saved_model/internal/testdata/SimpleWhileLoop");
546 
547   TF_SavedModel* saved_model =
548       TF_LoadSavedModel(model_dir.c_str(), ctx, status);
549   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
550 
551   TF_ConcreteFunction* while_fn =
552       TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
553   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
554 
555   std::vector<TFE_TensorHandle*> while_fn_inputs;
556   while_fn_inputs.push_back(TestScalarTensorHandle(ctx, 10.0f));
557 
558   TFE_Op* while_fn_op = TF_ConcreteFunctionMakeCallOp(
559       while_fn, while_fn_inputs.data(), while_fn_inputs.size(), status);
560   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
561 
562   TFE_TensorHandle* while_fn_outputs[1] = {nullptr};
563   int num_retvals = 1;
564 
565   TFE_Execute(while_fn_op, &while_fn_outputs[0], &num_retvals, status);
566   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
567 
568   TF_Tensor* result = TFE_TensorHandleResolve(while_fn_outputs[0], status);
569   ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
570   ASSERT_EQ(TF_NumDims(result), 0);
571   float output_value = *static_cast<float*>(TF_TensorData(result));
572   ASSERT_FLOAT_EQ(output_value, 55);  // 10+9+...+1
573 
574   TF_DeleteTensor(result);
575   TFE_DeleteTensorHandle(while_fn_outputs[0]);
576   TFE_DeleteOp(while_fn_op);
577   TFE_DeleteTensorHandle(while_fn_inputs[0]);
578   TF_DeleteSavedModel(saved_model);
579   TF_DeleteStatus(status);
580   TFE_DeleteContext(ctx);
581 }
582 
583 INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
584                          ::testing::Bool());
585 
586 }  // namespace
587