1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api_test_util.h"
17 
18 #include "tensorflow/c/eager/c_api.h"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21 
22 using tensorflow::string;
23 
TestScalarTensorHandle(float value)24 TFE_TensorHandle* TestScalarTensorHandle(float value) {
25   float data[] = {value};
26   TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
27   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
28   TF_Status* status = TF_NewStatus();
29   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
30   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
31   TF_DeleteTensor(t);
32   TF_DeleteStatus(status);
33   return th;
34 }
35 
TestScalarTensorHandle(int value)36 TFE_TensorHandle* TestScalarTensorHandle(int value) {
37   int data[] = {value};
38   TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
39   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
40   TF_Status* status = TF_NewStatus();
41   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
42   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
43   TF_DeleteTensor(t);
44   TF_DeleteStatus(status);
45   return th;
46 }
47 
TestScalarTensorHandle(bool value)48 TFE_TensorHandle* TestScalarTensorHandle(bool value) {
49   bool data[] = {value};
50   TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
51   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
52   TF_Status* status = TF_NewStatus();
53   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
54   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
55   TF_DeleteTensor(t);
56   TF_DeleteStatus(status);
57   return th;
58 }
59 
DoubleTestMatrixTensorHandle()60 TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
61   int64_t dims[] = {2, 2};
62   double data[] = {1.0, 2.0, 3.0, 4.0};
63   TF_Tensor* t = TF_AllocateTensor(
64       TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
65   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
66   TF_Status* status = TF_NewStatus();
67   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
68   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
69   TF_DeleteTensor(t);
70   TF_DeleteStatus(status);
71   return th;
72 }
73 
TestMatrixTensorHandle()74 TFE_TensorHandle* TestMatrixTensorHandle() {
75   int64_t dims[] = {2, 2};
76   float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
77   TF_Tensor* t = TF_AllocateTensor(
78       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
79   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
80   TF_Status* status = TF_NewStatus();
81   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
82   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
83   TF_DeleteTensor(t);
84   TF_DeleteStatus(status);
85   return th;
86 }
87 
DoubleTestMatrixTensorHandle3X2()88 TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
89   int64_t dims[] = {3, 2};
90   double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
91   TF_Tensor* t = TF_AllocateTensor(
92       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
93   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
94   TF_Status* status = TF_NewStatus();
95   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
96   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
97   TF_DeleteTensor(t);
98   TF_DeleteStatus(status);
99   return th;
100 }
101 
TestMatrixTensorHandle3X2()102 TFE_TensorHandle* TestMatrixTensorHandle3X2() {
103   int64_t dims[] = {3, 2};
104   float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
105   TF_Tensor* t = TF_AllocateTensor(
106       TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
107   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
108   TF_Status* status = TF_NewStatus();
109   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
110   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
111   TF_DeleteTensor(t);
112   TF_DeleteStatus(status);
113   return th;
114 }
115 
MatMulOp(TFE_Context * ctx,TFE_TensorHandle * a,TFE_TensorHandle * b)116 TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
117   TF_Status* status = TF_NewStatus();
118 
119   TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
120   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
121   TFE_OpAddInput(op, a, status);
122   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
123   TFE_OpAddInput(op, b, status);
124   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
125   TF_DeleteStatus(status);
126   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
127 
128   return op;
129 }
130 
ShapeOp(TFE_Context * ctx,TFE_TensorHandle * a)131 TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
132   TF_Status* status = TF_NewStatus();
133 
134   TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
135   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
136   TFE_OpAddInput(op, a, status);
137   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
138   TF_DeleteStatus(status);
139   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
140 
141   return op;
142 }
143 
TestAxisTensorHandle()144 TFE_TensorHandle* TestAxisTensorHandle() {
145   int64_t dims[] = {1};
146   int data[] = {1};
147   TF_Tensor* t = TF_AllocateTensor(
148       TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
149   memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
150   TF_Status* status = TF_NewStatus();
151   TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
152   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
153   TF_DeleteTensor(t);
154   TF_DeleteStatus(status);
155   return th;
156 }
157 
MinOp(TFE_Context * ctx,TFE_TensorHandle * input,TFE_TensorHandle * axis)158 TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
159               TFE_TensorHandle* axis) {
160   TF_Status* status = TF_NewStatus();
161 
162   TFE_Op* op = TFE_NewOp(ctx, "Min", status);
163   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
164   TFE_OpAddInput(op, input, status);
165   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
166   TFE_OpAddInput(op, axis, status);
167   CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
168   TFE_OpSetAttrBool(op, "keep_dims", 1);
169   TFE_OpSetAttrType(op, "Tidx", TF_INT32);
170   TF_DeleteStatus(status);
171   TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
172 
173   return op;
174 }
175 
GetDeviceName(TFE_Context * ctx,string * device_name,const char * device_type)176 bool GetDeviceName(TFE_Context* ctx, string* device_name,
177                    const char* device_type) {
178   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
179       TF_NewStatus(), TF_DeleteStatus);
180   TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
181   CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
182 
183   const int num_devices = TF_DeviceListCount(devices);
184   for (int i = 0; i < num_devices; ++i) {
185     const string dev_type(TF_DeviceListType(devices, i, status.get()));
186     CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
187     const string dev_name(TF_DeviceListName(devices, i, status.get()));
188     CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
189     if (dev_type == device_type) {
190       *device_name = dev_name;
191       LOG(INFO) << "Found " << device_type << " device " << *device_name;
192       TF_DeleteDeviceList(devices);
193       return true;
194     }
195   }
196   TF_DeleteDeviceList(devices);
197   return false;
198 }
199