1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
18 
19 // We replace this implementation with a null implementation for mobile
20 // platforms.
21 #include "tensorflow/core/platform/platform.h"
22 #ifdef IS_MOBILE_PLATFORM
23 #include "tensorflow/core/lib/monitoring/mobile_counter.h"
24 #else
25 
26 #include <array>
27 #include <atomic>
28 #include <map>
29 
30 #include "tensorflow/core/lib/monitoring/collection_registry.h"
31 #include "tensorflow/core/lib/monitoring/metric_def.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 
37 namespace tensorflow {
38 namespace monitoring {
39 
40 // CounterCell stores each value of an Counter.
41 //
42 // A cell can be passed off to a module which may repeatedly update it without
43 // needing further map-indexing computations. This improves both encapsulation
44 // (separate modules can own a cell each, without needing to know about the map
45 // to which both cells belong) and performance (since map indexing and
46 // associated locking are both avoided).
47 //
48 // This class is thread-safe.
49 class CounterCell {
50  public:
CounterCell(int64 value)51   CounterCell(int64 value) : value_(value) {}
~CounterCell()52   ~CounterCell() {}
53 
54   // Atomically increments the value by step.
55   // REQUIRES: Step be non-negative.
56   void IncrementBy(int64 step);
57 
58   // Retrieves the current value.
59   int64 value() const;
60 
61  private:
62   std::atomic<int64> value_;
63 
64   TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
65 };
66 
67 // A stateful class for updating a cumulative integer metric.
68 //
69 // This class encapsulates a set of values (or a single value for a label-less
70 // metric). Each value is identified by a tuple of labels. The class allows the
71 // user to increment each value.
72 //
73 // Counter allocates storage and maintains a cell for each value. You can
74 // retrieve an individual cell using a label-tuple and update it separately.
75 // This improves performance since operations related to retrieval, like
76 // map-indexing and locking, are avoided.
77 //
78 // This class is thread-safe.
79 template <int NumLabels>
80 class Counter {
81  public:
~Counter()82   ~Counter() {
83     // Deleted here, before the metric_def is destroyed.
84     registration_handle_.reset();
85   }
86 
87   // Creates the metric based on the metric-definition arguments.
88   //
89   // Example;
90   // auto* counter_with_label = Counter<1>::New("/tensorflow/counter",
91   //   "Tensorflow counter", "MyLabelName");
92   template <typename... MetricDefArgs>
93   static Counter* New(MetricDefArgs&&... metric_def_args);
94 
95   // Retrieves the cell for the specified labels, creating it on demand if
96   // not already present.
97   template <typename... Labels>
98   CounterCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_);
99 
100  private:
Counter(const MetricDef<MetricKind::kCumulative,int64,NumLabels> & metric_def)101   explicit Counter(
102       const MetricDef<MetricKind::kCumulative, int64, NumLabels>& metric_def)
103       : metric_def_(metric_def),
104         registration_handle_(CollectionRegistry::Default()->Register(
105             &metric_def_, [&](MetricCollectorGetter getter) {
106               auto metric_collector = getter.Get(&metric_def_);
107 
108               mutex_lock l(mu_);
109               for (const auto& cell : cells_) {
110                 metric_collector.CollectValue(cell.first, cell.second.value());
111               }
112             })) {}
113 
114   mutable mutex mu_;
115 
116   // The metric definition. This will be used to identify the metric when we
117   // register it for collection.
118   const MetricDef<MetricKind::kCumulative, int64, NumLabels> metric_def_;
119 
120   std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
121 
122   using LabelArray = std::array<string, NumLabels>;
123   std::map<LabelArray, CounterCell> cells_ GUARDED_BY(mu_);
124 
125   TF_DISALLOW_COPY_AND_ASSIGN(Counter);
126 };
127 
128 ////
129 //  Implementation details follow. API readers may skip.
130 ////
131 
IncrementBy(const int64 step)132 inline void CounterCell::IncrementBy(const int64 step) {
133   DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
134   value_ += step;
135 }
136 
value()137 inline int64 CounterCell::value() const { return value_; }
138 
139 template <int NumLabels>
140 template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)141 Counter<NumLabels>* Counter<NumLabels>::New(
142     MetricDefArgs&&... metric_def_args) {
143   return new Counter<NumLabels>(
144       MetricDef<MetricKind::kCumulative, int64, NumLabels>(
145           std::forward<MetricDefArgs>(metric_def_args)...));
146 }
147 
148 template <int NumLabels>
149 template <typename... Labels>
GetCell(const Labels &...labels)150 CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
151     LOCKS_EXCLUDED(mu_) {
152   // Provides a more informative error message than the one during array
153   // construction below.
154   static_assert(sizeof...(Labels) == NumLabels,
155                 "Mismatch between Counter<NumLabels> and number of labels "
156                 "provided in GetCell(...).");
157 
158   const LabelArray& label_array = {{labels...}};
159   mutex_lock l(mu_);
160   const auto found_it = cells_.find(label_array);
161   if (found_it != cells_.end()) {
162     return &(found_it->second);
163   }
164   return &(cells_
165                .emplace(std::piecewise_construct,
166                         std::forward_as_tuple(label_array),
167                         std::forward_as_tuple(0))
168                .first->second);
169 }
170 
171 }  // namespace monitoring
172 }  // namespace tensorflow
173 
174 #endif  // IS_MOBILE_PLATFORM
175 #endif  // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
176