1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_ 18 19 #include "tensorflow/core/platform/types.h" 20 #include "tensorflow/stream_executor/platform.h" 21 #include "tensorflow/stream_executor/tpu/c_api_decl.h" 22 #include "tensorflow/stream_executor/tpu/tpu_topology.h" 23 24 namespace tensorflow { 25 namespace tpu { 26 27 // TODO(skyewm): get rid of TpuTopologyPtr and either use SE_TpuTopology* or 28 // return a TpuTopologyExternal. 29 typedef SE_TpuTopology* TpuTopologyPtr; 30 31 class TpuPlatformInterface : public stream_executor::Platform { 32 public: 33 using Status = stream_executor::port::Status; 34 35 // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are 36 // registered, finds the most suitable one. Returns nullptr if no TPU platform 37 // is registered or an error occurred. 38 // 39 // 'initialize_platform' can be set to false to not initialize a platform if 40 // not necessary. 'num_tries' specifies the number of tries if the TPU 41 // platform isn't initialized yet, with a 1-second delay between each try 42 // (num_tries == 1 means try once with no retries). 43 static TpuPlatformInterface* GetRegisteredPlatform( 44 bool initialize_platform = true, int num_tries = 5); 45 46 virtual Status Reset(bool only_tear_down, absl::string_view reason) = 0; 47 Reset(absl::string_view reason)48 Status Reset(absl::string_view reason) { return Reset(false, reason); } 49 Reset()50 Status Reset() { return Reset(false, {}); } 51 52 virtual int64 TpuMemoryLimit() = 0; 53 54 virtual bool ShouldRegisterTpuDeviceToDeviceCopy() = 0; 55 56 virtual const TpuTopologyPtr GetTopologyPtr() = 0; 57 58 virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0; 59 60 virtual TpuRuntimeVersion version() const = 0; 61 topology()62 TpuTopologyExternal topology() { 63 return TpuTopologyExternal(GetTopologyPtr()); 64 } 65 }; 66 67 } // namespace tpu 68 } // namespace tensorflow 69 70 #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_ 71