1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
17 
18 #include <map>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/substitute.h"
23 #include "absl/time/clock.h"
24 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
25 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/operations.h"
28 #include "tensorflow/lite/delegates/gpu/common/precision.h"
29 #include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
30 #include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
31 #include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
32 #include "tensorflow/lite/delegates/gpu/common/shape.h"
33 #include "tensorflow/lite/delegates/gpu/common/status.h"
34 #include "tensorflow/lite/delegates/gpu/common/task/storage_type_util.h"
35 #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h"
36 #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h"
37 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
38 #include "tensorflow/lite/delegates/gpu/common/util.h"
39 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
40 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
41 
42 namespace tflite {
43 namespace gpu {
44 namespace metal {
45 namespace {
46 
47 // returns true if actual memory for this storage type is buffer
IsBufferBased(const TensorStorageType & type)48 bool IsBufferBased(const TensorStorageType& type) {
49   return type == TensorStorageType::BUFFER ||
50          type == TensorStorageType::IMAGE_BUFFER;
51 }
52 
HasIntersection(const std::vector<ValueId> & vec_ids,const std::set<ValueId> & ids)53 bool HasIntersection(const std::vector<ValueId>& vec_ids,
54                      const std::set<ValueId>& ids) {
55   for (ValueId id : vec_ids) {
56     if (ids.find(id) != ids.end()) {
57       return true;
58     }
59   }
60   return false;
61 }
62 
IsReady(const std::set<ValueId> & ready_tensors,const MetalNode & node)63 bool IsReady(const std::set<ValueId>& ready_tensors, const MetalNode& node) {
64   for (const ValueId in_id : node.inputs) {
65     if (ready_tensors.find(in_id) == ready_tensors.end()) {
66       return false;
67     }
68   }
69   return true;
70 }
71 
AddUsage(ValueId id,int task_index,std::map<ValueId,int2> * usage_records)72 void AddUsage(ValueId id, int task_index,
73               std::map<ValueId, int2>* usage_records) {
74   auto it = usage_records->find(id);
75   if (it == usage_records->end()) {
76     // initializing start index(.x) and end index(.y)
77     (*usage_records)[id].x = task_index;
78     (*usage_records)[id].y = task_index;
79   } else {
80     // updating end index(.y)
81     (*usage_records)[id].y = task_index;
82   }
83 }
84 
85 // Generic add is add that have several runtime inputs and they are not
86 // broadcasted, i.e. pointwise add for N tensors where N > 1.
IsGenericAdd(const Node & node,const std::vector<Value * > & inputs,const std::vector<Value * > & outputs)87 bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
88                   const std::vector<Value*>& outputs) {
89   if (inputs.size() == 1) {
90     return false;
91   }
92   const OperationType op_type = OperationTypeFromString(node.operation.type);
93   if (op_type != OperationType::ADD) {
94     return false;
95   }
96 
97   const auto dst_shape = outputs[0]->tensor.shape;
98   for (int i = 0; i < inputs.size(); ++i) {
99     const auto src_shape = inputs[i]->tensor.shape;
100     if (dst_shape.b != src_shape.b && src_shape.b == 1) {
101       return false;
102     }
103     if (dst_shape.h != src_shape.h && src_shape.h == 1) {
104       return false;
105     }
106     if (dst_shape.w != src_shape.w && src_shape.w == 1) {
107       return false;
108     }
109     if (dst_shape.c != src_shape.c && src_shape.c == 1) {
110       return false;
111     }
112   }
113   return true;
114 }
115 
MergeNodes(MetalNode * src,MetalNode * dst)116 absl::Status MergeNodes(MetalNode* src, MetalNode* dst) {
117   for (int j = 1; j < src->inputs.size(); ++j) {
118     dst->inputs.push_back(src->inputs[j]);
119   }
120   dst->outputs[0] = src->outputs[0];
121   dst->name += " linked : " + src->name;
122   return dst->task.AddTask(&src->task);
123 }
124 }  // namespace
125 
InitFromGraphWithTransforms(const CreateInferenceInfo & create_info,GraphFloat32 * graph,id<MTLDevice> device_id)126 absl::Status InferenceContext::InitFromGraphWithTransforms(
127     const CreateInferenceInfo& create_info, GraphFloat32* graph,
128     id<MTLDevice> device_id) {
129   RETURN_IF_ERROR(RunGraphTransforms(graph));
130   RETURN_IF_ERROR(InitFromGraph(create_info, *graph, device_id));
131   return absl::OkStatus();
132 }
133 
InitFromGraph(const CreateInferenceInfo & create_info,const GraphFloat32 & graph,id<MTLDevice> device_id)134 absl::Status InferenceContext::InitFromGraph(
135     const CreateInferenceInfo& create_info, const GraphFloat32& graph,
136     id<MTLDevice> device_id) {
137   std::set<ValueId> preallocated_ids;
138   const auto inputs = graph.inputs();
139   for (const auto& input : inputs) {
140     input_ids_.push_back(input->id);
141     preallocated_ids.insert(input->id);
142   }
143 
144   const auto outputs = graph.outputs();
145   for (const auto& output : outputs) {
146     output_ids_.push_back(output->id);
147     preallocated_ids.insert(output->id);
148   }
149   precision_ = create_info.precision;
150 
151   MetalDevice metal_device(device_id);
152   ReserveGraphTensors(create_info, metal_device.GetInfo(), graph,
153                       preallocated_ids);
154   RETURN_IF_ERROR(Compile(graph, metal_device.GetInfo(), create_info.hints));
155   RETURN_IF_ERROR(Merge());
156   RETURN_IF_ERROR(CompileOperations(&metal_device));
157   RETURN_IF_ERROR(AllocateTensors(&metal_device, preallocated_ids));
158   BindTensorsToOperations();
159   RETURN_IF_ERROR(UpdateParams(metal_device.GetInfo()));
160   RETURN_IF_ERROR(Tune(TuningType::kFast, &metal_device));
161   return absl::OkStatus();
162 }
163 
ReserveGraphTensors(const CreateInferenceInfo & create_info,const GpuInfo & gpu_info,const GraphFloat32 & graph,const std::set<ValueId> & preallocated_ids)164 void InferenceContext::ReserveGraphTensors(
165     const CreateInferenceInfo& create_info, const GpuInfo& gpu_info,
166     const GraphFloat32& graph, const std::set<ValueId>& preallocated_ids) {
167   ValueId max_id = 0;
168   auto tensors = graph.values();
169   auto data_type = DeduceDataTypeFromPrecision(create_info.precision);
170   for (auto& t : tensors) {
171     TensorStorageType storage_type = create_info.storage_type;
172     if (preallocated_ids.find(t->id) != preallocated_ids.end()) {
173       storage_type = TensorStorageType::BUFFER;
174     }
175     const auto shape = graph.GetValue(t->id)->tensor.shape;
176     Layout layout = shape.b == 1 ? Layout::HWC : Layout::BHWC;
177     // Temporary disabled because no support of SINGLE_TEXTURE_2D in Metal
178     // Metal supports only BUFFER storage type currently
179     // if (graph.IsGraphInput(t->id) || graph.IsGraphOutput(t->id)) {
180     //   if (false && shape.c < 4 &&
181     //       CanCreateTensorWithShape(
182     //           gpu_info, shape,
183     //           TensorDescriptor{data_type,
184     //           TensorStorageType::SINGLE_TEXTURE_2D,
185     //                            layout})) {
186     //     storage_type = TensorStorageType::SINGLE_TEXTURE_2D;
187     //   }
188     // }
189     storage_type =
190         SelectBestStorageType(gpu_info, shape, storage_type, data_type, layout);
191     tensor_reserver_.Add(
192         t->id, {shape, TensorDescriptor{data_type, storage_type, layout}});
193     max_id = std::max(max_id, t->id);
194   }
195   tensor_reserver_.SetNext(max_id + 1);
196 }
197 
Compile(const GraphFloat32 & graph,const GpuInfo & gpu_info,ModelHints hints)198 absl::Status InferenceContext::Compile(const GraphFloat32& graph,
199                                        const GpuInfo& gpu_info,
200                                        ModelHints hints) {
201   if (!IsBatchMatchesForAllValues(graph)) {
202     return absl::InvalidArgumentError(
203         "Only identical batch dimension is supported");
204   }
205   std::map<ValueId, TensorDescriptor> tensor_descriptors;
206   const auto values = graph.values();
207   for (auto value : values) {
208     tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
209   }
210   std::set<NodeId> consumed_nodes;
211   std::map<ValueId, int>
212       tensor_usages;  // keeps latest index of operation that updated tensor
213   for (const auto& input_id : input_ids_) {
214     tensor_usages[input_id] = -1;  // so as inputs "updated" before operation 0,
215                                    // we will mark them with -1
216   }
217   std::vector<Node*> graph_nodes = graph.nodes();
218   for (int i = 0; i < graph_nodes.size(); ++i) {
219     const Node& node = *graph_nodes[i];
220     auto op_type = OperationTypeFromString(node.operation.type);
221     if (op_type == OperationType::CONSTANT) {
222       auto attr =
223           absl::any_cast<ConstTensorAttributes>(node.operation.attributes);
224       auto outputs = graph.FindOutputs(node.id);
225       const_tensors_descs_[outputs[0]->id] =
226           tensor_reserver_.Get(outputs[0]->id).descriptor;
227       const_tensors_descs_[outputs[0]->id].UploadData(attr.tensor);
228       continue;
229     }
230     std::string op_name = node.operation.type + " " + std::to_string(node.id);
231     GPUOperationsSubgraph gpu_subgraph;
232     if (hints.Check(ModelHints::kAllowSpecialKernels) &&
233         GPUSubgraphFromGraph(gpu_info, precision_, graph, node.id,
234                              tensor_descriptors, &consumed_nodes, &gpu_subgraph,
235                              &op_name)
236             .ok()) {
237       // Mapping of subgraph (set of nodes) to GPU operations. Should happen
238       // before straigtforward mapping.
239     } else {
240       // Straigtforward mapping of one graph node to GPU operations.
241       auto inputs = graph.FindInputs(node.id);
242       auto outputs = graph.FindOutputs(node.id);
243       // Reordering of input ids and updating of temporary tensors_usage struct.
244       // This stage is necessary because we are building OperationDef that rely
245       // on order of input ids. But we also should have input id on first
246       // position that potentially can be "linking" tensor and as result
247       // eliminated(unused) We apply it only for ADD operation, because of ADD
248       // associativity and ADD can be linked. In current approach "linking"
249       // tensor can be only latest written tensor(during linear order of
250       // execution) among input tensors.
251       if (IsGenericAdd(node, inputs, outputs)) {
252         int latest_written_tensor_index = 0;
253         int last_usage = tensor_usages[inputs[0]->id];
254         for (int j = 1; j < inputs.size(); ++j) {
255           if (tensor_usages[inputs[j]->id] > last_usage) {
256             last_usage = tensor_usages[inputs[j]->id];
257             latest_written_tensor_index = j;
258           }
259         }
260         std::swap(inputs[0], inputs[latest_written_tensor_index]);
261       }
262       consumed_nodes.insert(node.id);
263       OperationDef op_def;
264       op_def.precision = precision_;
265       for (int j = 0; j < inputs.size(); ++j) {
266         op_def.src_tensors.push_back(
267             tensor_reserver_.Get(inputs[j]->id).descriptor);
268       }
269       for (int j = 0; j < outputs.size(); ++j) {
270         op_def.dst_tensors.push_back(
271             tensor_reserver_.Get(outputs[j]->id).descriptor);
272       }
273       RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, hints, inputs,
274                                            outputs, node, &gpu_subgraph));
275     }
276     std::map<int, ValueId> mapping_to_global_ids;
277     for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
278       const auto& t = gpu_subgraph.new_tensors[j];
279       auto global_id = tensor_reserver_.Add({t.first, t.second});
280       mapping_to_global_ids[j] = global_id;
281     }
282     for (auto& gpu_op : gpu_subgraph.operations) {
283       MetalNode metal_node;
284       metal_node.task.Init(std::move(gpu_op.operation));
285       metal_node.inputs.resize(gpu_op.input_ids.size());
286       for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
287         int id = gpu_op.input_ids[j];
288         if (id >= 0) {
289           metal_node.inputs[j] = id;
290         } else {
291           metal_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
292         }
293       }
294       metal_node.outputs.resize(gpu_op.output_ids.size());
295       for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
296         int id = gpu_op.output_ids[j];
297         if (id >= 0) {
298           metal_node.outputs[j] = id;
299           tensor_usages[id] = i;
300         } else {
301           metal_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
302         }
303       }
304       metal_node.name = op_name;
305       nodes_.push_back(std::move(metal_node));
306     }
307   }
308   return absl::OkStatus();
309 }
310 
Merge()311 absl::Status InferenceContext::Merge() {
312   std::set<ValueId> ready_tensors;
313   for (const auto& input_id : input_ids_) {
314     ready_tensors.insert(input_id);
315   }
316   for (int i = 0; i < nodes_.size(); ++i) {
317     auto& node = nodes_[i];
318     for (const auto& out_id : node.outputs) {
319       ready_tensors.insert(out_id);
320     }
321     if (node.outputs.size() != 1) {
322       continue;
323     }
324     std::vector<int> next_nodes;
325     int link_index = 0;
326     for (int j = i + 1; j < nodes_.size(); ++j) {
327       for (int k = 0; k < nodes_[j].inputs.size(); ++k) {
328         if (nodes_[j].inputs[k] == node.outputs[0]) {
329           next_nodes.push_back(j);
330           link_index = k;
331         }
332       }
333     }
334     if (next_nodes.size() != 1 || link_index != 0) {
335       continue;
336     }
337     auto& linkable_node = nodes_[next_nodes[0]];
338     if (!linkable_node.task.IsLinkable() || linkable_node.outputs.size() != 1 ||
339         !IsReady(ready_tensors, linkable_node)) {
340       continue;
341     }
342     const auto& original_dst_def = node.task.GetDefinition().dst_tensors[0];
343     const auto& link_dst_def =
344         linkable_node.task.GetDefinition().dst_tensors[0];
345     if (original_dst_def != link_dst_def) {
346       continue;
347     }
348     RETURN_IF_ERROR(MergeNodes(&linkable_node, &node));
349     nodes_.erase(nodes_.begin() + next_nodes[0]);
350     i -= 1;
351   }
352   return absl::OkStatus();
353 }
354 
CompileOperations(MetalDevice * device)355 absl::Status InferenceContext::CompileOperations(MetalDevice* device) {
356   for (auto& node : nodes_) {
357     RETURN_IF_ERROR(node.task.Compile(device));
358   }
359   return absl::OkStatus();
360 }
361 
AllocateTensors(MetalDevice * device,const std::set<ValueId> & preallocated_ids)362 absl::Status InferenceContext::AllocateTensors(
363     MetalDevice* device, const std::set<ValueId>& preallocated_ids) {
364   for (int i = 0; i < nodes_.size(); ++i) {
365     auto& node = nodes_[i];
366     if (HasIntersection(node.inputs, preallocated_ids) ||
367         HasIntersection(node.outputs, preallocated_ids)) {
368       task_ids_with_preallocated_tensors_.push_back(i);
369     }
370   }
371 
372   for (auto& tensor_id : preallocated_ids) {
373     const auto& t = tensor_reserver_.Get(tensor_id);
374     RETURN_IF_ERROR(CreateSharedBufferTensor(
375         nil, t.shape, t.descriptor, &preallocated_tensors_[tensor_id]));
376   }
377 
378   RETURN_IF_ERROR(AllocateMemoryForConstTensors(device));
379   RETURN_IF_ERROR(AllocateMemoryForBuffers(device));
380   RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device));
381   return absl::OkStatus();
382 }
383 
GetTensor(ValueId tensor_id)384 MetalSpatialTensor* InferenceContext::GetTensor(ValueId tensor_id) {
385   if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end()) {
386     return &preallocated_tensors_[tensor_id];
387   } else if (const_tensors_.find(tensor_id) != const_tensors_.end()) {
388     return &const_tensors_[tensor_id];
389   } else if (graph_ids_to_shared_buffer_tensors_.find(tensor_id) !=
390              graph_ids_to_shared_buffer_tensors_.end()) {
391     return &shared_buffer_tensors_
392         [graph_ids_to_shared_buffer_tensors_[tensor_id]];
393   } else if (graph_ids_to_strong_shape_tensors_.find(tensor_id) !=
394              graph_ids_to_strong_shape_tensors_.end()) {
395     return &strong_shape_tensors_
396         [graph_ids_to_strong_shape_tensors_[tensor_id]];
397   }
398   return nullptr;
399 }
400 
BindTensorsToOperations()401 void InferenceContext::BindTensorsToOperations() {
402   for (auto& node : nodes_) {
403     const auto& src_ids = node.inputs;
404     for (int i = 0; i < src_ids.size(); ++i) {
405       node.task.SetSrcTensor(GetTensor(src_ids[i]), i);
406     }
407     const auto& dst_ids = node.outputs;
408     for (int i = 0; i < dst_ids.size(); ++i) {
409       node.task.SetDstTensor(GetTensor(dst_ids[i]), i);
410     }
411   }
412 }
413 
UpdateParams(const GpuInfo & gpu_info)414 absl::Status InferenceContext::UpdateParams(const GpuInfo& gpu_info) {
415   for (auto& node : nodes_) {
416     std::vector<BHWC> src_shapes;
417     std::vector<BHWC> dst_shapes;
418     for (const auto& in_id : node.inputs) {
419       src_shapes.push_back(tensor_reserver_.Get(in_id).shape);
420     }
421     for (const auto& out_id : node.outputs) {
422       dst_shapes.push_back(tensor_reserver_.Get(out_id).shape);
423     }
424     RETURN_IF_ERROR(node.task.UpdateParams());
425   }
426   return absl::OkStatus();
427 }
428 
GetTensorMemoryType(ValueId id)429 InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
430     ValueId id) {
431   if (preallocated_tensors_.find(id) != preallocated_tensors_.end()) {
432     return TensorMemoryType::kPreallocated;
433   } else if (const_tensors_.find(id) != const_tensors_.end()) {
434     return TensorMemoryType::kConst;
435   } else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) {
436     return TensorMemoryType::kBuffer;
437   } else {
438     return TensorMemoryType::kStrongShape;
439   }
440 }
441 
GetUsages(const std::function<bool (ValueId)> & functor,std::map<ValueId,int2> * usages)442 void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
443                                  std::map<ValueId, int2>* usages) {
444   for (ValueId in_id : input_ids_) {
445     if (functor(in_id)) {
446       AddUsage(in_id, 0, usages);
447     }
448   }
449   for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
450     for (auto& tensor_id : nodes_[op_index].inputs) {
451       if (functor(tensor_id)) {
452         AddUsage(tensor_id, op_index, usages);
453       }
454     }
455     for (auto& tensor_id : nodes_[op_index].outputs) {
456       if (functor(tensor_id)) {
457         AddUsage(tensor_id, op_index, usages);
458       }
459     }
460   }
461   for (ValueId out_id : output_ids_) {
462     if (functor(out_id)) {
463       AddUsage(out_id, nodes_.size(), usages);
464     }
465   }
466 }
467 
AllocateMemoryForConstTensors(MetalDevice * device)468 absl::Status InferenceContext::AllocateMemoryForConstTensors(
469     MetalDevice* device) {
470   for (auto& description : const_tensors_descs_) {
471     RETURN_IF_ERROR(const_tensors_[description.first].CreateFromDescriptor(
472         description.second, device->device()));
473   }
474   const_tensors_descs_.clear();
475   return absl::OkStatus();
476 }
477 
AllocateMemoryForBuffers(MetalDevice * device)478 absl::Status InferenceContext::AllocateMemoryForBuffers(MetalDevice* device) {
479   std::map<ValueId, int2> buffer_usages;
480   GetUsages(
481       [this](ValueId id) {
482         return GetTensorMemoryType(id) == TensorMemoryType::kBuffer;
483       },
484       &buffer_usages);
485 
486   std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
487   for (auto& usage : buffer_usages) {
488     const auto& shape = tensor_reserver_.Get(usage.first).shape;
489     const size_t buffer_size =
490         shape.b * shape.w * shape.h * AlignByN(shape.c, 4);
491     graph_ids_to_shared_buffer_tensors_[usage.first] =
492         buffer_usage_records.size();
493     buffer_usage_records.push_back({buffer_size,
494                                     static_cast<TaskId>(usage.second.x),
495                                     static_cast<TaskId>(usage.second.y)});
496   }
497 
498   ObjectsAssignment<size_t> buffer_assignment;
499   RETURN_IF_ERROR(AssignObjectsToTensors(
500       buffer_usage_records, MemoryStrategy::GREEDY_BEST, &buffer_assignment));
501 
502   const bool f32_storage = precision_ == CalculationsPrecision::F32;
503   size_t dataTypeSize = f32_storage ? sizeof(float) : sizeof(HalfBits);
504   shared_buffers_.resize(buffer_assignment.object_sizes.size());
505   for (int i = 0; i < buffer_assignment.object_sizes.size(); ++i) {
506     // Initialize metal buffer
507     NSUInteger bufferSize = dataTypeSize * buffer_assignment.object_sizes[i];
508 
509     if (bufferSize > device->GetInfo().GetMaxBufferSize()) {
510       std::string error("Tensor id: ");
511       error += std::to_string(buffer_assignment.object_ids[i]) +
512                " with size: " + std::to_string(bufferSize) +
513                " exceeds MTLDevice maxBufferLength: " +
514                std::to_string(device->GetInfo().GetMaxBufferSize());
515       return absl::ResourceExhaustedError(error);
516     }
517 
518     shared_buffers_[i] =
519         [device->device() newBufferWithLength:bufferSize
520                                       options:MTLResourceStorageModeShared];
521   }
522 
523   std::vector<bool> created_tensors(buffer_usage_records.size(), false);
524   shared_buffer_tensors_.resize(buffer_usage_records.size());
525   for (auto& node : nodes_) {
526     std::vector<ValueId> all_ids = node.inputs;
527     all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
528     for (auto& tensor_id : all_ids) {
529       if (preallocated_tensors_.find(tensor_id) != preallocated_tensors_.end())
530         continue;
531       const int tensor_index = graph_ids_to_shared_buffer_tensors_[tensor_id];
532       if (created_tensors[tensor_index]) continue;
533       const auto& tensor_dummy = tensor_reserver_.Get(tensor_id);
534       const int buffer_index = buffer_assignment.object_ids[tensor_index];
535       RETURN_IF_ERROR(CreateSharedBufferTensor(
536           shared_buffers_[buffer_index], tensor_dummy.shape,
537           tensor_dummy.descriptor, &shared_buffer_tensors_[tensor_index]));
538       created_tensors[tensor_index] = true;
539     }
540   }
541   return absl::OkStatus();
542 }
543 
AllocateMemoryForStrongShapes(MetalDevice * device)544 absl::Status InferenceContext::AllocateMemoryForStrongShapes(
545     MetalDevice* device) {
546   std::map<ValueId, int2> usages;
547   GetUsages(
548       [this](ValueId id) {
549         return GetTensorMemoryType(id) == TensorMemoryType::kStrongShape;
550       },
551       &usages);
552 
553   std::vector<TensorUsageRecord<DummyTensor>> usage_records;
554   std::map<ValueId, ValueId> remap_from_graph_ids;
555   for (auto& usage : usages) {
556     remap_from_graph_ids[usage.first] = usage_records.size();
557     usage_records.push_back({tensor_reserver_.Get(usage.first),
558                              static_cast<TaskId>(usage.second.x),
559                              static_cast<TaskId>(usage.second.y)});
560   }
561 
562   ObjectsAssignment<DummyTensor> assignment;
563   RETURN_IF_ERROR(AssignObjectsToTensors(
564       usage_records, MemoryStrategy::EQUALITY, &assignment));
565 
566   for (auto& node : nodes_) {
567     std::vector<ValueId> all_ids = node.inputs;
568     all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
569     for (auto& tensor_id : all_ids) {
570       const auto& tensor_dummy = tensor_reserver_.Get(tensor_id);
571       if (GetTensorMemoryType(tensor_id) != TensorMemoryType::kStrongShape) {
572         continue;
573       }
574       const auto id = assignment.object_ids[remap_from_graph_ids[tensor_id]];
575       graph_ids_to_strong_shape_tensors_[tensor_id] = id;
576       const auto& it = strong_shape_tensors_.find(id);
577       if (it == strong_shape_tensors_.end()) {
578         RETURN_IF_ERROR(CreateTensor(device->device(), tensor_dummy.shape,
579                                      tensor_dummy.descriptor,
580                                      &strong_shape_tensors_[id]));
581       }
582     }
583   }
584   return absl::OkStatus();
585 }
586 
Tune(TuningType tuning_type,MetalDevice * device)587 absl::Status InferenceContext::Tune(TuningType tuning_type,
588                                     MetalDevice* device) {
589   for (auto& node : nodes_) {
590     RETURN_IF_ERROR(node.task.Tune(tuning_type, device));
591   }
592   return absl::OkStatus();
593 }
594 
EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder)595 void InferenceContext::EncodeWithEncoder(
596     id<MTLComputeCommandEncoder> command_encoder) {
597   for (int i = 0; i < nodes_.size(); ++i) {
598     auto& task = nodes_[i].task;
599     task.Encode(command_encoder);
600   }
601 }
602 
Profile(id<MTLDevice> device,ProfilingInfo * result)603 void InferenceContext::Profile(id<MTLDevice> device, ProfilingInfo* result) {
604   result->dispatches.resize(nodes_.size());
605   id<MTLCommandQueue> command_queue = [device newCommandQueue];
606   for (int k = 0; k < nodes_.size(); ++k) {
607     @autoreleasepool {
608       id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
609       id<MTLComputeCommandEncoder> encoder =
610           [command_buffer computeCommandEncoder];
611       auto& task = nodes_[k].task;
612       const int kRuns = 500;
613       for (int i = 0; i < kRuns; ++i) {
614         task.Encode(encoder);
615       }
616       [encoder endEncoding];
617       auto start = absl::Now();
618       [command_buffer commit];
619       [command_buffer waitUntilCompleted];
620       auto end = absl::Now();
621       auto& dispatch_info = result->dispatches[k];
622       dispatch_info.label = nodes_[k].name;
623       dispatch_info.duration = (end - start) / static_cast<float>(kRuns);
624     }
625   }
626 }
627 
EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer)628 void InferenceContext::EncodeWithCommandBuffer(
629     id<MTLCommandBuffer> command_buffer) {
630   for (int i = 0; i < nodes_.size(); ++i) {
631     id<MTLComputeCommandEncoder> encoder =
632         [command_buffer computeCommandEncoder];
633     auto& task = nodes_[i].task;
634     task.Encode(encoder);
635     [encoder endEncoding];
636   }
637 }
638 
EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,int flush_period)639 void InferenceContext::EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,
640                                               int flush_period) {
641   id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
642   for (int i = 0; i < nodes_.size(); ++i) {
643     id<MTLComputeCommandEncoder> encoder =
644         [command_buffer computeCommandEncoder];
645     auto& task = nodes_[i].task;
646     task.Encode(encoder);
647     [encoder endEncoding];
648     if (i % flush_period == (flush_period - 1)) {
649       [command_buffer commit];
650       command_buffer = [command_queue commandBuffer];
651     }
652   }
653   [command_buffer commit];
654 }
655 
UpdatePreallocatedTensors(const std::map<ValueId,id<MTLBuffer>> & preallocated)656 void InferenceContext::UpdatePreallocatedTensors(
657     const std::map<ValueId, id<MTLBuffer>>& preallocated) {
658   for (const auto& it : preallocated) {
659     auto status = preallocated_tensors_[it.first].SetBufferHandle(it.second);
660   }
661   for (auto& task_index : task_ids_with_preallocated_tensors_) {
662     auto& task = nodes_[task_index].task;
663     const auto& src_ids = nodes_[task_index].inputs;
664     for (int i = 0; i < src_ids.size(); ++i) {
665       const auto& it = preallocated_tensors_.find(src_ids[i]);
666       if (it != preallocated_tensors_.end()) {
667         task.SetSrcTensor(&it->second, i);
668       }
669     }
670     const auto& dst_ids = nodes_[task_index].outputs;
671     for (int i = 0; i < dst_ids.size(); ++i) {
672       const auto& it = preallocated_tensors_.find(dst_ids[i]);
673       if (it != preallocated_tensors_.end()) {
674         task.SetDstTensor(&it->second, i);
675       }
676     }
677   }
678 }
679 
RunGraphTransforms(GraphFloat32 * graph)680 absl::Status RunGraphTransforms(GraphFloat32* graph) {
681   auto merge_padding_transform = NewMergePaddingWithAdd();
682   auto add_bias_transform = NewAddBias();
683   auto pooling_to_reduce_op = NewGlobalPoolingToReduceOp();
684   ModelTransformer transformer(graph, /*reporter=*/nullptr);
685   if (!transformer.Apply("add_bias", add_bias_transform.get())) {
686     return absl::InternalError("Invalid add_bias transform");
687   }
688   if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
689     return absl::InternalError("Invalid merge_padding transform");
690   }
691   if (!transformer.Apply("global pooling to mean",
692                          pooling_to_reduce_op.get())) {
693     return absl::InternalError("Invalid global pooling to mean transform");
694   }
695   return absl::OkStatus();
696 }
697 
698 }  // namespace metal
699 }  // namespace gpu
700 }  // namespace tflite
701