1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/gpu/common/shape.h"
16
17 #include <stdint.h>
18
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24
25 namespace tflite {
26 namespace gpu {
27 namespace {
28
29 struct GetAxisByIndexFunc {
30 template <Layout T>
operator ()tflite::gpu::__anon55d2d3670111::GetAxisByIndexFunc31 Axis operator()() const {
32 return GetAxis<T>(index);
33 }
34 int32_t index;
35 };
36
37 struct GetIndexByAxisFunc {
38 template <Layout T>
operator ()tflite::gpu::__anon55d2d3670111::GetIndexByAxisFunc39 int operator()() const {
40 return GetAxisIndex<T>(axis);
41 }
42 Axis axis;
43 };
44
45 struct NumAxisFunc {
46 template <Layout T>
operator ()tflite::gpu::__anon55d2d3670111::NumAxisFunc47 int operator()() const {
48 return Size<T>();
49 }
50 };
51
52 } // namespace
53
ToString(Axis axis)54 std::string ToString(Axis axis) {
55 switch (axis) {
56 case Axis::BATCH:
57 return "batch";
58 case Axis::CHANNELS:
59 return "channels";
60 case Axis::INPUT_CHANNELS:
61 return "input_channels";
62 case Axis::OUTPUT_CHANNELS:
63 return "output_channels";
64 case Axis::HEIGHT:
65 return "height";
66 case Axis::WIDTH:
67 return "width";
68 case Axis::VALUE:
69 return "value";
70 case Axis::DEPTH:
71 return "depth";
72 case Axis::UNKNOWN:
73 return "unknown";
74 }
75 return "undefined";
76 }
77
ToString(Layout layout)78 std::string ToString(Layout layout) {
79 switch (layout) {
80 case Layout::SCALAR:
81 return "scalar";
82 case Layout::LINEAR:
83 return "linear";
84 case Layout::HW:
85 return "hw";
86 case Layout::HWD:
87 return "hwd";
88 case Layout::CHW:
89 return "chw";
90 case Layout::HWC:
91 return "hwc";
92 case Layout::HWDC:
93 return "hwdc";
94 case Layout::OHWI:
95 return "ohwi";
96 case Layout::IHWO:
97 return "ihwo";
98 case Layout::OIHW:
99 return "oihw";
100 case Layout::IOHW:
101 return "iohw";
102 case Layout::BHWC:
103 return "bhwc";
104 case Layout::BHWDC:
105 return "bhwdc";
106 case Layout::OHWDI:
107 return "ohwi";
108 case Layout::UNKNOWN:
109 return "unknown";
110 }
111 return "undefined";
112 }
113
GetAxis(Layout layout,int32_t index)114 Axis GetAxis(Layout layout, int32_t index) {
115 return DispatchByLayout(layout, GetAxisByIndexFunc{index});
116 }
117
GetAxisIndex(Layout layout,Axis axis)118 int GetAxisIndex(Layout layout, Axis axis) {
119 return DispatchByLayout(layout, GetIndexByAxisFunc{axis});
120 }
121
HasAxis(Layout layout,Axis axis)122 bool HasAxis(Layout layout, Axis axis) {
123 return GetAxisIndex(layout, axis) >= 0;
124 }
125
Size(Layout layout)126 int Size(Layout layout) { return DispatchByLayout(layout, NumAxisFunc()); }
127
ToString(const Shape & s)128 std::string ToString(const Shape& s) {
129 return absl::StrCat("{", ToString(s.layout), ", {",
130 absl::StrJoin(s.dimensions, ", "), "}}");
131 }
132
133 } // namespace gpu
134 } // namespace tflite
135