1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api_test_util.h"
17
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21
22 using tensorflow::string;
23
TestScalarTensorHandle(float value)24 TFE_TensorHandle* TestScalarTensorHandle(float value) {
25 float data[] = {value};
26 TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
27 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
28 TF_Status* status = TF_NewStatus();
29 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
30 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
31 TF_DeleteTensor(t);
32 TF_DeleteStatus(status);
33 return th;
34 }
35
TestScalarTensorHandle(int value)36 TFE_TensorHandle* TestScalarTensorHandle(int value) {
37 int data[] = {value};
38 TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
39 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
40 TF_Status* status = TF_NewStatus();
41 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
42 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
43 TF_DeleteTensor(t);
44 TF_DeleteStatus(status);
45 return th;
46 }
47
TestScalarTensorHandle(bool value)48 TFE_TensorHandle* TestScalarTensorHandle(bool value) {
49 bool data[] = {value};
50 TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
51 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
52 TF_Status* status = TF_NewStatus();
53 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
54 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
55 TF_DeleteTensor(t);
56 TF_DeleteStatus(status);
57 return th;
58 }
59
DoubleTestMatrixTensorHandle()60 TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
61 int64_t dims[] = {2, 2};
62 double data[] = {1.0, 2.0, 3.0, 4.0};
63 TF_Tensor* t = TF_AllocateTensor(
64 TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
65 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
66 TF_Status* status = TF_NewStatus();
67 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
68 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
69 TF_DeleteTensor(t);
70 TF_DeleteStatus(status);
71 return th;
72 }
73
TestMatrixTensorHandle()74 TFE_TensorHandle* TestMatrixTensorHandle() {
75 int64_t dims[] = {2, 2};
76 float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
77 TF_Tensor* t = TF_AllocateTensor(
78 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
79 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
80 TF_Status* status = TF_NewStatus();
81 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
82 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
83 TF_DeleteTensor(t);
84 TF_DeleteStatus(status);
85 return th;
86 }
87
DoubleTestMatrixTensorHandle3X2()88 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
89 int64_t dims[] = {3, 2};
90 double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
91 TF_Tensor* t = TF_AllocateTensor(
92 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
93 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
94 TF_Status* status = TF_NewStatus();
95 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
96 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
97 TF_DeleteTensor(t);
98 TF_DeleteStatus(status);
99 return th;
100 }
101
TestMatrixTensorHandle3X2()102 TFE_TensorHandle* TestMatrixTensorHandle3X2() {
103 int64_t dims[] = {3, 2};
104 float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
105 TF_Tensor* t = TF_AllocateTensor(
106 TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
107 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
108 TF_Status* status = TF_NewStatus();
109 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
110 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
111 TF_DeleteTensor(t);
112 TF_DeleteStatus(status);
113 return th;
114 }
115
MatMulOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)116 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
117 TF_Status* status = TF_NewStatus();
118
119 TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
120 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
121 TFE_OpAddInput(op, a, status);
122 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
123 TFE_OpAddInput(op, b, status);
124 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
125 TF_DeleteStatus(status);
126 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
127
128 return op;
129 }
130
ShapeOp(TFE_Context * ctx,TFE_TensorHandle * a)131 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
132 TF_Status* status = TF_NewStatus();
133
134 TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
135 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
136 TFE_OpAddInput(op, a, status);
137 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
138 TF_DeleteStatus(status);
139 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
140
141 return op;
142 }
143
TestAxisTensorHandle()144 TFE_TensorHandle* TestAxisTensorHandle() {
145 int64_t dims[] = {1};
146 int data[] = {1};
147 TF_Tensor* t = TF_AllocateTensor(
148 TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
149 memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
150 TF_Status* status = TF_NewStatus();
151 TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
152 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
153 TF_DeleteTensor(t);
154 TF_DeleteStatus(status);
155 return th;
156 }
157
MinOp(TFE_Context * ctx,TFE_TensorHandle * input,TFE_TensorHandle * axis)158 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
159 TFE_TensorHandle* axis) {
160 TF_Status* status = TF_NewStatus();
161
162 TFE_Op* op = TFE_NewOp(ctx, "Min", status);
163 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
164 TFE_OpAddInput(op, input, status);
165 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
166 TFE_OpAddInput(op, axis, status);
167 CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
168 TFE_OpSetAttrBool(op, "keep_dims", 1);
169 TFE_OpSetAttrType(op, "Tidx", TF_INT32);
170 TF_DeleteStatus(status);
171 TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
172
173 return op;
174 }
175
GetDeviceName(TFE_Context * ctx,string * device_name,const char * device_type)176 bool GetDeviceName(TFE_Context* ctx, string* device_name,
177 const char* device_type) {
178 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
179 TF_NewStatus(), TF_DeleteStatus);
180 TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
181 CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
182
183 const int num_devices = TF_DeviceListCount(devices);
184 for (int i = 0; i < num_devices; ++i) {
185 const string dev_type(TF_DeviceListType(devices, i, status.get()));
186 CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
187 const string dev_name(TF_DeviceListName(devices, i, status.get()));
188 CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
189 if (dev_type == device_type) {
190 *device_name = dev_name;
191 LOG(INFO) << "Found " << device_type << " device " << *device_name;
192 TF_DeleteDeviceList(devices);
193 return true;
194 }
195 }
196 TF_DeleteDeviceList(devices);
197 return false;
198 }
199