1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#import "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
17
18#import <Metal/Metal.h>
19
20#include "tensorflow/lite/delegates/gpu/common/shape.h"
21#include "tensorflow/lite/delegates/gpu/common/util.h"
22#include "tensorflow/lite/delegates/gpu/metal/common.h"
23
24using ::tflite::gpu::BHWC;
25using ::tflite::gpu::DivideRoundUp;
26using ::tflite::gpu::metal::CreateComputeProgram;
27
28@implementation TFLBufferConvert {
29  id<MTLComputePipelineState> _program;
30}
31
32- (id)initWithDevice:(id<MTLDevice>)device
33           isFloat16:(bool)isFloat16
34     convertToPBHWC4:(bool)convertToPBHWC4 {
35  if (self = [super init]) {
36    std::string shaderSource;
37    if (convertToPBHWC4) {
38      shaderSource = R"(
39        #include <metal_stdlib>
40        using namespace metal;
41        kernel void ComputeFunction(device float* const input_buffer [[buffer(0)]],
42                                    device FLT4* output_buffer [[buffer(1)]],
43                                    constant int4& size [[buffer(2)]],
44                                    uint3 gid[[thread_position_in_grid]]) {
45          int linear_id = static_cast<int>(gid.x);
46          int X = linear_id / size.w;
47          int B = linear_id % size.w;
48          int Y = static_cast<int>(gid.y);
49          int S = static_cast<int>(gid.z);
50          if (X >= size.x || Y >= size.y) {
51            return;
52          }
53          FLT4 value = FLT4(0.0);
54          for (int i = 0; i < 4; i++) {
55            int channel = S * 4 + i;
56            if (channel >= size.z) break;
57            const int bhwc_index = ((B * size.y + Y) * size.x + X) * size.z + channel;
58            value[i] = input_buffer[bhwc_index];
59          }
60          const int shwbc4_index = ((S * size.y + Y) * size.x + X) * size.w + B;
61          output_buffer[shwbc4_index] = value;
62        }
63      )";
64    } else {
65      shaderSource = R"(
66        #include <metal_stdlib>
67        using namespace metal;
68        kernel void ComputeFunction(device FLT4* const input_buffer [[buffer(0)]],
69                                    device float* output_buffer [[buffer(1)]],
70                                    constant int4& size [[buffer(2)]],
71                                    uint3 gid[[thread_position_in_grid]]) {
72          int linear_id = static_cast<int>(gid.x);
73          int X = linear_id / size.w;
74          int B = linear_id % size.w;
75          int Y = static_cast<int>(gid.y);
76          int S = static_cast<int>(gid.z);
77          if (X >= size.x || Y >= size.y) {
78            return;
79          }
80          const int shwbc4_index = ((S * size.y + Y) * size.x + X) * size.w + B;
81          FLT4 value = input_buffer[shwbc4_index];
82          for (int i = 0; i < 4; i++) {
83            int channel = S * 4 + i;
84            if (channel >= size.z) break;
85            const int bhwc_index = ((B * size.y + Y) * size.x + X) * size.z + channel;
86            output_buffer[bhwc_index] = value[i];
87          }
88        }
89      )";
90    }
91    NSDictionary* macros = @{@"FLT4" : (isFloat16 ? @"half4" : @"float4")};
92    NSString* code = [NSString stringWithCString:shaderSource.c_str()
93                                        encoding:[NSString defaultCStringEncoding]];
94    id<MTLComputePipelineState> program;
95    if (CreateComputeProgram(device, code, @"ComputeFunction", macros, &program).ok()) {
96      _program = program;
97      return self;
98    }
99  }
100  return nil;
101}
102
103- (void)convertWithEncoder:(id<MTLComputeCommandEncoder>)encoder
104                     shape:(const BHWC&)shape
105              sourceBuffer:(id<MTLBuffer>)sourceBuffer
106           convertedBuffer:(id<MTLBuffer>)convertedBuffer {
107  [encoder setComputePipelineState:_program];
108  [encoder setBuffer:sourceBuffer offset:0 atIndex:0];
109  [encoder setBuffer:convertedBuffer offset:0 atIndex:1];
110
111  std::vector<int> uniforms = {shape.w, shape.h, shape.c, shape.b};
112  [encoder setBytes:uniforms.data() length:uniforms.size() * sizeof(int) atIndex:2];
113
114  MTLSize group_size = MTLSizeMake(16, 8, 1);
115  int slices = DivideRoundUp(shape.c, 4);
116  int groups_x = DivideRoundUp(shape.w * shape.b, group_size.width);
117  int groups_y = DivideRoundUp(shape.h, group_size.height);
118  int groups_z = DivideRoundUp(slices, group_size.depth);
119  MTLSize groups_count = MTLSizeMake(groups_x, groups_y, groups_z);
120  [encoder dispatchThreadgroups:groups_count threadsPerThreadgroup:group_size];
121}
122
123@end
124