1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/local_service.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/computation_layout.h"
30 #include "tensorflow/compiler/xla/service/executable.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
35 #include "tensorflow/compiler/xla/service/platform_util.h"
36 #include "tensorflow/compiler/xla/shape_layout.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/gtl/cleanup.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44 
45 namespace xla {
46 
NewService(const ServiceOptions & options)47 /* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
48     const ServiceOptions& options) {
49   se::Platform* platform = options.platform();
50   if (platform == nullptr) {
51     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
52   }
53 
54   BackendOptions backend_options;
55   backend_options.set_platform(platform)
56       .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads())
57       .set_allowed_devices(options.allowed_devices());
58 
59   TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
60                       Backend::CreateBackend(backend_options));
61 
62   std::unique_ptr<LocalService> service(
63       new LocalService(options, std::move(backend)));
64   return std::move(service);
65 }
66 
LocalService(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)67 LocalService::LocalService(const ServiceOptions& options,
68                            std::unique_ptr<Backend> execute_backend)
69     : Service(options, std::move(execute_backend)) {}
70 
71 namespace {
72 
73 // Retrieves the parameter metadata for the given computation and parameter
74 // number.
75 //
76 // If the parameter number is invalid for this computation, nullopt is
77 // returned. When the return value has_value(), nullptr will never be
78 // the held value.
ParameterMetadata(const XlaComputation & computation,int parameter_number)79 absl::optional<const OpMetadata*> ParameterMetadata(
80     const XlaComputation& computation, int parameter_number) {
81   for (const HloComputationProto& comp : computation.proto().computations()) {
82     if (comp.id() == computation.proto().entry_computation_id()) {
83       for (const HloInstructionProto& instr : comp.instructions()) {
84         if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
85             instr.parameter_number() == parameter_number) {
86           if (!instr.has_metadata()) {
87             return absl::nullopt;
88           }
89           return &instr.metadata();
90         }
91       }
92     }
93   }
94   return absl::nullopt;
95 }
96 
CreateExecutionOptions(const ExecutableBuildOptions & build_options,const ProgramShape * program_shape)97 ExecutionOptions CreateExecutionOptions(
98     const ExecutableBuildOptions& build_options,
99     const ProgramShape* program_shape) {
100   ExecutionOptions execution_options = CreateDefaultExecutionOptions();
101   if (build_options.has_debug_options()) {
102     *execution_options.mutable_debug_options() = build_options.debug_options();
103   }
104   if (build_options.result_layout() != nullptr) {
105     *execution_options.mutable_shape_with_output_layout() =
106         build_options.result_layout()->ToProto();
107   } else {
108     Shape result_shape(program_shape->result());
109     LayoutUtil::SetToDefaultLayout(&result_shape);
110     *execution_options.mutable_shape_with_output_layout() =
111         result_shape.ToProto();
112   }
113   execution_options.set_num_replicas(build_options.num_replicas());
114   return execution_options;
115 }
116 
117 }  // namespace
118 
CompileExecutable(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)119 StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
120     const XlaComputation& computation,
121     const absl::Span<const Shape* const> argument_layouts,
122     const ExecutableBuildOptions& build_options) {
123   const HloModuleProto& proto = computation.proto();
124   TF_RET_CHECK(proto.has_host_program_shape());
125   ProgramShape program_shape(proto.host_program_shape());
126 
127   // Validate incoming layouts.
128   if (argument_layouts.size() != program_shape.parameters_size()) {
129     return InvalidArgument(
130         "Invalid number of arguments for computation: expected %d, got %u.",
131         program_shape.parameters_size(), argument_layouts.size());
132   }
133 
134   for (int i = 0; i < argument_layouts.size(); ++i) {
135     const Shape& argument_shape = *argument_layouts[i];
136     TF_RETURN_IF_ERROR(
137         ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
138     if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
139       absl::optional<const OpMetadata*> metadata =
140           ParameterMetadata(computation, /*parameter_number=*/i);
141       auto metadata_string = [&metadata]() -> string {
142         if (!metadata.has_value()) {
143           return "";
144         }
145         CHECK(metadata.value() != nullptr);
146         const OpMetadata& m = *metadata.value();
147         if (!m.source_file().empty()) {
148           return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
149         }
150         return "";
151       };
152       return InvalidArgument(
153           "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
154           metadata_string(),
155           ShapeUtil::HumanString(program_shape.parameters(i)),
156           ShapeUtil::HumanString(argument_shape));
157     }
158   }
159   if (build_options.result_layout() != nullptr) {
160     TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
161                                            program_shape.result()));
162   }
163 
164   ExecutionOptions execution_options =
165       CreateExecutionOptions(build_options, &program_shape);
166 
167   TF_ASSIGN_OR_RETURN(
168       std::unique_ptr<HloModuleConfig> module_config,
169       CreateModuleConfig(program_shape, argument_layouts, &execution_options));
170 
171   VLOG(3) << "Computation Layout: "
172           << module_config->entry_computation_layout().ToString();
173 
174   TF_ASSIGN_OR_RETURN(
175       se::StreamExecutor * executor,
176       execute_backend_->stream_executor(build_options.device_ordinal()));
177 
178   return BuildExecutable(proto, std::move(module_config),
179                          execute_backend_.get(), executor,
180                          build_options.device_allocator());
181 }
182 
ReplicaNumberToDeviceOrdinal(int replica_number)183 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
184   return backend().computation_placer()->DeviceId(
185       replica_number, /*computation=*/0, options_.number_of_replicas(),
186       /*computation_count=*/1);
187 }
188 
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)189 StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
190     const GlobalDataHandle& data, int replica_number) {
191   TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
192   if (replica_number >= buffers.size()) {
193     return InvalidArgument(
194         "replica_number %d out of range; must be less than num_replicas = %u.",
195         replica_number, buffers.size());
196   }
197   return buffers[replica_number];
198 }
199 
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)200 StatusOr<GlobalDataHandle> LocalService::RegisterReplicatedBuffers(
201     std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
202   return allocation_tracker_.RegisterReplicatedBuffers(
203       std::move(replicated_buffers), tag);
204 }
205 
206 }  // namespace xla
207