1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/local_service.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/computation_layout.h"
30 #include "tensorflow/compiler/xla/service/executable.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
36 #include "tensorflow/compiler/xla/service/platform_util.h"
37 #include "tensorflow/compiler/xla/shape_layout.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
45 
46 namespace xla {
47 
NewService(const ServiceOptions & options)48 /* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
49     const ServiceOptions& options) {
50   se::Platform* platform = options.platform();
51   if (platform == nullptr) {
52     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
53   }
54 
55   BackendOptions backend_options;
56   backend_options.set_platform(platform)
57       .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads())
58       .set_allowed_devices(options.allowed_devices());
59 
60   TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
61                       Backend::CreateBackend(backend_options));
62 
63   std::unique_ptr<LocalService> service(
64       new LocalService(options, std::move(backend)));
65   return std::move(service);
66 }
67 
LocalService(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)68 LocalService::LocalService(const ServiceOptions& options,
69                            std::unique_ptr<Backend> execute_backend)
70     : Service(options, std::move(execute_backend)) {}
71 
72 namespace {
73 
74 // Retrieves the parameter metadata for the given computation and parameter
75 // number.
76 //
77 // If the parameter number is invalid for this computation, nullopt is
78 // returned. When the return value has_value(), nullptr will never be
79 // the held value.
ParameterMetadata(const XlaComputation & computation,int parameter_number)80 absl::optional<const OpMetadata*> ParameterMetadata(
81     const XlaComputation& computation, int parameter_number) {
82   for (const HloComputationProto& comp : computation.proto().computations()) {
83     if (comp.id() == computation.proto().entry_computation_id()) {
84       for (const HloInstructionProto& instr : comp.instructions()) {
85         if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
86             instr.parameter_number() == parameter_number) {
87           if (!instr.has_metadata()) {
88             return absl::nullopt;
89           }
90           return &instr.metadata();
91         }
92       }
93     }
94   }
95   return absl::nullopt;
96 }
97 
98 }  // namespace
99 
100 StatusOr<std::vector<std::unique_ptr<Executable>>>
CompileExecutables(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)101 LocalService::CompileExecutables(
102     const XlaComputation& computation,
103     const absl::Span<const Shape* const> argument_layouts,
104     const ExecutableBuildOptions& build_options) {
105   const HloModuleProto& proto = computation.proto();
106   TF_RET_CHECK(proto.has_host_program_shape());
107   ProgramShape program_shape(proto.host_program_shape());
108 
109   // Validate incoming layouts.
110   if (argument_layouts.size() != program_shape.parameters_size()) {
111     return InvalidArgument(
112         "Invalid number of arguments for computation: expected %d, got %u.",
113         program_shape.parameters_size(), argument_layouts.size());
114   }
115 
116   for (int i = 0; i < argument_layouts.size(); ++i) {
117     const Shape& argument_shape = *argument_layouts[i];
118     TF_RETURN_IF_ERROR(
119         ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
120     if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
121       absl::optional<const OpMetadata*> metadata =
122           ParameterMetadata(computation, /*parameter_number=*/i);
123       auto metadata_string = [&metadata]() -> string {
124         if (!metadata.has_value()) {
125           return "";
126         }
127         CHECK(metadata.value() != nullptr);
128         const OpMetadata& m = *metadata.value();
129         if (!m.source_file().empty()) {
130           return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
131         }
132         return "";
133       };
134       return InvalidArgument(
135           "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
136           metadata_string(),
137           ShapeUtil::HumanString(program_shape.parameters(i)),
138           ShapeUtil::HumanString(argument_shape));
139     }
140   }
141   if (build_options.result_layout() != nullptr) {
142     TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
143                                            program_shape.result()));
144   }
145 
146   ExecutionOptions execution_options =
147       CreateExecutionOptions(build_options, &program_shape);
148 
149   TF_ASSIGN_OR_RETURN(
150       std::unique_ptr<HloModuleConfig> module_config,
151       CreateModuleConfig(program_shape, argument_layouts, &execution_options));
152 
153   VLOG(3) << "Computation Layout: "
154           << module_config->entry_computation_layout().ToString();
155 
156   TF_ASSIGN_OR_RETURN(
157       se::StreamExecutor * executor,
158       execute_backend_->stream_executor(build_options.device_ordinal()));
159 
160   // TODO(cjfj): Investigate why there are a couple of test failures when the
161   // single partition computations are built using `BuildExecutables`, fix it,
162   // and remove this special case (provided the performance if similar).
163   if (build_options.num_partitions() == 1) {
164     TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
165                         BuildExecutable(proto, std::move(module_config),
166                                         execute_backend_.get(), executor,
167                                         {build_options.device_allocator(),
168                                          build_options.compile_thread_pool()},
169                                         build_options.run_backend_only()));
170     std::vector<std::unique_ptr<Executable>> executables;
171     executables.push_back(std::move(executable));
172     return executables;
173   } else {
174     std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
175     module_configs.push_back(std::move(module_config));
176     // BuildExecutables uses the executors length to determine the number of
177     // cores per module, but otherwise only uses the first executor.
178     std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
179                                                executor);
180 
181     return BuildExecutables(
182         /*module_protos=*/{&proto}, std::move(module_configs),
183         execute_backend_.get(), {executors},
184         Compiler::CompileOptions{build_options.device_allocator(),
185                                  build_options.compile_thread_pool()},
186         build_options.run_backend_only());
187   }
188 }
189 
ReplicaNumberToDeviceOrdinal(int replica_number)190 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
191   return backend().computation_placer()->DeviceId(
192       replica_number, /*computation=*/0, options_.number_of_replicas(),
193       /*computation_count=*/1);
194 }
195 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)196 StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
197     const GlobalDataHandle& data, int replica_number) {
198   TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
199   if (replica_number >= buffers.size()) {
200     return InvalidArgument(
201         "replica_number %d out of range; must be less than num_replicas = %u.",
202         replica_number, buffers.size());
203   }
204   return buffers[replica_number];
205 }
206 
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)207 StatusOr<GlobalDataHandle> LocalService::RegisterReplicatedBuffers(
208     std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
209   return allocation_tracker_.RegisterReplicatedBuffers(
210       std::move(replicated_buffers), tag);
211 }
212 
213 }  // namespace xla
214