1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
17
18 #include <mutex> // NOLINT(build/c++11)
19
20 #include "absl/base/call_once.h"
21
22 // We need a pair of compile time and runtime flags to disable compilation of
23 // custom contraction kernels for unsupported architectures (e.g. Android,
24 // iOS, ARM and PPC CPUs, etc...), and to be able to fallback on default Eigen
25 // matrix multiplication at runtime.
26 //
27 // It's not allowed to use absl flags library in Tensorflow, so we have to pass
28 // the configuration through the environment variable.
29 //
30 // Example:
31 // bazel test \
32 // --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \
33 // //path/to:test
34
35 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
36
37 namespace Eigen {
38 namespace internal {
39
40 // TODO(ezhulenev): This is a temporary workaround for disabling custom kernels
41 // at runtime in tests. We should always rely on compile time flags for that.
42 //
43 // Example:
44 // bazel test \
45 // --test_env=TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL=false \
46 // //path/to:test
UseCustomContractionKernels()47 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE bool UseCustomContractionKernels() {
48 static bool use_custom_contraction_kernel = true;
49
50 // This subroutine should not be used in GPU. In case it is, a custom kernel
51 // should always be used
52 #if !defined __NVCC__ && !defined __HIP_DEVICE_COMPILE__
53 static absl::once_flag initialized;
54 absl::call_once(initialized, [&] {
55 char* flag = std::getenv("TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL");
56 if (flag && (strcmp(flag, "false") == 0 || strcmp(flag, "0") == 0)) {
57 use_custom_contraction_kernel = false;
58 }
59 });
60 #endif
61
62 return use_custom_contraction_kernel;
63 }
64
65 } // namespace internal
66 } // namespace Eigen
67 #endif
68