1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/common/shape.h"
16 
17 #include <stdint.h>
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 
25 namespace tflite {
26 namespace gpu {
27 namespace {
28 
29 struct GetAxisByIndexFunc {
30   template <Layout T>
operator ()tflite::gpu::__anon55d2d3670111::GetAxisByIndexFunc31   Axis operator()() const {
32     return GetAxis<T>(index);
33   }
34   int32_t index;
35 };
36 
37 struct GetIndexByAxisFunc {
38   template <Layout T>
operator ()tflite::gpu::__anon55d2d3670111::GetIndexByAxisFunc39   int operator()() const {
40     return GetAxisIndex<T>(axis);
41   }
42   Axis axis;
43 };
44 
45 struct NumAxisFunc {
46   template <Layout T>
operator ()tflite::gpu::__anon55d2d3670111::NumAxisFunc47   int operator()() const {
48     return Size<T>();
49   }
50 };
51 
52 }  // namespace
53 
ToString(Axis axis)54 std::string ToString(Axis axis) {
55   switch (axis) {
56     case Axis::BATCH:
57       return "batch";
58     case Axis::CHANNELS:
59       return "channels";
60     case Axis::INPUT_CHANNELS:
61       return "input_channels";
62     case Axis::OUTPUT_CHANNELS:
63       return "output_channels";
64     case Axis::HEIGHT:
65       return "height";
66     case Axis::WIDTH:
67       return "width";
68     case Axis::VALUE:
69       return "value";
70     case Axis::DEPTH:
71       return "depth";
72     case Axis::UNKNOWN:
73       return "unknown";
74   }
75   return "undefined";
76 }
77 
ToString(Layout layout)78 std::string ToString(Layout layout) {
79   switch (layout) {
80     case Layout::SCALAR:
81       return "scalar";
82     case Layout::LINEAR:
83       return "linear";
84     case Layout::HW:
85       return "hw";
86     case Layout::HWD:
87       return "hwd";
88     case Layout::CHW:
89       return "chw";
90     case Layout::HWC:
91       return "hwc";
92     case Layout::HWDC:
93       return "hwdc";
94     case Layout::OHWI:
95       return "ohwi";
96     case Layout::IHWO:
97       return "ihwo";
98     case Layout::OIHW:
99       return "oihw";
100     case Layout::IOHW:
101       return "iohw";
102     case Layout::BHWC:
103       return "bhwc";
104     case Layout::BHWDC:
105       return "bhwdc";
106     case Layout::OHWDI:
107       return "ohwi";
108     case Layout::UNKNOWN:
109       return "unknown";
110   }
111   return "undefined";
112 }
113 
GetAxis(Layout layout,int32_t index)114 Axis GetAxis(Layout layout, int32_t index) {
115   return DispatchByLayout(layout, GetAxisByIndexFunc{index});
116 }
117 
GetAxisIndex(Layout layout,Axis axis)118 int GetAxisIndex(Layout layout, Axis axis) {
119   return DispatchByLayout(layout, GetIndexByAxisFunc{axis});
120 }
121 
HasAxis(Layout layout,Axis axis)122 bool HasAxis(Layout layout, Axis axis) {
123   return GetAxisIndex(layout, axis) >= 0;
124 }
125 
Size(Layout layout)126 int Size(Layout layout) { return DispatchByLayout(layout, NumAxisFunc()); }
127 
ToString(const Shape & s)128 std::string ToString(const Shape& s) {
129   return absl::StrCat("{", ToString(s.layout), ", {",
130                       absl::StrJoin(s.dimensions, ", "), "}}");
131 }
132 
133 }  // namespace gpu
134 }  // namespace tflite
135