1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
17 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "absl/types/optional.h"
24 #include "absl/types/span.h"
25 #include "absl/types/variant.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/eager/c_api.h"
28 #include "tensorflow/c/eager/c_api_experimental.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32
33 namespace tensorflow {
34 namespace parallel_device {
35
36 // Functor for making unique_ptrs slightly more ergonomic. Using
37 // decltype(delete_fn) in the unique_ptr's second template argument requires
38 // passing a function pointer to delete_fn when constructing the unique_ptr.
39 class TensorHandleDeleter {
40 public:
operator()41 void operator()(TFE_TensorHandle* to_delete) const {
42 TFE_DeleteTensorHandle(to_delete);
43 }
44 };
45
46 using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
47
48 class ParallelTensor;
49 class DeviceThread;
50
51 // Forwards operations to `devices`, maintaining ParallelTensor with components
52 // placed on each underlying device.
53 class ParallelDevice {
54 public:
55 // Eager async execution is only supported when remote eager is not in use
56 // (b/157523095).
57 explicit ParallelDevice(const std::vector<std::string>& devices,
58 const bool is_async = false);
59
60 ~ParallelDevice();
61
62 // Helper to copy a tensor handle from another device once for each component
63 // of the ParallelDevice.
64 //
65 // Sets a bad status and returns a nullptr if `tensor` is already on the
66 // ParallelDevice, or if the individual copies fail.
67 std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
68 TFE_TensorHandle* tensor,
69 TF_Status* status) const;
70
71 // Construct a parallel tensor consisting of the scalar values from `values`.
72 template <typename DataType>
73 std::unique_ptr<ParallelTensor> ScalarsFromSequence(
74 absl::Span<const DataType> values, TFE_Context* context,
75 TF_Status* status) const;
76
77 // A parallel tensor with scalar integers numbering component devices.
78 std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
79 TF_Status* status) const;
80
81 // The number of devices operations run on.
num_underlying_devices()82 size_t num_underlying_devices() const { return underlying_devices_.size(); }
83
84 // Takes a description of a single operation being executed on the
85 // ParallelDevice, and in turn runs one operation per component device with
86 // its corresponding inputs from the input ParallelTensors. Wraps the
87 // resulting per-device and per-output TFE_TensorHandles into one
88 // ParallelTensor per output of the original operation.
89 //
90 // Attributes are forwarded to executed operations unmodified.
91 //
92 // The returned optional has a value if and only if `status` evaluates to
93 // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
94 // if sanity checks on dtypes/metadata fail.
95 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
96 TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
97 const char* operation_name, const TFE_OpAttrs* attributes,
98 int expected_max_outputs, TF_Status* status) const;
99
100 // A non-blocking version of `Execute`. After each call, `Join` must be called
101 // before `StartExecute` is called again. Using `StartExecute` with `Join`
102 // allows the caller to schedule computation on multiple ParallelDevices
103 // without sequencing those operations (first call `StartExecute` on each
104 // parallel device, then call `Join` on each; even if some of the `Join`s
105 // return a bad status the caller must run all of the `Join`s or any future
106 // `StartExecute`s will deadlock).
107 //
108 // If `is_async=false` (constructor argument), `cancellation_manager` must
109 // live until `Join` finishes. If `is_async=true` it must live until `Join` is
110 // followed by `TFE_ContextAsyncWait` to clear pending operations. It will be
111 // used to cancel all other operations if any fails.
112 void StartExecute(TFE_Context* context,
113 const std::vector<ParallelTensor*>& inputs,
114 const char* operation_name, const TFE_OpAttrs* attributes,
115 int expected_max_outputs,
116 CancellationManager& cancellation_manager) const;
117
118 // Blocks until the previous `StartExecute` has run `TFE_Execute` on each
119 // device. If is_async=false (constructor argument) this means the ops have
120 // run and have results. If is_async=true it means that all of the
121 // device-specific executors have scheduled the op.
122 //
123 // Accepts inferred shapes for outputs (`expected_output_shapes`), which if
124 // fully defined will avoid querying the shapes of the underlying
125 // TensorHandles when ParallelTensor::Shape is called. This allows async
126 // computation to continue without blocking.
127 //
128 // The return status and value is the same as `Execute`.
129 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Join(
130 const std::vector<PartialTensorShape>& expected_output_shapes,
131 TF_Status* status) const;
132
133 private:
134 // A sequence of device names, indicating which devices replicated operations
135 // are forwarded to.
136 const std::vector<std::string> underlying_devices_;
137 // A sequence of thread wrappers, one per device, for executing operations in
138 // parallel.
139 //
140 // Conceptually this is a thread pool with one thread per device. It requires
141 // less synchronization than a thread pool would for this task, since Execute
142 // acquires each thread in order (and so only one Execute will schedule
143 // blocking collective operations at a time), and avoids some dynamic
144 // allocation/scheduling.
145 //
146 // TODO(allenl): Keep a map from outer thread to list of inner threads rather
147 // than a single list of threads so aliased nested parallel devices don't
148 // re-use a thread.
149 std::vector<std::unique_ptr<DeviceThread>> device_threads_;
150 // A cancellation manager to use if the caller does not provide one. When ops
151 // are executed asynchronously this must outlive the queued op, so it can't be
152 // function-local to Execute.
153 mutable std::unique_ptr<CancellationManager> default_cancellation_manager_;
154 };
155
156 // Contains a tuple of tensors, one on each of the `underlying_devices_` of the
157 // ParallelDevice.
158 class ParallelTensor {
159 public:
160 // Construct a ParallelTensor from TensorHandles placed on the component
161 // devices of a ParallelDevice. If called, ParallelTensor::Shape inspects
162 // `components` to determine a shape.
163 static std::unique_ptr<ParallelTensor> FromTensorHandles(
164 const ParallelDevice& parallel_device,
165 std::vector<TensorHandlePtr> components, TF_Status* status);
166 // Uses the provided shape without additional checks, which avoids blocking
167 // when ParallelTensor::Shape is called.
168 static std::unique_ptr<ParallelTensor> FromTensorHandles(
169 const ParallelDevice& parallel_device,
170 std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
171 TF_Status* status);
172
num_tensors()173 size_t num_tensors() const { return tensors_.size(); }
tensor(size_t index)174 TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
175
176 // If the `shape` argument to `FromTensorHandles` is specified, returns that.
177 //
178 // Otherwise if all of the tensors have the same shape, returns that via the
179 // `shape` output argument. This blocks waiting for async tensors, may return
180 // a delayed bad status encountered during async execution, and will return a
181 // bad status unless all tensors have the same shape.
182 Status Shape(const std::vector<int64_t>** shape) const;
dtype()183 TF_DataType dtype() const { return dtype_; }
184
185 private:
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,absl::Span<const int64> shape,const TF_DataType dtype)186 ParallelTensor(const ParallelDevice& device,
187 std::vector<TensorHandlePtr> tensors,
188 absl::Span<const int64> shape, const TF_DataType dtype)
189 : device_(device),
190 tensors_(std::move(tensors)),
191 shape_(std::vector<int64_t>(shape.begin(), shape.end())),
192 dtype_(dtype) {}
ParallelTensor(const ParallelDevice & device,std::vector<TensorHandlePtr> tensors,const TF_DataType dtype)193 ParallelTensor(const ParallelDevice& device,
194 std::vector<TensorHandlePtr> tensors, const TF_DataType dtype)
195 : device_(device),
196 tensors_(std::move(tensors)),
197 shape_(absl::nullopt),
198 dtype_(dtype) {}
199
200 const ParallelDevice& device_;
201 const std::vector<TensorHandlePtr> tensors_;
202 // Parallel tensors are immutable but compute their shape lazily unless it is
203 // provided on construction. The optional has a value if the lazy computation
204 // has been completed or the shape was provided on construction.
205 mutable absl::optional<std::vector<int64_t>> shape_;
206 const TF_DataType dtype_;
207 };
208
209 template <typename DataType>
ScalarsFromSequence(absl::Span<DataType const> values,TFE_Context * context,TF_Status * status)210 std::unique_ptr<ParallelTensor> ParallelDevice::ScalarsFromSequence(
211 absl::Span<DataType const> values, TFE_Context* context,
212 TF_Status* status) const {
213 std::vector<TensorHandlePtr> components;
214 components.reserve(underlying_devices_.size());
215
216 if (values.size() != num_underlying_devices()) {
217 TF_SetStatus(
218 status, TF_INVALID_ARGUMENT,
219 "Number of values did not match number of underlying devices.");
220 return nullptr;
221 }
222 TF_DataType datatype_enum(
223 static_cast<TF_DataType>(DataTypeToEnum<DataType>().value));
224 for (int device_index = 0; device_index < num_underlying_devices();
225 ++device_index) {
226 auto device_value = absl::make_unique<DataType>();
227 *device_value = values[device_index];
228 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
229 TF_NewTensor(
230 datatype_enum, /*dims=*/nullptr, /*num_dims=*/0,
231 device_value.release(), sizeof(DataType),
232 [](void* data, size_t, void* arg) {
233 delete reinterpret_cast<DataType*>(data);
234 },
235 nullptr),
236 TF_DeleteTensor);
237 // TODO(allenl): Here and when executing regular operations, we could hold
238 // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
239 // device names repeatedly.
240 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> const_op(
241 TFE_NewOp(context, "Const", status), TFE_DeleteOp);
242 if (TF_GetCode(status) != TF_OK) return nullptr;
243 TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
244 status);
245 if (TF_GetCode(status) != TF_OK) return nullptr;
246 TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
247 if (TF_GetCode(status) != TF_OK) return nullptr;
248 TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum);
249 TFE_TensorHandle* device_handle;
250 int num_outputs = 1;
251 TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
252 if (TF_GetCode(status) != TF_OK) return nullptr;
253 components.emplace_back(device_handle);
254 }
255 return ParallelTensor::FromTensorHandles(*this, std::move(components),
256 status);
257 }
258
259 } // namespace parallel_device
260 } // namespace tensorflow
261
262 #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
263