1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 
20 namespace tensorflow {
XlaResourceOpKindToString(XlaResourceOpKind op_kind)21 /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
22     XlaResourceOpKind op_kind) {
23   switch (op_kind) {
24     case XlaResourceOpKind::kRead:
25       return "Read";
26     case XlaResourceOpKind::kWrite:
27       return "Write";
28     case XlaResourceOpKind::kReadWrite:
29       return "Modify";
30   }
31 }
32 
33 static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap()34 CreateResourceOpInfoMap() {
35   auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
36 
37   auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
38                  XlaResourceKind resource_kind) {
39     auto insert_result =
40         result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
41     CHECK(insert_result.second);
42   };
43 
44   auto kRead = XlaResourceOpKind::kRead;
45   auto kWrite = XlaResourceOpKind::kWrite;
46   auto kReadWrite = XlaResourceOpKind::kReadWrite;
47 
48   auto kVariable = XlaResourceKind::kVariable;
49   auto kStack = XlaResourceKind::kStack;
50   auto kTensorArray = XlaResourceKind::kTensorArray;
51 
52   // clang-format off
53   add("AssignAddVariableOp"                  , kReadWrite, kVariable);
54   add("AssignSubVariableOp"                  , kReadWrite, kVariable);
55   add("AssignVariableOp"                     , kWrite,     kVariable);
56   add("ReadVariableOp"                       , kRead,      kVariable);
57   add("ResourceApplyAdaMax"                  , kReadWrite, kVariable);
58   add("ResourceApplyAdadelta"                , kReadWrite, kVariable);
59   add("ResourceApplyAdagrad"                 , kReadWrite, kVariable);
60   add("ResourceApplyAdagradV2"               , kReadWrite, kVariable),
61   add("ResourceApplyAdagradDA"               , kReadWrite, kVariable);
62   add("ResourceApplyAdam"                    , kReadWrite, kVariable);
63   add("ResourceApplyAddSign"                 , kReadWrite, kVariable);
64   add("ResourceApplyCenteredRMSProp"         , kReadWrite, kVariable);
65   add("ResourceApplyFtrl"                    , kReadWrite, kVariable);
66   add("ResourceApplyFtrlV2"                  , kReadWrite, kVariable);
67   add("ResourceApplyGradientDescent"         , kReadWrite, kVariable);
68   add("ResourceApplyMomentum"                , kReadWrite, kVariable);
69   add("ResourceApplyKerasMomentum"           , kReadWrite, kVariable);
70   add("ResourceApplyPowerSign"               , kReadWrite, kVariable);
71   add("ResourceApplyProximalAdagrad"         , kReadWrite, kVariable);
72   add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable);
73   add("ResourceApplyRMSProp"                 , kReadWrite, kVariable);
74   add("ResourceGather"                       , kRead,      kVariable);
75   add("ResourceScatterAdd"                   , kReadWrite, kVariable);
76   add("ResourceScatterDiv"                   , kReadWrite, kVariable);
77   add("ResourceScatterMax"                   , kReadWrite, kVariable);
78   add("ResourceScatterMin"                   , kReadWrite, kVariable);
79   add("ResourceScatterMul"                   , kReadWrite, kVariable);
80   add("ResourceScatterNdAdd"                 , kReadWrite, kVariable);
81   add("ResourceScatterNdSub"                 , kReadWrite, kVariable);
82   add("ResourceScatterNdUpdate"              , kReadWrite, kVariable);
83   add("ResourceScatterSub"                   , kReadWrite, kVariable);
84   add("ResourceScatterUpdate"                , kReadWrite, kVariable);
85   add("ResourceStridedSliceAssign"           , kReadWrite, kVariable);
86   add("RngReadAndSkip"                       , kReadWrite, kVariable);
87   add("RngSkip"                              , kReadWrite, kVariable);
88   add("StatefulStandardNormalV2"             , kReadWrite, kVariable);
89   add("StatefulTruncatedNormal"              , kReadWrite, kVariable);
90   add("StatefulUniform"                      , kReadWrite, kVariable);
91   add("StatefulUniformFullInt"               , kReadWrite, kVariable);
92   add("StatefulUniformInt"                   , kReadWrite, kVariable);
93   add("VarIsInitializedOp"                   , kRead,      kVariable);
94   add("VariableShape"                        , kRead,      kVariable);
95 
96   add("StackV2"                              , kWrite,     kStack);
97   add("StackCloseV2"                         , kRead,      kStack);
98   add("StackPopV2"                           , kReadWrite, kStack);
99   add("StackPushV2"                          , kReadWrite, kStack);
100 
101   add("TensorArrayV3"                        , kWrite,     kTensorArray);
102   add("TensorArrayConcatV3"                  , kRead,      kTensorArray);
103   add("TensorArrayGatherV3"                  , kRead,      kTensorArray);
104   add("TensorArrayScatterV3"                 , kWrite,     kTensorArray);
105   add("TensorArrayGradV3"                    , kRead,      kTensorArray);
106   add("TensorArrayCloseV3"                   , kRead,      kTensorArray);
107   add("TensorArrayReadV3"                    , kRead,      kTensorArray);
108   add("TensorArraySizeV3"                    , kRead,      kTensorArray);
109   add("TensorArraySplitV3"                   , kWrite,     kTensorArray);
110   add("TensorArrayWriteV3"                   , kWrite,     kTensorArray);
111   // clang-format on
112 
113   return result;
114 }
115 
116 static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap()117 GetStaticResourceOpInfoMap() {
118   static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
119       op_info_map = CreateResourceOpInfoMap();
120   return *op_info_map;
121 }
122 
GetResourceOpInfoForOp(absl::string_view op)123 const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
124   const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
125       GetStaticResourceOpInfoMap();
126   auto it = op_infos.find(op);
127   return it == op_infos.end() ? nullptr : &it->second;
128 }
129 
130 namespace resource_op_table_internal {
GetKnownResourceOps()131 std::vector<absl::string_view> GetKnownResourceOps() {
132   std::vector<absl::string_view> result;
133   for (const auto& p : GetStaticResourceOpInfoMap()) {
134     result.push_back(p.first);
135   }
136   absl::c_sort(result);
137   return result;
138 }
139 }  // namespace resource_op_table_internal
140 }  // namespace tensorflow
141