1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
17 
18 #include <array>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 // NOTE(allenl): These tests currently go through TFE_Execute and so are
27 // integration testing rather than purely testing the parallel device. They
28 // correspond fairly well to the implementation, but testing the C++ directly is
29 // another option.
30 
31 namespace tensorflow {
32 namespace parallel_device {
33 
Create(TFE_Context * context,TF_DataType type,const int64_t * dims,const int num_dims,const char * device,TF_Status * status)34 Variable* Variable::Create(TFE_Context* context, TF_DataType type,
35                            const int64_t* dims, const int num_dims,
36                            const char* device, TF_Status* status) {
37   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
38       TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
39   if (TF_GetCode(status) != TF_OK) return nullptr;
40   TFE_OpSetAttrType(op.get(), "dtype", type);
41   TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
42   TFE_OpSetAttrString(op.get(), "container", "", 0);
43   // Use the special GUID for no buffer sharing
44   //
45   // TODO(allenl): Should we provide a better API for this? AFAIK this is the
46   // only reasonable way to make variables with no aliasing using the eager C
47   // API.
48   std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
49   TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
50                       no_sharing.length());
51   TFE_OpSetDevice(op.get(), device, status);
52   if (TF_GetCode(status) != TF_OK) return nullptr;
53   TFE_TensorHandle* var_handle = nullptr;
54   int num_retvals = 1;
55   TFE_Execute(op.get(), &var_handle, &num_retvals, status);
56   if (TF_GetCode(status) != TF_OK) return nullptr;
57   return new Variable(var_handle, type);
58 }
59 
Destroy(TFE_Context * context,TF_Status * status)60 void Variable::Destroy(TFE_Context* context, TF_Status* status) {
61   // Free the backing buffer for the variable.
62   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
63       TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
64   if (TF_GetCode(status) != TF_OK) return;
65   TFE_OpAddInput(op.get(), handle_, status);
66   if (TF_GetCode(status) != TF_OK) return;
67   const char* device = TFE_TensorHandleDeviceName(handle_, status);
68   if (TF_GetCode(status) != TF_OK) return;
69   TFE_OpSetDevice(op.get(), device, status);
70   if (TF_GetCode(status) != TF_OK) return;
71   int num_retvals = 0;
72   TFE_Execute(op.get(), nullptr, &num_retvals, status);
73   if (TF_GetCode(status) != TF_OK) return;
74   // Delete the variable handle itself.
75   TFE_DeleteTensorHandle(handle_);
76 }
77 
Read(TFE_Context * context,TF_Status * status)78 TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
79   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
80       TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
81   if (TF_GetCode(status) != TF_OK) return nullptr;
82   TFE_OpAddInput(op.get(), handle_, status);
83   if (TF_GetCode(status) != TF_OK) return nullptr;
84   const char* device = TFE_TensorHandleDeviceName(handle_, status);
85   if (TF_GetCode(status) != TF_OK) return nullptr;
86   TFE_OpSetDevice(op.get(), device, status);
87   if (TF_GetCode(status) != TF_OK) return nullptr;
88   TFE_OpSetAttrType(op.get(), "dtype", type_);
89   int num_retvals = 1;
90   TFE_TensorHandle* var_value = nullptr;
91   TFE_Execute(op.get(), &var_value, &num_retvals, status);
92   if (TF_GetCode(status) != TF_OK) return nullptr;
93   return TensorHandlePtr(var_value);
94 }
95 
GeneralAssignment(const char * op_name,TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)96 void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
97                                  TFE_TensorHandle* value, TF_Status* status) {
98   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
99       TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
100   if (TF_GetCode(status) != TF_OK) return;
101   TFE_OpSetAttrType(op.get(), "dtype", type_);
102   TFE_OpAddInput(op.get(), handle_, status);
103   if (TF_GetCode(status) != TF_OK) return;
104   TFE_OpAddInput(op.get(), value, status);
105   if (TF_GetCode(status) != TF_OK) return;
106   const char* device = TFE_TensorHandleDeviceName(handle_, status);
107   if (TF_GetCode(status) != TF_OK) return;
108   TFE_OpSetDevice(op.get(), device, status);
109 
110   int num_retvals = 0;
111   TFE_Execute(op.get(), nullptr, &num_retvals, status);
112   if (TF_GetCode(status) != TF_OK) return;
113 }
114 
AssignAdd(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)115 void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
116                          TF_Status* status) {
117   GeneralAssignment("AssignAddVariableOp", context, value, status);
118 }
119 
Assign(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)120 void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
121                       TF_Status* status) {
122   GeneralAssignment("AssignVariableOp", context, value, status);
123 }
124 
125 // Passed to `TF_NewTensor` to indicate how an array of floats should be
126 // deleted.
FloatDeallocator(void * data,size_t,void * arg)127 static void FloatDeallocator(void* data, size_t, void* arg) {
128   delete[] static_cast<float*>(data);
129 }
130 
131 // Creates a TFE_TensorHandle with value `v`.
FloatTensorHandle(float v,TF_Status * status)132 TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
133   const int num_bytes = sizeof(float);
134   float* values = new float[1];
135   values[0] = v;
136   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
137       TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
138                    nullptr),
139       TF_DeleteTensor);
140   return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
141 }
142 
143 // Creates a rank-one TFE_TensorHandle with value `v`.
VectorFloatTensorHandle(const std::vector<float> & v,TF_Status * status)144 TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
145                                         TF_Status* status) {
146   const int num_bytes = v.size() * sizeof(float);
147   float* values = new float[v.size()];
148   memcpy(values, v.data(), num_bytes);
149   int64_t dims = v.size();
150   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
151       TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
152                    &FloatDeallocator, nullptr),
153       TF_DeleteTensor);
154   return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
155 }
156 
157 // Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
158 template <std::size_t num_replicas>
ExtractPerDeviceValues(TFE_Context * context,TFE_TensorHandle * input,std::array<TensorHandlePtr,num_replicas> * components,TF_Status * status)159 void ExtractPerDeviceValues(
160     TFE_Context* context, TFE_TensorHandle* input,
161     std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
162   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
163       TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
164   if (TF_GetCode(status) != TF_OK) return;
165   TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
166   TFE_OpAddInput(op.get(), input, status);
167   if (TF_GetCode(status) != TF_OK) return;
168   const char* device = TFE_TensorHandleDeviceName(input, status);
169   if (TF_GetCode(status) != TF_OK) return;
170   TFE_OpSetDevice(op.get(), device, status);
171   if (TF_GetCode(status) != TF_OK) return;
172 
173   TFE_TensorHandle* result_handles[num_replicas];
174   int num_retvals = num_replicas;
175   TFE_Execute(op.get(), result_handles, &num_retvals, status);
176   if (TF_GetCode(status) != TF_OK) return;
177   for (int i = 0; i < num_replicas; ++i) {
178     (*components)[i].reset(result_handles[i]);
179   }
180 }
181 
Multiply(TFE_Context * context,TFE_TensorHandle * first,TFE_TensorHandle * second,TF_Status * status)182 TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
183                          TFE_TensorHandle* second, TF_Status* status) {
184   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
185       TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
186   if (TF_GetCode(status) != TF_OK) return nullptr;
187   TFE_OpAddInput(op.get(), first, status);
188   if (TF_GetCode(status) != TF_OK) return nullptr;
189   TFE_OpAddInput(op.get(), second, status);
190   if (TF_GetCode(status) != TF_OK) return nullptr;
191   const char* first_device = TFE_TensorHandleDeviceName(first, status);
192   if (TF_GetCode(status) != TF_OK) return nullptr;
193   TFE_OpSetDevice(op.get(), first_device, status);
194 
195   TFE_TensorHandle* result_handle;
196   int num_retvals = 1;
197   TFE_Execute(op.get(), &result_handle, &num_retvals, status);
198   if (TF_GetCode(status) != TF_OK) return nullptr;
199   return TensorHandlePtr(result_handle);
200 }
201 
202 // Create and modify a variable placed on a parallel device which composes
203 // `first_device` and `second_device`.
BasicTestsForTwoDevices(TFE_Context * context,const char * first_device,const char * second_device)204 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
205                              const char* second_device) {
206   // Register the custom device
207   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
208       TF_NewStatus(), TF_DeleteStatus);
209   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
210   std::array<const char*, 2> underlying_devices{first_device, second_device};
211   RegisterParallelDevice(context, device_name, underlying_devices,
212                          status.get());
213   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
214 
215   // Create a variable handle (uninitialized to start) placed on the parallel
216   // device.
217   std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
218     to_delete->Destroy(context, status.get());
219     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
220     delete to_delete;
221   };
222   std::unique_ptr<Variable, decltype(variable_deleter)> variable(
223       Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
224                        status.get()),
225       variable_deleter);
226   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
227 
228   // Assign an initial value to the variable, implicitly mirroring it to each
229   // component device.
230   {
231     TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
232     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
233 
234     variable->Assign(context, initial_value.get(), status.get());
235   }
236 
237   // Read from the variable and verify that we have a parallel tensor.
238   {
239     TensorHandlePtr read = variable->Read(context, status.get());
240     std::array<TensorHandlePtr, 2> components;
241     ExtractPerDeviceValues(context, read.get(), &components, status.get());
242     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
243 
244     ExpectScalarEq<float>(components[0].get(), 20.);
245     ExpectScalarEq<float>(components[1].get(), 20.);
246 
247     std::string first_device =
248         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
249     ASSERT_EQ(underlying_devices[0], first_device);
250     std::string second_device =
251         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
252     ASSERT_EQ(underlying_devices[1], second_device);
253   }
254 
255   // Add a parallel tensor with different values on each device to the variable.
256   {
257     TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
258     TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
259     std::array<TFE_TensorHandle*, 2> components{value_one.get(),
260                                                 value_two.get()};
261     TensorHandlePtr combined_value =
262         CreatePerDeviceValues(context, components, device_name, status.get());
263     variable->AssignAdd(context, combined_value.get(), status.get());
264   }
265 
266   // Read the variable and verify that each component has the right modified
267   // value.
268   {
269     TensorHandlePtr read = variable->Read(context, status.get());
270     std::array<TensorHandlePtr, 2> components;
271     ExtractPerDeviceValues(context, read.get(), &components, status.get());
272     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
273 
274     ExpectScalarEq<float>(components[0].get(), 23.);
275     ExpectScalarEq<float>(components[1].get(), 18.);
276 
277     std::string first_device =
278         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
279     ASSERT_EQ(underlying_devices[0], first_device);
280     std::string second_device =
281         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
282     ASSERT_EQ(underlying_devices[1], second_device);
283   }
284 }
285 
286 }  // namespace parallel_device
287 }  // namespace tensorflow
288