1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
17 #define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
18 
19 #if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \
20      defined(_M_X64))
21 #define TFLITE_X86_PLATFORM
22 #endif
23 
24 #include <memory>
25 
26 #include "public/gemmlowp.h"
27 #include "ruy/context.h"  // from @ruy
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/external_cpu_backend_context.h"
30 
31 namespace tflite {
32 
33 class CpuBackendContext final : public TfLiteInternalBackendContext {
34  public:
35   static CpuBackendContext* GetFromContext(TfLiteContext* context);
36 
37   CpuBackendContext();
38   ~CpuBackendContext() override;
39 
ruy_context()40   ruy::Context* ruy_context() const { return ruy_context_.get(); }
41 
gemmlowp_context()42   gemmlowp::GemmContext* gemmlowp_context() const {
43     return gemmlowp_context_.get();
44   }
45 
46   // Sets the maximum-number-of-threads-to-use parameter, only as a means of
47   // passing around this information.
48   void SetMaxNumThreads(int max_num_threads) override;
49 
max_num_threads()50   int max_num_threads() const { return max_num_threads_; }
51 
52   void SetUseCaching(bool flag);
53 
use_caching()54   bool use_caching() const { return use_caching_; }
55 
ClearCaches()56   void ClearCaches() override { ruy_context_->ClearPrepackedCache(); }
57 
58   bool HasAvxOrAbove();
59 
60   // Gemmlowp on x86 is a deprecated path but some clients may still use
61   // this path based on link time dependencies.
62   bool PreferGemmlowpOnX86();
63 
64  private:
65   // Copy the wrapper class for cpuinfo from Ruy.
66   class CpuInfo final {
67    public:
CpuInfo()68     CpuInfo() {}
69     ~CpuInfo();
70 
71     // X86 features
72     bool Avx();
73     bool Avx2Fma();
74     bool Avx512();
75 
76    private:
77     enum class InitStatus {
78       kNotYetAttempted,
79       kInitialized,
80       kFailed,
81     };
82 
83     InitStatus init_status_ = InitStatus::kNotYetAttempted;
84 
85     bool EnsureInitialized();
86     InitStatus Initialize();
87     CpuInfo(const CpuInfo&) = delete;
88     CpuInfo& operator=(const CpuInfo&) = delete;
89   };
90 
91   // To enable a smooth transition from the current direct usage
92   // of the underlying gemmlowp context to going through abstractions
93   // (see :cpu_backend_gemm), for now a CpuBackendContext always
94   // stores both a gemmlowp context and a ruy context.
95   // TODO(b/131416458): Once call sites all go through abstractions,
96   // elide what can be elided based on TFLITE_WITH_RUY.
97   const std::unique_ptr<ruy::Context> ruy_context_;
98   const std::unique_ptr<gemmlowp::GemmContext> gemmlowp_context_;
99   CpuInfo cpuinfo_;
100 
101   // The maximum of threads used for parallelizing TfLite ops. However,
102   // cpu_backend_threadpool::Execute creates as many threads as it's
103   // asked to, regardless of this. Typically a call site would query
104   // cpu_backend_context->max_num_threads() and used that to determine
105   // the number of tasks to create and to give to
106   // cpu_backend_threadpool::Execute.
107   //
108   // This value also gets propagated to back-ends, where it plays the same
109   // information-only role.
110   int max_num_threads_;
111   // For matrix muliplications with constants parameters (i.e. weights), we can
112   // sometimes provide speedups by caching the "prepacked" data, for some
113   // additional memory cost. This flag permits the user to route all
114   // CpuBackendGem operations to a library that permits such an optimization
115   // (currently the Ruy library only).
116   bool use_caching_;
117 
118   CpuBackendContext(const CpuBackendContext&) = delete;
119 };
120 
121 }  // namespace tflite
122 
123 #endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_CONTEXT_H_
124