1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/utils.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/client/executable_build_options.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
24 #include "tensorflow/compiler/xla/shape.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 
30 namespace {
GetShardedShape(const Shape & shape,const OpSharding & sharding)31 StatusOr<Shape> GetShardedShape(const Shape& shape,
32                                 const OpSharding& sharding) {
33   if (sharding.type() == OpSharding::TUPLE) {
34     if (!shape.IsTuple()) {
35       return InvalidArgument(
36           "Got tuple OpSharding (%s) for non-tuple shape (%s)",
37           sharding.DebugString(), shape.ToString());
38     }
39     if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
40       return InvalidArgument(
41           "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
42           " (OpSharding: %s, shape: %s)",
43           sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
44           sharding.DebugString(), shape.ToString());
45     }
46     std::vector<Shape> sharded_subshapes;
47     for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
48       TF_ASSIGN_OR_RETURN(
49           Shape sharded_subshape,
50           GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
51       sharded_subshapes.emplace_back(std::move(sharded_subshape));
52     }
53     return ShapeUtil::MakeTupleShape(sharded_subshapes);
54   }
55   TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
56                       HloSharding::FromProto(sharding));
57   return hlo_sharding.TileShape(shape);
58 }
59 
GetShardedShape(const HloInstructionProto & instr)60 StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
61   const Shape unsharded_shape(instr.shape());
62   Shape sharded_shape;
63   if (instr.has_sharding()) {
64     TF_ASSIGN_OR_RETURN(sharded_shape,
65                         GetShardedShape(unsharded_shape, instr.sharding()));
66   } else {
67     sharded_shape = unsharded_shape;
68   }
69   LayoutUtil::ClearLayout(&sharded_shape);
70   return sharded_shape;
71 }
72 
73 // Returns sharded (argument shapes, result shape) without layouts.
GetShardedProgramShapes(const XlaComputation & computation,const ProgramShape & program_shape)74 StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
75     const XlaComputation& computation, const ProgramShape& program_shape) {
76   std::vector<Shape> arg_shapes;
77   arg_shapes.resize(program_shape.parameters_size());
78   Shape result_shape;
79   for (const HloComputationProto& comp : computation.proto().computations()) {
80     if (comp.id() != computation.proto().entry_computation_id()) {
81       continue;
82     }
83     for (const HloInstructionProto& instr : comp.instructions()) {
84       if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
85         if (instr.parameter_number() >= program_shape.parameters_size()) {
86           return InvalidArgument(
87               "Got invalid parameter number %d, expected %d parameters",
88               instr.parameter_number(), program_shape.parameters_size());
89         }
90         TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
91                             GetShardedShape(instr));
92       }
93       if (instr.id() == comp.root_id()) {
94         if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
95           return InvalidArgument("Found multiple root instructions");
96         }
97         TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
98       }
99     }
100   }
101   for (int i = 0; i < arg_shapes.size(); ++i) {
102     if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
103       return InvalidArgument("Couldn't find parameter %d", i);
104     }
105   }
106   if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
107     return InvalidArgument("Couldn't find root instruction");
108   }
109   return std::make_pair(arg_shapes, result_shape);
110 }
111 }  // namespace
112 
ParseDeviceAssignmentCompileOptions(bool compile_portable_executable,ExecutableBuildOptions * build_options,std::function<StatusOr<DeviceAssignment> (int,int)> GetDefaultDeviceAssignmentFunction,int * num_replicas,int * num_partitions,std::shared_ptr<DeviceAssignment> * device_assignment)113 Status ParseDeviceAssignmentCompileOptions(
114     bool compile_portable_executable, ExecutableBuildOptions* build_options,
115     std::function<StatusOr<DeviceAssignment>(int, int)>
116         GetDefaultDeviceAssignmentFunction,
117     int* num_replicas, int* num_partitions,
118     std::shared_ptr<DeviceAssignment>* device_assignment) {
119   if (compile_portable_executable) {
120     if (build_options->has_device_assignment()) {
121       return InvalidArgument(
122           "CompileOptions requests portable executable but "
123           "ExecutableBuildOptions includes a device assignment");
124     }
125     *num_replicas = 1;
126     *num_partitions = 1;
127   } else {
128     if (!build_options->has_device_assignment()) {
129       VLOG(2) << "Compile using default device_assignment.";
130       TF_ASSIGN_OR_RETURN(
131           DeviceAssignment device_assignment,
132           GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
133                                              build_options->num_partitions()));
134       build_options->set_device_assignment(device_assignment);
135     }
136     VLOG(2) << "Compile device_assignment:\n"
137             << build_options->device_assignment().ToString();
138     *num_replicas = build_options->device_assignment().replica_count();
139     *num_partitions = build_options->device_assignment().computation_count();
140     *device_assignment =
141         std::make_shared<DeviceAssignment>(build_options->device_assignment());
142   }
143   return Status::OK();
144 }
145 
DetermineArgumentLayoutsFromCompileOptions(const XlaComputation & computation,std::function<StatusOr<Shape> (Shape)> choose_compact_layout_for_shape_function,absl::optional<std::vector<Shape>> & argument_layouts,ExecutableBuildOptions * build_options,std::vector<const Shape * > * argument_layout_pointers)146 Status DetermineArgumentLayoutsFromCompileOptions(
147     const XlaComputation& computation,
148     std::function<StatusOr<Shape>(Shape)>
149         choose_compact_layout_for_shape_function,
150     absl::optional<std::vector<Shape>>& argument_layouts,
151     ExecutableBuildOptions* build_options,
152     std::vector<const Shape*>* argument_layout_pointers) {
153   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
154                       computation.GetProgramShape());
155   if (!argument_layouts) {
156     argument_layouts.emplace(program_shape.parameters());
157     for (Shape& shape : *argument_layouts) {
158       LayoutUtil::ClearLayout(&shape);
159     }
160   } else if (argument_layouts->size() != program_shape.parameters_size()) {
161     return InvalidArgument(
162         "CompileOptions specify %d argument layouts, but computation has %d "
163         "arguments",
164         argument_layouts->size(), program_shape.parameters_size());
165   }
166   argument_layout_pointers->reserve(argument_layouts->size());
167 
168   // Assign a default layout based on `sharded_shape` to any array subshapes in
169   // `dst_shape` that are missing layouts.
170   auto assign_layouts = [&choose_compact_layout_for_shape_function](
171                             const Shape& sharded_shape, Shape* dst_shape) {
172     return ShapeUtil::ForEachMutableSubshapeWithStatus(
173         dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
174           if (subshape->IsArray() && !subshape->has_layout()) {
175             CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
176             const Shape& sharded_subshape =
177                 ShapeUtil::GetSubshape(sharded_shape, idx);
178             LayoutUtil::SetToDefaultLayout(subshape);
179             TF_ASSIGN_OR_RETURN(
180                 Shape layout,
181                 choose_compact_layout_for_shape_function(sharded_subshape));
182             *subshape->mutable_layout() = layout.layout();
183           }
184           return Status::OK();
185         });
186   };
187   TF_ASSIGN_OR_RETURN(auto sharded_shapes,
188                       GetShardedProgramShapes(computation, program_shape));
189 
190   CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
191   for (int i = 0; i < argument_layouts->size(); ++i) {
192     Shape* layout = &(*argument_layouts)[i];
193     argument_layout_pointers->push_back(layout);
194     TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
195   }
196 
197   Shape result_layout;
198   if (build_options->result_layout()) {
199     result_layout = *build_options->result_layout();
200   } else {
201     result_layout = program_shape.result();
202     LayoutUtil::ClearLayout(&result_layout);
203   }
204   TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
205   build_options->set_result_layout(result_layout);
206   return Status::OK();
207 }
208 
GetParametersThatMustBeDonated(const HloModule & module,bool tuple_inputs)209 StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
210     const HloModule& module, bool tuple_inputs) {
211   HloComputation* computation = module.entry_computation();
212   int number_of_parameters = [&]() -> int {
213     if (tuple_inputs) {
214       CHECK_EQ(computation->num_parameters(), 1);
215       const Shape& input_tuple_shape =
216           computation->parameter_instruction(0)->shape();
217       CHECK(input_tuple_shape.IsTuple());
218       return input_tuple_shape.tuple_shapes_size();
219     } else {
220       return computation->num_parameters();
221     }
222   }();
223   // If any buffer in a parameter is aliased we will donate the entire input
224   // parameter.
225   absl::flat_hash_set<int> parameters_to_donate;
226   const HloInputOutputAliasConfig& config = module.input_output_alias_config();
227   TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
228       [&](const ShapeIndex& output_index,
229           const HloInputOutputAliasConfig::Alias& alias) {
230         if (tuple_inputs) {
231           if (alias.parameter_number != 0) {
232             return InvalidArgument(
233                 "Unexpected parameter number %d in alias config with tupled "
234                 "inputs",
235                 alias.parameter_number);
236           }
237           const ShapeIndex& index = alias.parameter_index;
238           if (!index.empty()) {
239             int this_parameter = index.data()[0];
240             if (this_parameter >= number_of_parameters) {
241               return InvalidArgument(
242                   "Unexpected parameter index %s in alias config with tupled "
243                   "inputs and %d parameters",
244                   index.ToString(), number_of_parameters);
245             }
246             parameters_to_donate.insert(this_parameter);
247           }
248         } else {
249           int this_parameter = alias.parameter_number;
250           if (this_parameter >= number_of_parameters) {
251             return InvalidArgument(
252                 "Unexpected parameter number %d in alias config without tupled "
253                 "inputs and %d parameters",
254                 this_parameter, number_of_parameters);
255           }
256           parameters_to_donate.insert(this_parameter);
257         }
258         return Status::OK();
259       }));
260   return parameters_to_donate;
261 }
262 
263 }  // namespace xla
264