1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PLATFORM_TRACING_H_
17 #define TENSORFLOW_CORE_PLATFORM_TRACING_H_
18 
19 // Tracing interface
20 
21 #include <array>
22 #include <atomic>
23 #include <map>
24 #include <memory>
25 
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/platform.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 namespace tracing {
35 
36 // This enumeration contains the identifiers of all TensorFlow CPU profiler
37 // events. It must be kept in sync with the code in GetEventCategoryName().
38 enum struct EventCategory : unsigned {
39   kScheduleClosure = 0,
40   kRunClosure = 1,
41   kCompute = 2,
42   kNumCategories = 3  // sentinel - keep last
43 };
GetNumEventCategories()44 constexpr unsigned GetNumEventCategories() {
45   return static_cast<unsigned>(EventCategory::kNumCategories);
46 }
47 const char* GetEventCategoryName(EventCategory);
48 
49 // Interface for CPU profiler events.
50 class EventCollector {
51  public:
~EventCollector()52   virtual ~EventCollector() {}
53   virtual void RecordEvent(uint64 arg) const = 0;
54   virtual void StartRegion(uint64 arg) const = 0;
55   virtual void StopRegion() const = 0;
56 
57   // Annotates the current thread with a name.
58   static void SetCurrentThreadName(const char* name);
59   // Returns whether event collection is enabled.
60   static bool IsEnabled();
61 
62  private:
63   friend void SetEventCollector(EventCategory, const EventCollector*);
64   friend const EventCollector* GetEventCollector(EventCategory);
65 
66   static std::array<const EventCollector*, GetNumEventCategories()> instances_;
67 };
68 // Set the callback for RecordEvent and ScopedRegion of category.
69 // Not thread safe. Only call while EventCollector::IsEnabled returns false.
70 void SetEventCollector(EventCategory category, const EventCollector* collector);
71 
72 // Returns the callback for RecordEvent and ScopedRegion of category if
73 // EventCollector::IsEnabled(), otherwise returns null.
GetEventCollector(EventCategory category)74 inline const EventCollector* GetEventCollector(EventCategory category) {
75   if (EventCollector::IsEnabled()) {
76     return EventCollector::instances_[static_cast<unsigned>(category)];
77   }
78   return nullptr;
79 }
80 
81 // Returns a unique id to pass to RecordEvent/ScopedRegion. Never returns zero.
82 uint64 GetUniqueArg();
83 
84 // Returns an id for name to pass to RecordEvent/ScopedRegion.
85 uint64 GetArgForName(StringPiece name);
86 
87 // Records an atomic event through the currently registered EventCollector.
RecordEvent(EventCategory category,uint64 arg)88 inline void RecordEvent(EventCategory category, uint64 arg) {
89   if (auto collector = GetEventCollector(category)) {
90     collector->RecordEvent(arg);
91   }
92 }
93 
94 // Records an event for the duration of the instance lifetime through the
95 // currently registered EventCollector.
96 class ScopedRegion {
97   ScopedRegion(ScopedRegion&) = delete;             // Not copy-constructible.
98   ScopedRegion& operator=(ScopedRegion&) = delete;  // Not assignable.
99 
100  public:
ScopedRegion(ScopedRegion && other)101   ScopedRegion(ScopedRegion&& other) noexcept  // Move-constructible.
102       : collector_(other.collector_) {
103     other.collector_ = nullptr;
104   }
105 
ScopedRegion(EventCategory category,uint64 arg)106   ScopedRegion(EventCategory category, uint64 arg)
107       : collector_(GetEventCollector(category)) {
108     if (collector_) {
109       collector_->StartRegion(arg);
110     }
111   }
112 
113   // Same as ScopedRegion(category, GetUniqueArg()), but faster if
114   // EventCollector::IsEnaled() returns false.
ScopedRegion(EventCategory category)115   ScopedRegion(EventCategory category)
116       : collector_(GetEventCollector(category)) {
117     if (collector_) {
118       collector_->StartRegion(GetUniqueArg());
119     }
120   }
121 
122   // Same as ScopedRegion(category, GetArgForName(name)), but faster if
123   // EventCollector::IsEnaled() returns false.
ScopedRegion(EventCategory category,StringPiece name)124   ScopedRegion(EventCategory category, StringPiece name)
125       : collector_(GetEventCollector(category)) {
126     if (collector_) {
127       collector_->StartRegion(GetArgForName(name));
128     }
129   }
130 
~ScopedRegion()131   ~ScopedRegion() {
132     if (collector_) {
133       collector_->StopRegion();
134     }
135   }
136 
IsEnabled()137   bool IsEnabled() const { return collector_ != nullptr; }
138 
139  private:
140   const EventCollector* collector_;
141 };
142 
143 // Interface for accelerator profiler annotations.
144 class TraceCollector {
145  public:
146   class Handle {
147    public:
~Handle()148     virtual ~Handle() {}
149   };
150 
~TraceCollector()151   virtual ~TraceCollector() {}
152   virtual std::unique_ptr<Handle> CreateAnnotationHandle(
153       StringPiece name_part1, StringPiece name_part2) const = 0;
154   virtual std::unique_ptr<Handle> CreateActivityHandle(
155       StringPiece name_part1, StringPiece name_part2,
156       bool is_expensive) const = 0;
157 
158   // Returns true if this annotation tracing is enabled for any op.
159   virtual bool IsEnabledForAnnotations() const = 0;
160 
161   // Returns true if this activity handle tracking is enabled for an op of the
162   // given expensiveness.
163   virtual bool IsEnabledForActivities(bool is_expensive) const = 0;
164 
165  protected:
166   static string ConcatenateNames(StringPiece first, StringPiece second);
167 
168  private:
169   friend void SetTraceCollector(const TraceCollector*);
170   friend const TraceCollector* GetTraceCollector();
171 };
172 // Set the callback for ScopedAnnotation and ScopedActivity.
173 void SetTraceCollector(const TraceCollector* collector);
174 // Returns the callback for ScopedAnnotation and ScopedActivity.
175 const TraceCollector* GetTraceCollector();
176 
177 // Adds an annotation to all activities for the duration of the instance
178 // lifetime through the currently registered TraceCollector.
179 //
180 // Usage: {
181 //          ScopedAnnotation annotation("my kernels");
182 //          Kernel1<<<x,y>>>;
183 //          LaunchKernel2(); // Launches a CUDA kernel.
184 //        }
185 // This will add 'my kernels' to both kernels in the profiler UI
186 class ScopedAnnotation {
187  public:
ScopedAnnotation(StringPiece name)188   explicit ScopedAnnotation(StringPiece name)
189       : ScopedAnnotation(name, StringPiece()) {}
190 
191   // If tracing is enabled, add a name scope of
192   // "<name_part1>:<name_part2>".  This can be cheaper than the
193   // single-argument constructor because the concatenation of the
194   // label string is only done if tracing is enabled.
ScopedAnnotation(StringPiece name_part1,StringPiece name_part2)195   ScopedAnnotation(StringPiece name_part1, StringPiece name_part2)
196       : handle_([&] {
197           auto trace_collector = GetTraceCollector();
198           return trace_collector ? trace_collector->CreateAnnotationHandle(
199                                        name_part1, name_part2)
200                                  : nullptr;
201         }()) {}
202 
IsEnabled()203   bool IsEnabled() const { return static_cast<bool>(handle_); }
204 
205  private:
206   std::unique_ptr<TraceCollector::Handle> handle_;
207 };
208 
209 // Adds an activity through the currently registered TraceCollector.
210 // The activity starts when an object of this class is created and stops when
211 // the object is destroyed.
212 class ScopedActivity {
213  public:
214   explicit ScopedActivity(StringPiece name, bool is_expensive = true)
ScopedActivity(name,StringPiece (),is_expensive)215       : ScopedActivity(name, StringPiece(), is_expensive) {}
216 
217   // If tracing is enabled, set up an activity with a label of
218   // "<name_part1>:<name_part2>".  This can be cheaper than the
219   // single-argument constructor because the concatenation of the
220   // label string is only done if tracing is enabled.
221   ScopedActivity(StringPiece name_part1, StringPiece name_part2,
222                  bool is_expensive = true)
223       : handle_([&] {
224           auto trace_collector = GetTraceCollector();
225           return trace_collector ? trace_collector->CreateActivityHandle(
226                                        name_part1, name_part2, is_expensive)
227                                  : nullptr;
228         }()) {}
229 
IsEnabled()230   bool IsEnabled() const { return static_cast<bool>(handle_); }
231 
232  private:
233   std::unique_ptr<TraceCollector::Handle> handle_;
234 };
235 
236 // Return the pathname of the directory where we are writing log files.
237 const char* GetLogDir();
238 
239 }  // namespace tracing
240 }  // namespace tensorflow
241 
242 #if defined(PLATFORM_GOOGLE)
243 #include "tensorflow/core/platform/google/tracing_impl.h"
244 #else
245 #include "tensorflow/core/platform/default/tracing_impl.h"
246 #endif
247 
248 #endif  // TENSORFLOW_CORE_PLATFORM_TRACING_H_
249