1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
18 
19 #include "tensorflow/core/platform/types.h"
20 #include "tensorflow/stream_executor/platform.h"
21 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
22 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
23 
24 namespace tensorflow {
25 namespace tpu {
26 
27 // TODO(skyewm): get rid of TpuTopologyPtr and either use SE_TpuTopology* or
28 // return a TpuTopologyExternal.
29 typedef SE_TpuTopology* TpuTopologyPtr;
30 
31 class TpuPlatformInterface : public stream_executor::Platform {
32  public:
33   using Status = stream_executor::port::Status;
34 
35   // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
36   // registered, finds the most suitable one. Returns nullptr if no TPU platform
37   // is registered or an error occurred.
38   //
39   // 'initialize_platform' can be set to false to not initialize a platform if
40   // not necessary. 'num_tries' specifies the number of tries if the TPU
41   // platform isn't initialized yet, with a 1-second delay between each try
42   // (num_tries == 1 means try once with no retries).
43   static TpuPlatformInterface* GetRegisteredPlatform(
44       bool initialize_platform = true, int num_tries = 5);
45 
46   virtual Status Reset(bool only_tear_down, absl::string_view reason) = 0;
47 
Reset(absl::string_view reason)48   Status Reset(absl::string_view reason) { return Reset(false, reason); }
49 
Reset()50   Status Reset() { return Reset(false, {}); }
51 
52   virtual int64 TpuMemoryLimit() = 0;
53 
54   virtual bool ShouldRegisterTpuDeviceToDeviceCopy() = 0;
55 
56   virtual const TpuTopologyPtr GetTopologyPtr() = 0;
57 
58   virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0;
59 
60   virtual TpuRuntimeVersion version() const = 0;
61 
topology()62   TpuTopologyExternal topology() {
63     return TpuTopologyExternal(GetTopologyPtr());
64   }
65 };
66 
67 }  // namespace tpu
68 }  // namespace tensorflow
69 
70 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
71