1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/sharding_util.h"
16 
17 #include "absl/strings/match.h"
18 #include "tensorflow/core/framework/node_def.pb.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/util/device_name_utils.h"
21 
22 namespace tensorflow {
23 namespace {
24 const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE";
25 const char kShardingAttribute[] = "_XlaSharding";
26 }  // namespace
27 
28 namespace {
CreateOpMetadata(const std::string & op_type,const std::string & op_name)29 xla::OpMetadata CreateOpMetadata(const std::string& op_type,
30                                  const std::string& op_name) {
31   xla::OpMetadata metadata;
32   metadata.set_op_type(op_type);
33   metadata.set_op_name(op_name);
34   return metadata;
35 }
36 
AssignOpMetadataToSharding(xla::OpSharding & sharding,const string & op_type,const string & op_name)37 void AssignOpMetadataToSharding(xla::OpSharding& sharding,
38                                 const string& op_type, const string& op_name) {
39   auto metadata = CreateOpMetadata(op_type, op_name);
40   if (sharding.type() == xla::OpSharding::TUPLE) {
41     for (auto& sharding_element : *sharding.mutable_tuple_shardings()) {
42       *sharding_element.add_metadata() = metadata;
43     }
44   } else {
45     *sharding.add_metadata() = metadata;
46   }
47 }
48 
CoreOutOfRangeError(int core,int num_cores_per_replica)49 Status CoreOutOfRangeError(int core, int num_cores_per_replica) {
50   return errors::InvalidArgument(
51       "Invalid replicated core id: ", core,
52       "; num_cores_per_replica=", num_cores_per_replica);
53 }
54 }  // namespace
55 
ParseShardingFromDevice(const string & device_name,int num_cores_per_replica,absl::optional<xla::OpSharding> explicit_sharding,absl::optional<xla::OpMetadata> metadata)56 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
57     const string& device_name, int num_cores_per_replica,
58     absl::optional<xla::OpSharding> explicit_sharding,
59     absl::optional<xla::OpMetadata> metadata) {
60   if (device_name.empty()) {
61     return explicit_sharding;
62   }
63   DeviceNameUtils::ParsedName parsed_device;
64   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) {
65     return errors::InvalidArgument("Malformed assigned device '", device_name,
66                                    "'");
67   }
68 
69   if (explicit_sharding.has_value()) {
70     return explicit_sharding;
71   } else if (!parsed_device.has_type || !parsed_device.has_id ||
72              !absl::StrContains(parsed_device.type,
73                                 kDeviceSuffixReplicatedCore)) {
74     return absl::optional<xla::OpSharding>();
75   } else {
76     const int core = parsed_device.id;
77     if (core < 0 || core >= num_cores_per_replica) {
78       return CoreOutOfRangeError(core, num_cores_per_replica);
79     }
80     auto sharding = xla::sharding_builder::AssignDevice(core);
81     if (metadata.has_value()) {
82       *sharding.add_metadata() = metadata.value();
83     }
84     return absl::optional<xla::OpSharding>(sharding);
85   }
86 }
87 
ParseShardingFromDevice(const NodeDef & node_def,int num_cores_per_replica,bool add_metadata)88 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
89     const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) {
90   const string& device_name = node_def.device();
91   TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
92                       GetShardingFromNodeDef(node_def, add_metadata));
93   return ParseShardingFromDevice(
94       device_name, num_cores_per_replica, sharding,
95       add_metadata ? absl::optional<xla::OpMetadata>(
96                          CreateOpMetadata(node_def.op(), node_def.name()))
97                    : absl::nullopt);
98 }
99 
ParseShardingFromDevice(const Node & node,int num_cores_per_replica,bool add_metadata)100 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromDevice(
101     const Node& node, int num_cores_per_replica, bool add_metadata) {
102   string device_name = node.assigned_device_name();
103   if (device_name.empty()) {
104     device_name = node.requested_device();
105   }
106   TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
107                       GetShardingFromNodeDef(node.def(), add_metadata));
108   return ParseShardingFromDevice(
109       device_name, num_cores_per_replica, sharding,
110       add_metadata ? absl::optional<xla::OpMetadata>(
111                          CreateOpMetadata(node.type_string(), node.name()))
112                    : absl::nullopt);
113 }
114 
ParseShardingFromEdgeSource(const Edge & edge,int num_cores_per_replica,bool add_metadata)115 xla::StatusOr<absl::optional<xla::OpSharding>> ParseShardingFromEdgeSource(
116     const Edge& edge, int num_cores_per_replica, bool add_metadata) {
117   if (edge.src() == nullptr) {
118     return tensorflow::errors::InvalidArgument(
119         "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString());
120   }
121   TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
122                       ParseShardingFromDevice(
123                           *edge.src(), num_cores_per_replica, add_metadata));
124   if (sharding.has_value() &&
125       sharding.value().type() == xla::OpSharding::TUPLE) {
126     if (edge.src_output() < 0 ||
127         edge.src_output() >= sharding.value().tuple_shardings_size()) {
128       return tensorflow::errors::InvalidArgument(
129           "Tuple index out of bound: edge=", edge.DebugString(),
130           " sharding=", sharding->DebugString());
131     }
132     absl::optional<xla::OpSharding> subsharding =
133         sharding.value().tuple_shardings(edge.src_output());
134     return subsharding;
135   }
136   return sharding;
137 }
138 
SetShardingDeviceAssignmentFromNode(const Node & src,Node * dst)139 void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) {
140   string device_name = src.assigned_device_name();
141   if (device_name.empty()) {
142     device_name = src.requested_device();
143   }
144   dst->set_assigned_device_name(device_name);
145   if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) {
146     dst->AddAttr(kShardingAttribute, *attr);
147   }
148 }
149 
GetShardingFromNodeDef(const NodeDef & node_def,bool add_metadata)150 xla::StatusOr<absl::optional<xla::OpSharding>> GetShardingFromNodeDef(
151     const NodeDef& node_def, bool add_metadata) {
152   if (!HasNodeAttr(node_def, kShardingAttribute)) {
153     return absl::optional<xla::OpSharding>();
154   }
155   string value;
156   xla::OpSharding sharding;
157   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value));
158   if (!sharding.ParseFromString(value)) {
159     return xla::InvalidArgument(
160         "Experimental _XlaSharding attribute was not a valid encoded "
161         "xla::OpSharding proto.");
162   }
163   if (add_metadata) {
164     AssignOpMetadataToSharding(sharding, node_def.op(), node_def.name());
165   }
166   return absl::optional<xla::OpSharding>(sharding);
167 }
168 }  // namespace tensorflow
169