1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_ 17 #define TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_ 18 19 #include <functional> 20 21 #include "tensorflow/core/framework/logging.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/protobuf/config.pb.h" 24 25 namespace tensorflow { 26 27 namespace xla_config_registry { 28 29 // XlaGlobalJitLevel is used by XLA to expose its JIT level for processing 30 // single gpu and general (multi-gpu) graphs. 31 struct XlaGlobalJitLevel { 32 OptimizerOptions::GlobalJitLevel single_gpu; 33 OptimizerOptions::GlobalJitLevel general; 34 }; 35 36 // Input is the jit_level in session config, and return value is the jit_level 37 // from XLA, reflecting the effect of the environment variable flags. 38 typedef std::function<XlaGlobalJitLevel( 39 const OptimizerOptions::GlobalJitLevel&)> 40 GlobalJitLevelGetterTy; 41 42 void RegisterGlobalJitLevelGetter(GlobalJitLevelGetterTy getter); 43 44 XlaGlobalJitLevel GetGlobalJitLevel( 45 OptimizerOptions::GlobalJitLevel jit_level_in_session_opts); 46 47 #define REGISTER_XLA_CONFIG_GETTER(getter) \ 48 REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(__COUNTER__, getter) 49 50 #define REGISTER_XLA_CONFIG_GETTER_UNIQ_HELPER(ctr, getter) \ 51 REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) 52 53 #define REGISTER_XLA_CONFIG_GETTER_UNIQ(ctr, getter) \ 54 static bool xla_config_registry_registration_##ctr = \ 55 (::tensorflow::xla_config_registry::RegisterGlobalJitLevelGetter( \ 56 getter), \ 57 true) 58 59 } // namespace xla_config_registry 60 61 } // namespace tensorflow 62 63 #endif // TENSORFLOW_CORE_UTIL_XLA_CONFIG_REGISTRY_H_ 64