1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17 
18 #include <memory>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/optional.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/c/c_api.h"
24 #include "tensorflow/c/eager/c_api.h"
25 #include "tensorflow/c/eager/c_api_experimental.h"
26 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/c/tf_status_helper.h"
29 
30 namespace tensorflow {
31 namespace parallel_device {
32 namespace {
33 
34 class OpDeleter {
35  public:
operator ()(TFE_Op * to_delete) const36   void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
37 };
38 
39 using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
40 
41 using MaybeParallelTensorOwned =
42     absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
43 
44 using MaybeParallelTensorUnowned =
45     absl::variant<ParallelTensor*, TFE_TensorHandle*>;
46 
47 // A ParallelDevice on its own is not registered with a TFE_Context, and so has
48 // no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
49 // name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
50 // placed on the parallel device.
51 class NamedParallelDevice {
52  public:
NamedParallelDevice(const std::string & name,std::unique_ptr<ParallelDevice> parallel_device)53   NamedParallelDevice(const std::string& name,
54                       std::unique_ptr<ParallelDevice> parallel_device)
55       : device_name_(name), parallel_device_(std::move(parallel_device)) {}
name() const56   const std::string& name() const { return device_name_; }
device() const57   const ParallelDevice& device() const { return *parallel_device_; }
58 
59  private:
60   std::string device_name_;
61   std::unique_ptr<ParallelDevice> parallel_device_;
62 };
63 
ExecuteWithSpecialOps(const ParallelDevice & parallel_device,const std::string & parallel_device_name,TFE_Context * context,std::vector<MaybeParallelTensorUnowned> inputs,const char * operation_name,const TFE_OpAttrs * attributes,int expected_max_outputs,TF_Status * status)64 absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
65     const ParallelDevice& parallel_device,
66     const std::string& parallel_device_name, TFE_Context* context,
67     std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
68     const TFE_OpAttrs* attributes, int expected_max_outputs,
69     TF_Status* status) {
70   absl::optional<std::vector<MaybeParallelTensorOwned>> result;
71   // TODO(allenl): We should remove "TPU" from these op names at the very least,
72   // or consider other ways of packing/unpacking parallel tensors.
73   if (operation_name == std::string("TPUReplicatedInput")) {
74     // Special-cased operation for packing per-device tensors into one parallel
75     // tensor.
76     if (inputs.size() != parallel_device.num_underlying_devices()) {
77       std::string message(absl::StrCat(
78           "The parallel device ", parallel_device_name, " expected ",
79           parallel_device.num_underlying_devices(),
80           " inputs to TPUReplicatedInput, but got ", inputs.size()));
81       TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
82       return result;
83     }
84     std::vector<TensorHandlePtr> components;
85     components.reserve(inputs.size());
86     for (int i = 0; i < inputs.size(); ++i) {
87       if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
88         std::string message(absl::StrCat(
89             "Expected all inputs to TPUReplicatedInput to be non-parallel "
90             "TensorHandles. The input ",
91             i,
92             " was a parallel tensor (already "
93             "placed on the parallel device)."));
94         TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
95         return result;
96       }
97       components.emplace_back(TFE_TensorHandleCopySharingTensor(
98           absl::get<TFE_TensorHandle*>(inputs[i]), status));
99     }
100     std::vector<MaybeParallelTensorOwned> result_content;
101     result_content.reserve(1);
102     result_content.push_back(ParallelTensor::FromTensorHandles(
103         parallel_device, std::move(components), status));
104     if (TF_GetCode(status) != TF_OK) return result;
105     result.emplace(std::move(result_content));
106     return result;
107   } else if (operation_name == std::string("TPUReplicatedOutput")) {
108     // Special-cased operation for un-packing one parallel tensor into
109     // per-device tensors.
110     OpPtr op(TFE_NewOp(context, operation_name, status));
111     TFE_OpAddAttrs(op.get(), attributes);
112     int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
113     if (TF_GetCode(status) != TF_OK) return result;
114     if (expected_outputs != parallel_device.num_underlying_devices()) {
115       std::string message(absl::StrCat(
116           "The parallel device ", parallel_device_name, " expected ",
117           parallel_device.num_underlying_devices(),
118           " outputs for TPUReplicatedOutput, but got ", expected_outputs));
119       TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
120       return result;
121     }
122     if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
123       TF_SetStatus(status, TF_INVALID_ARGUMENT,
124                    "Expected the input to "
125                    "TPUReplicatedOutput to be a parallel tensor (placed on the "
126                    "parallel device).");
127       return result;
128     }
129     ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
130     std::vector<MaybeParallelTensorOwned> outputs;
131     outputs.reserve(t->num_tensors());
132     for (int i = 0; i < t->num_tensors(); ++i) {
133       TensorHandlePtr this_output(
134           TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
135       outputs.emplace_back(std::move(this_output));
136       if (TF_GetCode(status) != TF_OK) return result;
137     }
138     result.emplace(std::move(outputs));
139     return result;
140   }
141   std::vector<ParallelTensor*> parallel_inputs;
142   std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
143   parallel_inputs.reserve(inputs.size());
144   implicitly_broadcast_tensors.reserve(inputs.size());  // not tight
145   for (const auto& input : inputs) {
146     if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
147       // Non-parallel tensors are implicitly broadcast, i.e. set as the input
148       // to each parallel operation.
149       //
150       // TODO(allenl): There may be smarter ways to do this copy in some
151       // cases, i.e. with a collective broadcast. We'll need to be careful
152       // about things that are taken as inputs on the host or on their
153       // existing device (for multi-device functions).
154       std::unique_ptr<ParallelTensor> parallel_tensor(
155           parallel_device.CopyToParallelDevice(
156               context, absl::get<TFE_TensorHandle*>(input), status));
157       if (TF_GetCode(status) != TF_OK) return result;
158       parallel_inputs.push_back(parallel_tensor.get());
159       implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
160     } else {
161       parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
162     }
163   }
164   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
165       maybe_parallel_results(
166           parallel_device.Execute(context, parallel_inputs, operation_name,
167                                   attributes, expected_max_outputs, status));
168   if (!maybe_parallel_results.has_value()) return result;
169   std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
170       std::move(maybe_parallel_results.value()));
171   std::vector<MaybeParallelTensorOwned> result_content;
172   result_content.reserve(parallel_results.size());
173   for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
174     result_content.push_back(
175         MaybeParallelTensorOwned(std::move(parallel_result)));
176   }
177   result.emplace(std::move(result_content));
178   return result;
179 }
180 
181 // Used as an argument to TFE_NewCustomDeviceTensorHandle, indicating how
182 // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
183 // reference counts drop to zero.
ParallelTensorDeallocator(void * data,void * arg)184 void ParallelTensorDeallocator(void* data, void* arg) {
185   delete reinterpret_cast<ParallelTensor*>(data);
186 }
187 
188 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing the
189 // number of dimensions of a parallel tensor.
ParallelTensorNumDims(void * data,void * arg,TF_Status * status)190 int ParallelTensorNumDims(void* data, void* arg, TF_Status* status) {
191   const std::vector<int64_t>* shape;
192   Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
193   if (!s.ok()) {
194     Set_TF_Status_from_Status(status, s);
195     return -1;
196   }
197   return shape->size();
198 }
199 
200 // Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing a
201 // dimension of a parallel tensor.
ParallelTensorDim(void * data,int dim_index,void * arg,TF_Status * status)202 int64_t ParallelTensorDim(void* data, int dim_index, void* arg,
203                           TF_Status* status) {
204   const std::vector<int64_t>* shape;
205   Status s = reinterpret_cast<ParallelTensor*>(data)->Shape(&shape);
206   if (!s.ok()) {
207     Set_TF_Status_from_Status(status, s);
208     return -1;
209   }
210   return (*shape)[dim_index];
211 }
212 
ParallelTensorToTensorHandle(const std::string & parallel_device_name,TFE_Context * context,std::unique_ptr<ParallelTensor> t,TF_Status * status)213 TensorHandlePtr ParallelTensorToTensorHandle(
214     const std::string& parallel_device_name, TFE_Context* context,
215     std::unique_ptr<ParallelTensor> t, TF_Status* status) {
216   // The resulting TensorHandle owns an opaque pointer to "device memory", which
217   // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
218   // deleted, it will call ParallelTensorDeallocator to free the struct.
219   ParallelTensor* t_released = t.release();
220   return TensorHandlePtr(TFE_NewCustomDeviceTensorHandle(
221       context, parallel_device_name.c_str(), t_released->dtype(), t_released,
222       &ParallelTensorNumDims, &ParallelTensorDim, &ParallelTensorDeallocator,
223       nullptr, status));
224 }
225 
226 // For TFE_CustomDevice::copy_tensor_to_device in the parallel device
227 // registration.
228 //
229 // Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
230 // a ParallelTensor with one copy of `tensor` for each device in the
231 // ParallelDevice.
232 //
233 // Since this function is used to satisfy the TFE_CustomDevice C API,
234 // device_info is passed in using a C-style generic. It must always be a
235 // ParallelDevice.
CopyToParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)236 TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
237                                        TFE_TensorHandle* tensor,
238                                        TF_Status* status, void* device_info) {
239   NamedParallelDevice* named_device =
240       reinterpret_cast<NamedParallelDevice*>(device_info);
241   const ParallelDevice& dev = named_device->device();
242   std::unique_ptr<ParallelTensor> parallel_tensor(
243       dev.CopyToParallelDevice(context, tensor, status));
244   if (TF_GetCode(status) != TF_OK) return nullptr;
245   return ParallelTensorToTensorHandle(named_device->name(), context,
246                                       std::move(parallel_tensor), status)
247       .release();
248 }
249 
250 // For TFE_CustomDevice::copy_tensor_from_device in the parallel device
251 // registration.
252 //
253 // Currently this is an error, and un-packing ParallelTensors must be performed
254 // explicitly by running a TPUReplicatedOutput operation on the parallel device.
255 //
256 // TODO(allenl): There are some use-cases that are only supported by copying to
257 // host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
258 // need to return something here or address these use-cases one by one.
CopyTensorFromParallelDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)259 TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
260                                                TFE_TensorHandle* tensor,
261                                                const char* target_device_name,
262                                                TF_Status* status,
263                                                void* device_info) {
264   TF_SetStatus(status, TF_UNIMPLEMENTED,
265                "Trying to copy a tensor out of a parallel device. Since there "
266                "are multiple components to parallel tensors, they must be "
267                "unpacked explicitly.");
268   return nullptr;
269 }
270 
271 // For TFE_CustomDevice::execute in the parallel device registration.
272 //
273 // Since this function is used to satisfy the TFE_CustomDevice C API,
274 // device_info is passed in using a C-style generic. It must always be a
275 // ParallelDevice.
ParallelDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status,void * device_info)276 void ParallelDeviceExecute(const TFE_Op* original_op, int* num_outputs,
277                            TFE_TensorHandle** outputs, TF_Status* status,
278                            void* device_info) {
279   const char* requested_placement = TFE_OpGetDevice(original_op, status);
280   if (*requested_placement == '\0') {
281     TF_SetStatus(
282         status, TF_INTERNAL,
283         "Ops must be placed on the parallel device explicitly, or their inputs "
284         "first un-packed. Got an un-placed op with an input placed on the "
285         "parallel device.");
286     return;
287   }
288   TFE_Context* context = TFE_OpGetContext(original_op, status);
289   if (TF_GetCode(status) != TF_OK) return;
290   const char* operation_name = TFE_OpGetName(original_op, status);
291   if (TF_GetCode(status) != TF_OK) return;
292   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
293 
294   NamedParallelDevice* named_device =
295       reinterpret_cast<NamedParallelDevice*>(device_info);
296   std::vector<MaybeParallelTensorUnowned> typed_inputs;
297   int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
298   if (TF_GetCode(status) != TF_OK) return;
299   typed_inputs.reserve(num_inputs);
300   for (int i = 0; i < num_inputs; ++i) {
301     TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, i, status);
302     if (TF_GetCode(status) != TF_OK) return;
303     const char* tensor_handle_device =
304         TFE_TensorHandleDeviceName(input, status);
305     if (TF_GetCode(status) != TF_OK) return;
306     if (named_device->name() == tensor_handle_device) {
307       // We assume that any tensors already placed on this device are
308       // ParallelTensors.
309       typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
310           TFE_TensorHandleDevicePointer(input, status)));
311       if (TF_GetCode(status) != TF_OK) return;
312     } else {
313       typed_inputs.emplace_back(input);
314     }
315   }
316 
317   absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
318       ExecuteWithSpecialOps(named_device->device(), named_device->name(),
319                             context, std::move(typed_inputs), operation_name,
320                             attributes, *num_outputs, status));
321   if (TF_GetCode(status) != TF_OK) return;
322   if (!maybe_typed_outputs.has_value()) {
323     TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
324     return;
325   }
326 
327   std::vector<MaybeParallelTensorOwned> typed_outputs(
328       std::move(maybe_typed_outputs.value()));
329 
330   if (typed_outputs.size() > *num_outputs) {
331     TF_SetStatus(status, TF_INTERNAL,
332                  "The allocated output buffer was too small.");
333     return;
334   }
335 
336   for (int i = 0; i < typed_outputs.size(); ++i) {
337     MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
338     if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
339       outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
340     } else {
341       outputs[i] = ParallelTensorToTensorHandle(
342                        named_device->name(), context,
343                        std::move(absl::get<std::unique_ptr<ParallelTensor>>(
344                            typed_output)),
345                        status)
346                        .release();
347       if (TF_GetCode(status) != TF_OK) return;
348     }
349   }
350   *num_outputs = typed_outputs.size();
351 }
352 
353 // For TFE_CustomDevice::delete_device in the parallel device registration.
354 //
355 // Since this function is used to satisfy the TFE_CustomDevice C API,
356 // device_info is passed in using a C-style generic. It must always be a
357 // ParallelDevice.
DeleteParallelDevice(void * device_info)358 void DeleteParallelDevice(void* device_info) {
359   delete reinterpret_cast<NamedParallelDevice*>(device_info);
360 }
361 
362 }  // namespace
363 
AllocateParallelDevice(const char * device_name,const char * const * underlying_devices,int num_underlying_devices,TFE_CustomDevice * device,void ** device_info)364 void AllocateParallelDevice(const char* device_name,
365                             const char* const* underlying_devices,
366                             int num_underlying_devices,
367                             TFE_CustomDevice* device, void** device_info) {
368   device->copy_tensor_to_device = &CopyToParallelDevice;
369   device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
370   device->delete_device = &DeleteParallelDevice;
371   device->execute = &ParallelDeviceExecute;
372   std::vector<std::string> underlying_devices_vector;
373   underlying_devices_vector.reserve(num_underlying_devices);
374   for (int device_index = 0; device_index < num_underlying_devices;
375        ++device_index) {
376     underlying_devices_vector.push_back(underlying_devices[device_index]);
377   }
378   std::unique_ptr<ParallelDevice> parallel_device(
379       new ParallelDevice(underlying_devices_vector));
380   *device_info =
381       new NamedParallelDevice{device_name, std::move(parallel_device)};
382 }
383 }  // namespace parallel_device
384 }  // namespace tensorflow
385