1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
28 #include "tensorflow/compiler/xla/service/local_service.h"
29 #include "tensorflow/compiler/xla/service/platform_util.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/service/transfer_manager.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 
42 class TestAllocator : public StreamExecutorMemoryAllocator {
43  public:
TestAllocator(se::Platform * platform)44   explicit TestAllocator(se::Platform* platform)
45       : StreamExecutorMemoryAllocator(
46             platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
47   }
48 
49   StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
50                                         bool retry_on_failure) override;
51   Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
52 
53   // Return the number of allocations that have been performed.
54   int64 allocation_count() const;
55   int64 allocation_count(int device_ordinal) const;
56 
57   // Return the number of deallocations that have been performed.
58   int64 deallocation_count() const;
59   int64 deallocation_count(int device_ordinal) const;
60 
61  private:
62   mutable tensorflow::mutex count_mutex_;
63 
64   // Global counts of allocations and deallocations.
65   int64 allocation_count_ GUARDED_BY(count_mutex_) = 0;
66   int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0;
67 
68   // Per-device counts of allocations and deallocations.
69   std::map<int, int64> device_allocation_count_ GUARDED_BY(count_mutex_);
70   std::map<int, int64> device_deallocation_count_ GUARDED_BY(count_mutex_);
71 };
72 
73 // A base class for tests which exercise the LocalClient interface.
74 class LocalClientTestBase : public ::testing::Test {
75  protected:
76   struct EigenThreadPoolWrapper;
77   explicit LocalClientTestBase(se::Platform* platform = nullptr);
78   virtual ~LocalClientTestBase();
79 
80   static TestAllocator* GetOrCreateAllocator(se::Platform* platform);
81 
82   // Copy the given literal onto the default device and return a
83   // ScopedShapedBuffer. Convenience wrapper around
84   // LocalClient::LiteralToShapedBuffer.
85   ScopedShapedBuffer LiteralToShapedBuffer(const Literal& literal);
86 
87   // Construct and return a literal containing the array represented by
88   // shaped_buffer.
89   Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
90 
91   // Execute the given computation on the local client. With and without
92   // options.
93   StatusOr<ScopedShapedBuffer> ExecuteLocally(
94       const XlaComputation& computation,
95       absl::Span<const ShapedBuffer* const> arguments);
96   StatusOr<ScopedShapedBuffer> ExecuteLocally(
97       const XlaComputation& computation,
98       absl::Span<const ShapedBuffer* const> arguments,
99       const ExecutableBuildOptions& build_options,
100       const ExecutableRunOptions& run_options);
101 
102   ScopedShapedBuffer ExecuteLocallyOrDie(
103       const XlaComputation& computation,
104       absl::Span<const ShapedBuffer* const> arguments);
105   ScopedShapedBuffer ExecuteLocallyOrDie(
106       const XlaComputation& computation,
107       absl::Span<const ShapedBuffer* const> arguments,
108       const ExecutableBuildOptions& build_options,
109       const ExecutableRunOptions& run_options);
110 
111   // Returns a default set of execute options.
112   ExecutableBuildOptions DefaultExecutableBuildOptions() const;
113 
114   // Returns a default set of execute options, configured to use allocator_
115   // as the allocator.
116   ExecutableRunOptions DefaultExecutableRunOptions() const;
117 
TestName()118   string TestName() const {
119     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
120   }
121 
122   // The allocator must live as long as the service, which lives until the end
123   // of the process. So make the allocator static.
124   static TestAllocator* allocator_;
125 
126   se::StreamExecutor* stream_executor_;
127   TransferManager* transfer_manager_;
128 
129   LocalClient* local_client_;
130 
131   std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
132 };
133 
134 }  // namespace xla
135 
136 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
137