1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
17 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/eager/c_api.h"
28 #include "tensorflow/c/eager/c_api_experimental.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 
33 namespace tensorflow {
34 namespace parallel_device {
35 
36 // Functor for making unique_ptrs slightly more ergonomic. Using
37 // decltype(delete_fn) in the unique_ptr's second template argument requires
38 // passing a function pointer to delete_fn when constructing the unique_ptr.
39 class TensorHandleDeleter {
40  public:
operator()41   void operator()(TFE_TensorHandle* to_delete) const {
42     TFE_DeleteTensorHandle(to_delete);
43   }
44 };
45 
46 using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
47 
48 class ParallelTensor;
49 class DeviceThread;
50 
51 // Forwards operations to `devices`, maintaining ParallelTensor with components
52 // placed on each underlying device.
53 class ParallelDevice {
54  public:
55   // Eager async execution is only supported when remote eager is not in use
56   // (b/157523095).
57   explicit ParallelDevice(const std::vector<std::string>& devices,
58                           const bool is_async = false);
59 
60   ~ParallelDevice();
61 
62   // Helper to copy a tensor handle from another device once for each component
63   // of the ParallelDevice.
64   //
65   // Sets a bad status and returns a nullptr if `tensor` is already on the
66   // ParallelDevice, or if the individual copies fail.
67   std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
68                                                        TFE_TensorHandle* tensor,
69                                                        TF_Status* status) const;
70 
71   // Construct a parallel tensor consisting of the scalar values from `values`.
72   template <typename DataType>
73   std::unique_ptr<ParallelTensor> ScalarsFromSequence(
74       absl::Span<const DataType> values, TFE_Context* context,
75       TF_Status* status) const;
76 
77   // A parallel tensor with scalar integers numbering component devices.
78   std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
79                                             TF_Status* status) const;
80 
81   // The number of devices operations run on.
num_underlying_devices()82   size_t num_underlying_devices() const { return underlying_devices_.size(); }
83 
84   // Takes a description of a single operation being executed on the
85   // ParallelDevice, and in turn runs one operation per component device with
86   // its corresponding inputs from the input ParallelTensors. Wraps the
87   // resulting per-device and per-output TFE_TensorHandles into one
88   // ParallelTensor per output of the original operation.
89   //
90   // Attributes are forwarded to executed operations unmodified.
91   //
92   // The returned optional has a value if and only if `status` evaluates to
93   // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
94   // if sanity checks on dtypes/metadata fail.
95   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
96       TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
97       const char* operation_name, const TFE_OpAttrs* attributes,
98       int expected_max_outputs, TF_Status* status) const;
99 
100   // A non-blocking version of `Execute`. After each call, `Join` must be called
101   // before `StartExecute` is called again. Using `StartExecute` with `Join`
102   // allows the caller to schedule computation on multiple ParallelDevices
103   // without sequencing those operations (first call `StartExecute` on each
104   // parallel device, then call `Join` on each; even if some of the `Join`s
105   // return a bad status the caller must run all of the `Join`s or any future
106   // `StartExecute`s will deadlock).
107   //
108   // If `is_async=false` (constructor argument), `cancellation_manager` must
109   // live until `Join` finishes. If `is_async=true` it must live until `Join` is
110   // followed by `TFE_ContextAsyncWait` to clear pending operations. It will be
111   // used to cancel all other operations if any fails.
112   void StartExecute(TFE_Context* context,
113                     const std::vector<ParallelTensor*>& inputs,
114                     const char* operation_name, const TFE_OpAttrs* attributes,
115                     int expected_max_outputs,
116                     CancellationManager& cancellation_manager) const;
117 
118   // Blocks until the previous `StartExecute` has run `TFE_Execute` on each
119   // device. If is_async=false (constructor argument) this means the ops have
120   // run and have results. If is_async=true it means that all of the
121   // device-specific executors have scheduled the op.
122   //
123   // Accepts inferred shapes for outputs (`expected_output_shapes`), which if
124   // fully defined will avoid querying the shapes of the underlying
125   // TensorHandles when ParallelTensor::Shape is called. This allows async
126   // computation to continue without blocking.
127   //
128   // The return status and value is the same as `Execute`.
129   absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Join(
130       const std::vector<PartialTensorShape>& expected_output_shapes,
131       TF_Status* status) const;
132 
133  private:
134   // A sequence of device names, indicating which devices replicated operations
135   // are forwarded to.
136   const std::vector<std::string> underlying_devices_;
137   // A sequence of thread wrappers, one per device, for executing operations in
138   // parallel.
139   //
140   // Conceptually this is a thread pool with one thread per device. It requires
141   // less synchronization than a thread pool would for this task, since Execute
142   // acquires each thread in order (and so only one Execute will schedule
143   // blocking collective operations at a time), and avoids some dynamic
144   // allocation/scheduling.
145   //
146   // TODO(allenl): Keep a map from outer thread to list of inner threads rather
147   // than a single list of threads so aliased nested parallel devices don't
148   // re-use a thread.
149   std::vector<std::unique_ptr<DeviceThread>> device_threads_;
150   // A cancellation manager to use if the caller does not provide one. When ops
151   // are executed asynchronously this must outlive the queued op, so it can't be
152   // function-local to Execute.
153   mutable std::unique_ptr<CancellationManager> default_cancellation_manager_;
154 };
155 
156 // Contains a tuple of tensors, one on each of the `underlying_devices_` of the
157 // ParallelDevice.
158 class ParallelTensor {
159  public:
160   // Construct a ParallelTensor from TensorHandles placed on the component
161   // devices of a ParallelDevice. If called, ParallelTensor::Shape inspects
162   // `components` to determine a shape.
163   static std::unique_ptr<ParallelTensor> FromTensorHandles(
164       const ParallelDevice& parallel_device,
165       std::vector<TensorHandlePtr> components, TF_Status* status);
166   // Uses the provided shape without additional checks, which avoids blocking
167   // when ParallelTensor::Shape is called.
168   static std::unique_ptr<ParallelTensor> FromTensorHandles(
169       const ParallelDevice& parallel_device,
170       std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
171       TF_Status* status);
172 
num_tensors()173   size_t num_tensors() const { return tensors_.size(); }
tensor(size_t index)174   TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
175 
176   // If the `shape` argument to `FromTensorHandles` is specified, returns that.
177   //
178   // Otherwise if all of the tensors have the same shape, returns that via the
179   // `shape` output argument. This blocks waiting for async tensors, may return
180   // a delayed bad status encountered during async execution, and will return a
181   // bad status unless all tensors have the same shape.
182   Status Shape(const std::vector<int64_t>** shape) const;
dtype()183   TF_DataType dtype() const { return dtype_; }
184 
185  private:
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,absl::Span<const int64> shape,const TF_DataType dtype)186   ParallelTensor(const ParallelDevice& device,
187                  std::vector<TensorHandlePtr> tensors,
188                  absl::Span<const int64> shape, const TF_DataType dtype)
189       : device_(device),
190         tensors_(std::move(tensors)),
191         shape_(std::vector<int64_t>(shape.begin(), shape.end())),
192         dtype_(dtype) {}
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,const TF_DataType dtype)193   ParallelTensor(const ParallelDevice& device,
194                  std::vector<TensorHandlePtr> tensors, const TF_DataType dtype)
195       : device_(device),
196         tensors_(std::move(tensors)),
197         shape_(absl::nullopt),
198         dtype_(dtype) {}
199 
200   const ParallelDevice& device_;
201   const std::vector<TensorHandlePtr> tensors_;
202   // Parallel tensors are immutable but compute their shape lazily unless it is
203   // provided on construction. The optional has a value if the lazy computation
204   // has been completed or the shape was provided on construction.
205   mutable absl::optional<std::vector<int64_t>> shape_;
206   const TF_DataType dtype_;
207 };
208 
209 template <typename DataType>
ScalarsFromSequence(absl::Span<DataType const> values,TFE_Context * context,TF_Status * status)210 std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
211     absl::Span<DataType const> values, TFE_Context* context,
212     TF_Status* status) const {
213   std::vector<TensorHandlePtr> components;
214   components.reserve(underlying_devices_.size());
215 
216   if (values.size() != num_underlying_devices()) {
217     TF_SetStatus(
218         status, TF_INVALID_ARGUMENT,
219         "Number of values did not match number of underlying devices.");
220     return nullptr;
221   }
222   TF_DataType datatype_enum(
223       static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
224   for (int device_index = 0; device_index < num_underlying_devices();
225        ++device_index) {
226     auto device_value = absl::make_unique<DataType>();
227     *device_value = values[device_index];
228     std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
229         TF_NewTensor(
230             datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
231             device_value.release(), sizeof(DataType),
232             [](void* data, size_t, void* arg) {
233               delete reinterpret_cast<DataType*>(data);
234             },
235             nullptr),
236         TF_DeleteTensor);
237     // TODO(allenl): Here and when executing regular operations, we could hold
238     // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
239     // device names repeatedly.
240     std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
241         TFE_NewOp(context, "Const", status), TFE_DeleteOp);
242     if (TF_GetCode(status) != TF_OK) return nullptr;
243     TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
244                     status);
245     if (TF_GetCode(status) != TF_OK) return nullptr;
246     TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
247     if (TF_GetCode(status) != TF_OK) return nullptr;
248     TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
249     TFE_TensorHandle* device_handle;
250     int num_outputs = 1;
251     TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
252     if (TF_GetCode(status) != TF_OK) return nullptr;
253     components.emplace_back(device_handle);
254   }
255   return ParallelTensor::FromTensorHandles(*this, std::move(components),
256                                            status);
257 }
258 
259 }  // namespace parallel_device
260 }  // namespace tensorflow
261 
262 #endif  // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
263