1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_
18 
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "absl/types/variant.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "tensorflow/compiler/xla/python/py_buffer.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 // TODO(jblespiau): The current implementation moves the Python logic to C++,
31 // as a preliminary step to executing the `pmap` execution path from C++.
32 // It implements the current Python behavior (thus, it may not be optimal, and
33 // we will be able to modify it later).
34 
35 namespace jax {
36 
37 // High level introduction.
38 //
39 // pmap and other parallel computation functions distribute some computation on
40 // several devices. On December 2020, the devices mesh (i.e. N-dimentional array
41 // of devices on which we map the computation) is defined by the user.
42 //
43 // We describe how to shard the inputs, and how to map it to the mesh of devices
44 // using `ShardingSpec`. It's mainly based on 2 components:
45 // - `sharding`, which specifies how to shard the inputs.
46 // - `mesh_mapping`, which specifies how to map shards to devices.
47 //
48 
49 // The 3 following structs define how to shard one dimension of an ndarry.
50 //
51 // `NoSharding` (`None` in Python) means no sharding.
52 struct NoSharding {};
53 
54 // `Chunked` means that the dimension is split into np.prod(chunks) chunks
55 // and the split dimension itself is preserved inside the map.
56 // Those chunks are distributed over `len(chunks)` ShardedAxes axes
57 // (major-to-minor).
58 // For example, for a tensor `t` or shape [N] sharded using [Chunked([p])] (with
59 // p  dividing N, let S = N // p) the tensor will be split into p chunks of
60 // shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right
61 // excluded) for k in {0, ... p-1}.
62 struct Chunked {
63  public:
ChunkedChunked64   explicit Chunked(std::vector<int> chunks_) : chunks(std::move(chunks_)) {}
65   // The number of chunks per axis.
66   std::vector<int> chunks;
67 
68   bool operator==(const Chunked& other) const { return chunks == other.chunks; }
69   bool operator!=(const Chunked& other) const { return chunks != other.chunks; }
70 };
71 
72 // `Unstacked` means that the dimension is split into chunks of size 1, and
73 // doesn't appear inside the map. `size` is always the dimension size.
74 // For example, a Tensor t of shape [N] will be sharded into N tensors of shape
75 // [], when using `Unstacked(N)`.
76 struct Unstacked {
77  public:
UnstackedUnstacked78   explicit Unstacked(int sz) : size(sz) {}
79   int size;
80 
81   bool operator==(const Unstacked& other) const { return size == other.size; }
82   bool operator!=(const Unstacked& other) const { return size != other.size; }
83 };
84 
85 using AvalDimSharding = absl::variant<NoSharding, Chunked, Unstacked>;
86 
87 // Assigns sharded axes to mesh dimensions.
88 //
89 // The devices will be for each dimension which has a sharded `AvalDimSharding`
90 // When no axis is assigned, the data is replicated.
91 // As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually
92 // sharded axis (i.e. counting as if the None dimensions of sharding were
93 // filtered out).
94 // For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry
95 // of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`.
96 
97 struct ShardedAxis {
98   int axis;
99   bool operator==(const ShardedAxis& other) const { return axis == other.axis; }
100   bool operator!=(const ShardedAxis& other) const { return axis != other.axis; }
101 };
102 struct Replicated {
103   int replicas;
104   bool operator==(const Replicated& other) const {
105     return replicas == other.replicas;
106   }
107   bool operator!=(const Replicated& other) const {
108     return replicas != other.replicas;
109   }
110 };
111 
112 using MeshDimAssignment = absl::variant<ShardedAxis, Replicated>;
113 
114 // Describes how each axis is sharded (if it is), and how it'smapped to the
115 // devices mesh.
116 class ShardingSpec {
117  public:
ShardingSpec(std::vector<AvalDimSharding> sharding,std::vector<MeshDimAssignment> mesh_mapping)118   ShardingSpec(std::vector<AvalDimSharding> sharding,
119                std::vector<MeshDimAssignment> mesh_mapping)
120       : sharding_(std::move(sharding)),
121         mesh_mapping_(std::move(mesh_mapping)) {}
122 
GetSharding()123   const std::vector<AvalDimSharding>& GetSharding() const { return sharding_; }
GetMeshMapping()124   const std::vector<MeshDimAssignment>& GetMeshMapping() const {
125     return mesh_mapping_;
126   }
127 
128  private:
129   //  `sharding` specifies how the array is supposed to get partitioned into
130   //  chunks. Its length matchs the rank of the array. See the docstring
131   //  of `AvalDimSharding` for the supported partitioning schemes.
132   std::vector<AvalDimSharding> sharding_;
133   //  `mesh_mapping` describes an assignments of the array chunks created by
134   //  `sharding` to a logical device mesh. The length of the tuple is equal to
135   //  the rank of the mesh. Each mesh dimension can either get partitions of
136   //  data varying along one of the sharded dimensions, or the data can be
137   //  replicated.
138   std::vector<MeshDimAssignment> mesh_mapping_;
139 };
140 
141 // A ShardedDeviceArray is an ndarray sharded across devices.
142 //
143 // The purpose of a ShardedDeviceArray is to reduce the number of transfers when
144 // executing replicated computations, by allowing results to persist on the
145 // devices that produced them. That way dispatching a similarly replicated
146 // computation that consumes the same sharded memory layout does not incur any
147 // transfers.
148 
149 // A ShardedDeviceArray represents one logical ndarray value, and simulates the
150 // behavior of an ndarray so that it can be treated by user code as an ndarray;
151 // that is, it is only an optimization to reduce transfers.
152 
153 // Design note: We move to C++, only what will need to be accessed by C++ to
154 // execute a pmap computation. A large part of the logic is still in Python.
155 class ShardedDeviceArray : xla::DeviceArrayBase {
156  public:
ShardedDeviceArray(pybind11::handle aval,ShardingSpec sharding_spec,pybind11::list device_buffers)157   ShardedDeviceArray(
158       pybind11::handle aval, ShardingSpec sharding_spec,
159       // Buffers are expected to be xla::PyBuffer objects, but as there are
160       // alternative backend implementations, this may not be guaranteed.
161       // TODO(jblespiau): As soon as PjRtBuffer is supported by all
162       // implementations, we should be able to store this with the C++ objects.
163       pybind11::list device_buffers)
164       : DeviceArrayBase(),
165         aval_(pybind11::cast<pybind11::object>(aval)),
166         sharding_spec_(std::move(sharding_spec)),
167         device_buffers_(device_buffers) {}
168 
GetAval()169   pybind11::object GetAval() const { return aval_; }
GetShardingSpec()170   const ShardingSpec& GetShardingSpec() const { return sharding_spec_; }
GetDeviceBuffers()171   pybind11::list GetDeviceBuffers() const { return device_buffers_; }
172 
173  private:
174   // A ShapedArray indicating the shape and dtype of this array.
175   pybind11::object aval_;
176   // Describes how this array is sharded across `device_buffers`.
177   ShardingSpec sharding_spec_;
178   // The buffers containing the data for this array. Each buffer is the same
179   // shape and on a different device. Buffers are in row-major order, with
180   // replication treated as an extra innermost dimension.
181   pybind11::list device_buffers_;
182 };
183 
184 void BuildPmapSubmodule(pybind11::module& m);
185 
186 }  // namespace jax
187 
188 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PMAP_LIB_H_
189