1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/c/eager/c_api.h"
22 #include "tensorflow/c/eager/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api_test_util.h"
24 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
25 #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
26 #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
27 #include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
28 #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
29 #include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
30 #include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
31 #include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
32 #include "tensorflow/c/tf_datatype.h"
33 #include "tensorflow/c/tf_shape.h"
34 #include "tensorflow/c/tf_status.h"
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/core/lib/io/path.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/stringpiece.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/platform/tstring.h"
41
42 namespace {
43
44 using tensorflow::tstring;
45
46 constexpr char kTestData[] = "cc/saved_model/testdata";
47 const char* kServeTag[] = {"serve"};
48
SavedModelPath(tensorflow::StringPiece saved_model_dir)49 std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
50 return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
51 kTestData, saved_model_dir);
52 }
53
54 // This value parameterized test allows us to test both TFRT
55 // and non TFRT runtimes.
56 // https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
57 class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
58
TEST_P(CSavedModelAPITest,LoadsSavedModelWithTags)59 TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
60 TF_Status* status = TF_NewStatus();
61 TFE_ContextOptions* opts = TFE_NewContextOptions();
62 bool use_tfrt = GetParam();
63 if (use_tfrt) {
64 TFE_DeleteContextOptions(opts);
65 TF_DeleteStatus(status);
66 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
67 }
68
69 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
70
71 TFE_Context* ctx = TFE_NewContext(opts, status);
72 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
73 TFE_DeleteContextOptions(opts);
74
75 std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
76
77 TF_SavedModel* saved_model =
78 TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
79
80 // TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
81 // That unblocks writing other tests that require a TF_SavedModel*,
82 // like loading a ConcreteFunction. This test at least checks that the
83 // C API builds and can be minimally run.
84 EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
85
86 TF_DeleteSavedModel(saved_model);
87 TF_DeleteStatus(status);
88 TFE_DeleteContext(ctx);
89 }
90
TEST_P(CSavedModelAPITest,LoadsSavedModel)91 TEST_P(CSavedModelAPITest, LoadsSavedModel) {
92 TF_Status* status = TF_NewStatus();
93 TFE_ContextOptions* opts = TFE_NewContextOptions();
94 bool use_tfrt = GetParam();
95 if (use_tfrt) {
96 TFE_DeleteContextOptions(opts);
97 TF_DeleteStatus(status);
98 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
99 }
100
101 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
102
103 TFE_Context* ctx = TFE_NewContext(opts, status);
104 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
105 TFE_DeleteContextOptions(opts);
106
107 std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
108
109 TF_SavedModel* saved_model =
110 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
111
112 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
113 TF_ConcreteFunction* compute_fn =
114 TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
115 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
116
117 std::vector<TFE_TensorHandle*> compute_fn_inputs;
118 TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
119 TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
120 compute_fn_inputs.push_back(input_a);
121 compute_fn_inputs.push_back(input_b);
122
123 TFE_Op* compute_fn_op = TF_ConcreteFunctionMakeCallOp(
124 compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status);
125 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
126
127 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
128 // inputs + outputs a function has.
129 TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
130 int num_retvals = 1;
131
132 TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
133 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
134
135 TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
136 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
137
138 EXPECT_EQ(TF_NumDims(result), 0);
139 float output_value = *static_cast<float*>(TF_TensorData(result));
140 // (1 + 2) * (2 + 1) / 3 + 5 should be 8
141 EXPECT_FLOAT_EQ(output_value, 8.0);
142
143 TF_DeleteTensor(result);
144 TFE_DeleteTensorHandle(compute_fn_outputs[0]);
145 TFE_DeleteTensorHandle(input_a);
146 TFE_DeleteTensorHandle(input_b);
147 TFE_DeleteOp(compute_fn_op);
148 TF_DeleteSavedModel(saved_model);
149 TF_DeleteStatus(status);
150 TFE_DeleteContext(ctx);
151 }
152
153 // This tests running the "serving_default" SignatureDefFunction from the
154 // VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
155 // protobuf in the metagraph looks like:
156 // signature_def: {
157 // key : "serving_default"
158 // value: {
159 // inputs: {
160 // key : "a"
161 // value: {
162 // name : "serving_default_a:0"
163 // dtype: DT_FLOAT
164 // tensor_shape: {
165 // }
166 // }
167 // }
168 // inputs: {
169 // key : "b"
170 // value: {
171 // name : "serving_default_b:0"
172 // dtype: DT_FLOAT
173 // tensor_shape: {
174 // }
175 // }
176 // }
177 // outputs: {
178 // key : "output_0"
179 // value: {
180 // name : "StatefulPartitionedCall:0"
181 // dtype: DT_FLOAT
182 // tensor_shape: {
183 // }
184 // }
185 // }
186 // method_name: "tensorflow/serving/predict"
187 // }
188 // }
TEST_P(CSavedModelAPITest,RunsSignatureDefFunction)189 TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
190 TF_Status* status = TF_NewStatus();
191 TFE_ContextOptions* opts = TFE_NewContextOptions();
192 bool use_tfrt = GetParam();
193 if (use_tfrt) {
194 TFE_DeleteContextOptions(opts);
195 TF_DeleteStatus(status);
196 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
197 }
198
199 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
200
201 TFE_Context* ctx = TFE_NewContext(opts, status);
202 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
203 TFE_DeleteContextOptions(opts);
204
205 std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
206
207 TF_SavedModel* saved_model =
208 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
209
210 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
211 TF_SignatureDefFunction* serving_default =
212 TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
213 status);
214 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
215
216 TF_SignatureDefFunctionMetadata* metadata =
217 TF_SignatureDefFunctionGetMetadata(serving_default);
218
219 const TF_SignatureDefParamList* args =
220 TF_SignatureDefFunctionMetadataArgs(metadata);
221 const TF_SignatureDefParamList* returns =
222 TF_SignatureDefFunctionMetadataReturns(metadata);
223
224 EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
225 const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
226 const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
227 const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
228
229 // Input "a" is a scalar, float32 tensor
230 EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
231 EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
232 EXPECT_EQ(0, TF_ShapeDims(shape_a));
233
234 const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
235 const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
236 const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
237
238 // Input "b" is a scalar, float32 tensor
239 EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
240 EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
241 EXPECT_EQ(0, TF_ShapeDims(shape_b));
242
243 EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
244
245 const TF_SignatureDefParam* param_out =
246 TF_SignatureDefParamListGet(returns, 0);
247 const TF_TensorSpec* tensor_spec_out =
248 TF_SignatureDefParamTensorSpec(param_out);
249 const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
250
251 // Output "output_0" is a scalar, float32 tensor
252 EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
253 EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
254 EXPECT_EQ(0, TF_ShapeDims(shape_out));
255
256 std::vector<TFE_TensorHandle*> compute_fn_inputs;
257 TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
258 TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
259 compute_fn_inputs.push_back(input_a);
260 compute_fn_inputs.push_back(input_b);
261
262 TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
263 serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
264 status);
265 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
266
267 std::vector<TFE_TensorHandle*> compute_fn_outputs(
268 TF_SignatureDefParamListSize(returns));
269 int num_retvals = TF_SignatureDefParamListSize(returns);
270
271 TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
272 status);
273 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
274
275 TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
276 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
277
278 EXPECT_EQ(TF_NumDims(result), 0);
279 float output_value = *static_cast<float*>(TF_TensorData(result));
280 // (1 + 2) * (2 + 1) / 3 + 5 should be 8
281 EXPECT_FLOAT_EQ(output_value, 8.0);
282
283 TF_DeleteTensor(result);
284 TFE_DeleteTensorHandle(compute_fn_outputs[0]);
285 TFE_DeleteTensorHandle(input_a);
286 TFE_DeleteTensorHandle(input_b);
287 TFE_DeleteOp(serving_default_op);
288 TF_DeleteSavedModel(saved_model);
289 TF_DeleteStatus(status);
290 TFE_DeleteContext(ctx);
291 }
292
TEST_P(CSavedModelAPITest,LoadsAssetSavedModel)293 TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
294 TF_Status* status = TF_NewStatus();
295 TFE_ContextOptions* opts = TFE_NewContextOptions();
296 bool use_tfrt = GetParam();
297 if (use_tfrt) {
298 TFE_DeleteContextOptions(opts);
299 TF_DeleteStatus(status);
300 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
301 }
302
303 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
304
305 TFE_Context* ctx = TFE_NewContext(opts, status);
306 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
307 TFE_DeleteContextOptions(opts);
308
309 std::string model_dir = SavedModelPath("AssetModule");
310
311 TF_SavedModel* saved_model =
312 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
313
314 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
315 TF_ConcreteFunction* read_file_fn =
316 TF_GetSavedModelConcreteFunction(saved_model, "read_file", status);
317 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
318
319 TFE_Op* read_file_op =
320 TF_ConcreteFunctionMakeCallOp(read_file_fn, nullptr, 0, status);
321 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
322
323 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
324 // inputs + outputs a function has.
325 TFE_TensorHandle* read_file_fn_outputs[1] = {nullptr};
326 int num_retvals = 1;
327
328 TFE_Execute(read_file_op, &read_file_fn_outputs[0], &num_retvals, status);
329 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
330
331 TF_Tensor* result = TFE_TensorHandleResolve(read_file_fn_outputs[0], status);
332 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
333
334 EXPECT_EQ(TF_NumDims(result), 0);
335 tensorflow::tstring* output_value =
336 static_cast<tensorflow::tstring*>(TF_TensorData(result));
337 std::string file_contents(*output_value);
338 EXPECT_NE(file_contents.find("TEST ASSET FILE CONTENTS"), std::string::npos);
339
340 TF_DeleteTensor(result);
341 TFE_DeleteTensorHandle(read_file_fn_outputs[0]);
342 TFE_DeleteOp(read_file_op);
343 TF_DeleteSavedModel(saved_model);
344 TF_DeleteStatus(status);
345 TFE_DeleteContext(ctx);
346 }
347
TEST_P(CSavedModelAPITest,LoadsStaticHashtableSavedModel)348 TEST_P(CSavedModelAPITest, LoadsStaticHashtableSavedModel) {
349 TF_Status* status = TF_NewStatus();
350 TFE_ContextOptions* opts = TFE_NewContextOptions();
351 bool use_tfrt = GetParam();
352 if (use_tfrt) {
353 TFE_DeleteContextOptions(opts);
354 TF_DeleteStatus(status);
355 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
356 }
357
358 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
359
360 TFE_Context* ctx = TFE_NewContext(opts, status);
361 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
362 TFE_DeleteContextOptions(opts);
363
364 std::string model_dir = SavedModelPath("StaticHashTableModule");
365
366 TF_SavedModel* saved_model =
367 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
368
369 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
370 TF_ConcreteFunction* lookup_fn =
371 TF_GetSavedModelConcreteFunction(saved_model, "lookup", status);
372 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
373
374 // Note(bmzhao): Based on static_hashtable_asset.txt, we expect the following
375 // mapping:
376 // "foo" -> 0
377 // "bar" -> 1
378 // "baz" -> 2
379 // "wombat" -> 3
380 // all other strings -> -1
381
382 // Call lookup function with input "foo", expecting an output of 0
383 {
384 std::vector<TFE_TensorHandle*> lookup_fn_inputs;
385 TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("foo"));
386 lookup_fn_inputs.push_back(input_foo);
387
388 TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
389 lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
390 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
391
392 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
393 // inputs + outputs a function has.
394 TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
395 int num_retvals = 1;
396
397 TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
398 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
399
400 TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
401 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
402
403 EXPECT_EQ(TF_NumDims(result), 0);
404 tensorflow::int64* output_value =
405 static_cast<tensorflow::int64*>(TF_TensorData(result));
406 EXPECT_EQ(*output_value, 0);
407
408 TF_DeleteTensor(result);
409 TFE_DeleteTensorHandle(input_foo);
410 TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
411 TFE_DeleteOp(lookup_op);
412 }
413
414 // Call lookup function with input "baz", expecting an output of 2
415 {
416 std::vector<TFE_TensorHandle*> lookup_fn_inputs;
417 TFE_TensorHandle* input_foo = TestScalarTensorHandle(ctx, tstring("baz"));
418 lookup_fn_inputs.push_back(input_foo);
419
420 TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
421 lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
422 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
423
424 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
425 // inputs + outputs a function has.
426 TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
427 int num_retvals = 1;
428
429 TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
430 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
431
432 TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
433 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
434
435 EXPECT_EQ(TF_NumDims(result), 0);
436 tensorflow::int64* output_value =
437 static_cast<tensorflow::int64*>(TF_TensorData(result));
438 EXPECT_EQ(*output_value, 2);
439
440 TF_DeleteTensor(result);
441 TFE_DeleteTensorHandle(input_foo);
442 TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
443 TFE_DeleteOp(lookup_op);
444 }
445
446 // Call lookup function w/input "NON-EXISTENT-KEY", expecting an output of -1
447 {
448 std::vector<TFE_TensorHandle*> lookup_fn_inputs;
449 TFE_TensorHandle* input_foo =
450 TestScalarTensorHandle(ctx, tstring("NON-EXISTENT-KEY"));
451 lookup_fn_inputs.push_back(input_foo);
452
453 TFE_Op* lookup_op = TF_ConcreteFunctionMakeCallOp(
454 lookup_fn, lookup_fn_inputs.data(), lookup_fn_inputs.size(), status);
455 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
456
457 // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
458 // inputs + outputs a function has.
459 TFE_TensorHandle* lookup_fn_outputs[1] = {nullptr};
460 int num_retvals = 1;
461
462 TFE_Execute(lookup_op, &lookup_fn_outputs[0], &num_retvals, status);
463 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
464
465 TF_Tensor* result = TFE_TensorHandleResolve(lookup_fn_outputs[0], status);
466 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
467
468 EXPECT_EQ(TF_NumDims(result), 0);
469 tensorflow::int64* output_value =
470 static_cast<tensorflow::int64*>(TF_TensorData(result));
471 EXPECT_EQ(*output_value, -1);
472
473 TF_DeleteTensor(result);
474 TFE_DeleteTensorHandle(input_foo);
475 TFE_DeleteTensorHandle(lookup_fn_outputs[0]);
476 TFE_DeleteOp(lookup_op);
477 }
478
479 TF_DeleteSavedModel(saved_model);
480 TF_DeleteStatus(status);
481 TFE_DeleteContext(ctx);
482 }
483
TEST_P(CSavedModelAPITest,LoadSavedModelWithUninitializedVariable)484 TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) {
485 TF_Status* status = TF_NewStatus();
486 TFE_ContextOptions* opts = TFE_NewContextOptions();
487 bool use_tfrt = GetParam();
488 if (use_tfrt) {
489 TFE_DeleteContextOptions(opts);
490 TF_DeleteStatus(status);
491 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
492 }
493
494 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
495
496 TFE_Context* ctx = TFE_NewContext(opts, status);
497 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
498 TFE_DeleteContextOptions(opts);
499
500 std::string model_dir = tensorflow::io::JoinPath(
501 tensorflow::testing::TensorFlowSrcRoot(),
502 "c/experimental/saved_model/internal/testdata/UninitializedVariable");
503
504 TF_SavedModel* saved_model =
505 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
506 EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
507
508 tensorflow::TFSavedModelAPI* model_api =
509 tensorflow::down_cast<tensorflow::TFSavedModelAPI*>(
510 tensorflow::unwrap(saved_model));
511 tensorflow::Variable* uninitialized_variable;
512 ASSERT_EQ(tensorflow::Status::OK(),
513 model_api->GetVariable("uninitialized_variable",
514 &uninitialized_variable));
515 ASSERT_EQ(tensorflow::DT_FLOAT, uninitialized_variable->dtype());
516
517 ASSERT_EQ(tensorflow::Status::OK(),
518 model_api->GetVariable("sub_module.uninitialized_variable",
519 &uninitialized_variable));
520 ASSERT_EQ(tensorflow::DT_INT64, uninitialized_variable->dtype());
521
522 TF_DeleteSavedModel(saved_model);
523 TF_DeleteStatus(status);
524 TFE_DeleteContext(ctx);
525 }
526
TEST_P(CSavedModelAPITest,LoadSavedModelWithWhileLoop)527 TEST_P(CSavedModelAPITest, LoadSavedModelWithWhileLoop) {
528 TF_Status* status = TF_NewStatus();
529 TFE_ContextOptions* opts = TFE_NewContextOptions();
530 bool use_tfrt = GetParam();
531 if (use_tfrt) {
532 TFE_DeleteContextOptions(opts);
533 TF_DeleteStatus(status);
534 GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
535 }
536
537 TFE_ContextOptionsSetTfrt(opts, use_tfrt);
538
539 TFE_Context* ctx = TFE_NewContext(opts, status);
540 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
541 TFE_DeleteContextOptions(opts);
542
543 std::string model_dir = tensorflow::io::JoinPath(
544 tensorflow::testing::TensorFlowSrcRoot(),
545 "c/experimental/saved_model/internal/testdata/SimpleWhileLoop");
546
547 TF_SavedModel* saved_model =
548 TF_LoadSavedModel(model_dir.c_str(), ctx, status);
549 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
550
551 TF_ConcreteFunction* while_fn =
552 TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
553 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
554
555 std::vector<TFE_TensorHandle*> while_fn_inputs;
556 while_fn_inputs.push_back(TestScalarTensorHandle(ctx, 10.0f));
557
558 TFE_Op* while_fn_op = TF_ConcreteFunctionMakeCallOp(
559 while_fn, while_fn_inputs.data(), while_fn_inputs.size(), status);
560 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
561
562 TFE_TensorHandle* while_fn_outputs[1] = {nullptr};
563 int num_retvals = 1;
564
565 TFE_Execute(while_fn_op, &while_fn_outputs[0], &num_retvals, status);
566 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
567
568 TF_Tensor* result = TFE_TensorHandleResolve(while_fn_outputs[0], status);
569 ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
570 ASSERT_EQ(TF_NumDims(result), 0);
571 float output_value = *static_cast<float*>(TF_TensorData(result));
572 ASSERT_FLOAT_EQ(output_value, 55); // 10+9+...+1
573
574 TF_DeleteTensor(result);
575 TFE_DeleteTensorHandle(while_fn_outputs[0]);
576 TFE_DeleteOp(while_fn_op);
577 TFE_DeleteTensorHandle(while_fn_inputs[0]);
578 TF_DeleteSavedModel(saved_model);
579 TF_DeleteStatus(status);
580 TFE_DeleteContext(ctx);
581 }
582
583 INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
584 ::testing::Bool());
585
586 } // namespace
587