1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
18 
19 #import <Metal/Metal.h>
20 
21 #include <list>
22 #include <map>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
28 #include "tensorflow/lite/delegates/gpu/common/precision.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h"
32 #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
33 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
34 #include "tensorflow/lite/delegates/gpu/metal/metal_device.h"
35 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
36 
37 namespace tflite {
38 namespace gpu {
39 namespace metal {
40 
41 struct MetalNode {
42   ComputeTask task;
43   std::vector<ValueId> inputs;
44   std::vector<ValueId> outputs;
45 
46   // Mostly for debug purposes.
47   std::string name;
48 
49   MetalNode() = default;
50 
51   MetalNode(MetalNode&& node) = default;
52   MetalNode& operator=(MetalNode&& node) = default;
53   MetalNode(const MetalNode&) = delete;
54   MetalNode& operator=(const MetalNode&) = delete;
55 };
56 
57 class InferenceContext {
58  public:
59   struct CreateInferenceInfo {
60     CalculationsPrecision precision;
61     TensorStorageType storage_type;
62     ModelHints hints;
63   };
64 
65   InferenceContext() = default;
66 
67   // IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for
68   // this graph upfront, otherwise not guaranteed correct behavior
69   absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
70                              const GraphFloat32& graph,
71                              id<MTLDevice> device_id);
72 
73   // Applies specific transformations to the graph before the
74   // initialization. These transformations are either impossible or useless in
75   // other backends.
76   absl::Status InitFromGraphWithTransforms(
77       const CreateInferenceInfo& create_info, GraphFloat32* graph,
78       id<MTLDevice> device_id);
79 
80   // Updates MTLBuffer handles in MetalSpatialTensors and kernels that use this
81   // tensors.
82   void UpdatePreallocatedTensors(
83       const std::map<ValueId, id<MTLBuffer>>& preallocated);
84 
85   /// Inserts all GPU compute tasks into the command encoder.
86   /// @param inputOutputBuffers Must be created and passed into the method
87   /// with pairs ID:buffer
88   /// @discussion No GPU synchronization functions are used inside. All GPU
89   /// resources must be created
90   ///             with the same device which has been used in
91   ///             compileModelWithDevice() method.
92   void EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder);
93 
94   /// Inserts all GPU compute tasks into the command buffer. For every task will
95   /// be used separate
96   ///   encoder.
97   /// @param inputOutputBuffers Must be created and passed into the method with
98   /// pairs ID:buffer
99   /// @discussion No GPU synchronization functions are used inside. All GPU
100   /// resources must be created
101   ///             with the same device which has been used in
102   ///             compileModelWithDevice() method.
103   void EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer);
104 
105   /// Adds all GPU compute tasks to the command queue. For every task will be
106   /// used separate
107   ///   encoder. Few encoders(flushPeriod) batched into compute buffer that sent
108   ///   for execution.
109   /// @param inputOutputBuffers Must be created and passed into the method with
110   /// pairs ID:buffer
111   /// @discussion No GPU synchronization functions are used inside. All GPU
112   /// resources must be created
113   ///             with the same device which has been used in
114   ///             compileModelWithDevice() method.
115   void EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,
116                               int flush_period);
117 
118   void Profile(id<MTLDevice> device, ProfilingInfo* result);
119 
120  private:
121   enum class TensorMemoryType {
122     kStrongShape,
123     kBuffer,
124     kVariable,
125     kConst,
126     kPreallocated
127   };
128   absl::Status Compile(const GraphFloat32& graph, const GpuInfo& gpu_info,
129                        ModelHints hints);
130 
131   void ReserveGraphTensors(const CreateInferenceInfo& create_info,
132                            const GpuInfo& gpu_info, const GraphFloat32& graph,
133                            const std::set<ValueId>& preallocated_ids);
134 
135   absl::Status CompileOperations(MetalDevice* device);
136 
137   absl::Status Merge();
138   absl::Status AllocateTensors(MetalDevice* device,
139                                const std::set<ValueId>& preallocated_ids);
140   absl::Status AllocateMemoryForConstTensors(MetalDevice* device);
141   absl::Status AllocateMemoryForBuffers(MetalDevice* device);
142   absl::Status AllocateMemoryForStrongShapes(MetalDevice* device);
143   void BindTensorsToOperations();
144   absl::Status UpdateParams(const GpuInfo& gpu_info);
145   MetalSpatialTensor* GetTensor(ValueId tensor_id);
146   void GetUsages(const std::function<bool(ValueId)>& functor,
147                  std::map<ValueId, int2>* usages);
148   TensorMemoryType GetTensorMemoryType(ValueId id);
149   absl::Status Tune(TuningType tuning_type, MetalDevice* device);
150 
151   struct DummyTensor {
152     BHWC shape;
153     TensorDescriptor descriptor;
154 
155     bool operator==(const DummyTensor& b) const {
156       return shape == b.shape && descriptor == b.descriptor;
157     }
158   };
159 
160   class TensorReserver {
161    public:
TensorReserver()162     TensorReserver() : next_(0) {}
Add(const DummyTensor & dummy)163     ValueId Add(const DummyTensor& dummy) {
164       reservations_[next_] = dummy;
165       return next_++;
166     }
Add(ValueId id,const DummyTensor & dummy)167     void Add(ValueId id, const DummyTensor& dummy) {
168       reservations_[id] = dummy;
169     }
SetNext(ValueId id)170     void SetNext(ValueId id) { next_ = id; }
Get(ValueId id)171     DummyTensor Get(ValueId id) { return reservations_[id]; }
172 
GetTensorDescs()173     std::vector<std::pair<ValueId, TensorDescriptor>> GetTensorDescs() const {
174       std::vector<std::pair<ValueId, TensorDescriptor>> result;
175       for (auto& v : reservations_) {
176         TensorDescriptor desc = v.second.descriptor;
177         desc.shape.b = v.second.shape.b;
178         desc.shape.h = v.second.shape.h;
179         desc.shape.w = v.second.shape.w;
180         desc.shape.d = 1;
181         desc.shape.c = v.second.shape.c;
182         result.push_back({v.first, desc});
183       }
184       return result;
185     }
186 
Add(const std::vector<std::pair<ValueId,TensorDescriptor>> & tensors)187     void Add(const std::vector<std::pair<ValueId, TensorDescriptor>>& tensors) {
188       for (auto& v : tensors) {
189         DummyTensor dummy;
190         dummy.descriptor = v.second;
191         dummy.shape.b = v.second.shape.b;
192         dummy.shape.h = v.second.shape.h;
193         dummy.shape.w = v.second.shape.w;
194         dummy.shape.c = v.second.shape.c;
195         Add(v.first, dummy);
196       }
197     }
198 
199    private:
200     absl::flat_hash_map<ValueId, DummyTensor> reservations_;
201     ValueId next_;
202   };
203   TensorReserver tensor_reserver_;
204 
205   std::vector<MetalNode> nodes_;
206   // contains indexes of compute_tasks_
207   std::vector<int> task_ids_with_preallocated_tensors_;
208   std::vector<ValueId> input_ids_;
209   std::vector<ValueId> output_ids_;
210   CalculationsPrecision precision_;
211   std::map<ValueId, MetalSpatialTensor> preallocated_tensors_;
212 
213   std::map<ValueId, TensorDescriptor> const_tensors_descs_;
214   std::map<ValueId, MetalSpatialTensor> const_tensors_;
215 
216   std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_;
217   std::vector<id<MTLBuffer>> shared_buffers_;
218   std::vector<MetalSpatialTensor>
219       shared_buffer_tensors_;  // use references to memory
220                                // from _sharedBuffers
221 
222   std::map<ValueId, MetalSpatialTensor> strong_shape_tensors_;
223   std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_;
224 };
225 
226 // Runs specific transforms for the graph.
227 absl::Status RunGraphTransforms(GraphFloat32* graph);
228 
229 }  // namespace metal
230 }  // namespace gpu
231 }  // namespace tflite
232 
233 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
234