1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/weights_conversion.h"
17 
18 namespace tflite {
19 namespace gpu {
GetTotalElementsCountForLayout(const WeightsDescription & weight_desc,const OHWI & shape)20 uint GetTotalElementsCountForLayout(const WeightsDescription& weight_desc,
21                                     const OHWI& shape) {
22   if (weight_desc.layout == WeightsLayout::kOHWIOGroupI4O4 ||
23       weight_desc.layout == WeightsLayout::kOHWIOGroupO4I4 ||
24       weight_desc.layout == WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4 ||
25       weight_desc.layout == WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4) {
26     uint i_aligned = AlignByN(shape.i, 4);
27     uint o_aligned = AlignByN(shape.o, 4 * weight_desc.output_group_size);
28     return i_aligned * o_aligned * shape.h * shape.w;
29   } else if (weight_desc.layout == WeightsLayout::kOICustomSpatialI4O4 ||
30              weight_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
31     uint i_aligned = AlignByN(shape.i, 4);
32     uint o_aligned = AlignByN(shape.o, 4);
33     return i_aligned * o_aligned * weight_desc.spatial_remap.size();
34   } else {
35     return -1;
36   }
37 }
38 
RearrangeWeights(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights,const WeightsDescription & dst_weight_desc,DataType dst_type,absl::Span<uint8_t> dst)39 void RearrangeWeights(
40     const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
41     const WeightsDescription& dst_weight_desc, DataType dst_type,
42     absl::Span<uint8_t> dst) {
43   const uint flt_count =
44       GetTotalElementsCountForLayout(dst_weight_desc, weights.shape);
45   if (dst_weight_desc.layout == WeightsLayout::kOHWIOGroupI4O4) {
46     if (dst_type == DataType::FLOAT32) {
47       float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
48       RearrangeWeightsToOHWIOGroupI4O4(weights,
49                                        dst_weight_desc.output_group_size,
50                                        absl::MakeSpan(f32_ptr, flt_count / 4));
51     } else if (dst_type == DataType::FLOAT16) {
52       half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
53       RearrangeWeightsToOHWIOGroupI4O4(weights,
54                                        dst_weight_desc.output_group_size,
55                                        absl::MakeSpan(f16_ptr, flt_count / 4));
56     }
57     return;
58   } else if (dst_weight_desc.layout == WeightsLayout::kOHWIOGroupO4I4) {
59     if (dst_type == DataType::FLOAT32) {
60       float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
61       RearrangeWeightsToOHWIOGroupO4I4(weights,
62                                        dst_weight_desc.output_group_size,
63                                        absl::MakeSpan(f32_ptr, flt_count / 4));
64     } else if (dst_type == DataType::FLOAT16) {
65       half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
66       RearrangeWeightsToOHWIOGroupO4I4(weights,
67                                        dst_weight_desc.output_group_size,
68                                        absl::MakeSpan(f16_ptr, flt_count / 4));
69     }
70     return;
71   } else if (dst_weight_desc.layout == WeightsLayout::kOICustomSpatialI4O4) {
72     if (dst_type == DataType::FLOAT32) {
73       float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
74       RearrangeWeightsToOICustomSpatialI4O4(
75           weights, dst_weight_desc.spatial_remap,
76           absl::MakeSpan(f32_ptr, flt_count / 4));
77     } else if (dst_type == DataType::FLOAT16) {
78       half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
79       RearrangeWeightsToOICustomSpatialI4O4(
80           weights, dst_weight_desc.spatial_remap,
81           absl::MakeSpan(f16_ptr, flt_count / 4));
82     }
83     return;
84   } else if (dst_weight_desc.layout == WeightsLayout::kOICustomSpatialO4I4) {
85     if (dst_type == DataType::FLOAT32) {
86       float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
87       RearrangeWeightsToOICustomSpatialO4I4(
88           weights, dst_weight_desc.spatial_remap,
89           absl::MakeSpan(f32_ptr, flt_count / 4));
90     } else if (dst_type == DataType::FLOAT16) {
91       half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
92       RearrangeWeightsToOICustomSpatialO4I4(
93           weights, dst_weight_desc.spatial_remap,
94           absl::MakeSpan(f16_ptr, flt_count / 4));
95     }
96     return;
97   } else if (dst_weight_desc.layout ==
98              WeightsLayout::k2DX4I4YIsHWIAndXIsOOGroupO4) {
99     if (dst_type == DataType::FLOAT32) {
100       float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
101       RearrangeWeightsToI4HWIOOGroupO4(weights,
102                                        dst_weight_desc.output_group_size,
103                                        absl::MakeSpan(f32_ptr, flt_count / 4));
104     } else if (dst_type == DataType::FLOAT16) {
105       half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
106       RearrangeWeightsToI4HWIOOGroupO4(weights,
107                                        dst_weight_desc.output_group_size,
108                                        absl::MakeSpan(f16_ptr, flt_count / 4));
109     }
110     return;
111   } else if (dst_weight_desc.layout ==
112              WeightsLayout::k2DX4O4YIsHWIAndXIsOOGroupI4) {
113     if (dst_type == DataType::FLOAT32) {
114       float4* f32_ptr = reinterpret_cast<float4*>(dst.data());
115       RearrangeWeightsToO4HWIOOGroupI4(weights,
116                                        dst_weight_desc.output_group_size,
117                                        absl::MakeSpan(f32_ptr, flt_count / 4));
118     } else if (dst_type == DataType::FLOAT16) {
119       half4* f16_ptr = reinterpret_cast<half4*>(dst.data());
120       RearrangeWeightsToO4HWIOOGroupI4(weights,
121                                        dst_weight_desc.output_group_size,
122                                        absl::MakeSpan(f16_ptr, flt_count / 4));
123     }
124     return;
125   }
126 }
127 
128 }  // namespace gpu
129 }  // namespace tflite
130