1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/resource_op_kernel.h"
19 #include "tensorflow/core/framework/stats_aggregator.h"
20 #include "tensorflow/core/framework/summary.pb.h"
21 #include "tensorflow/core/kernels/summary_interface.h"
22 #include "tensorflow/core/lib/core/refcount.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/histogram/histogram.h"
25 #include "tensorflow/core/lib/monitoring/counter.h"
26 #include "tensorflow/core/lib/monitoring/gauge.h"
27 #include "tensorflow/core/lib/monitoring/sampler.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/events_writer.h"
30 
31 namespace tensorflow {
32 namespace data {
33 namespace experimental {
34 namespace {
35 
get_counters_map_lock()36 static mutex* get_counters_map_lock() {
37   static mutex counters_map_lock(LINKER_INITIALIZED);
38   return &counters_map_lock;
39 }
40 
get_counters_map()41 static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
42   static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
43       new std::unordered_map<string, monitoring::Counter<1>*>;
44   return counters_map;
45 }
46 
47 class StatsAggregatorImpl : public StatsAggregator {
48  public:
StatsAggregatorImpl()49   StatsAggregatorImpl() {}
50 
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64 steps)51   void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
52                       const int64 steps) override {
53     mutex_lock l(mu_);
54     histogram::Histogram& histogram = histograms_[name];
55     for (double value : values) {
56       histogram.Add(value);
57     }
58   }
59 
AddScalar(const string & name,float value,const int64 steps)60   void AddScalar(const string& name, float value, const int64 steps) override {
61     mutex_lock l(mu_);
62     scalars_[name] = value;
63   }
64 
EncodeToProto(Summary * out_summary)65   void EncodeToProto(Summary* out_summary) override {
66     mutex_lock l(mu_);
67     for (const auto& pair : histograms_) {
68       const string& name = pair.first;
69       const histogram::Histogram& histogram = pair.second;
70 
71       Summary::Value* value = out_summary->add_value();
72       value->set_tag(name);
73       histogram.EncodeToProto(value->mutable_histo(),
74                               false /* doesn't preserve zero buckets */);
75     }
76     for (const auto& pair : scalars_) {
77       Summary::Value* value = out_summary->add_value();
78       value->set_tag(pair.first);
79       value->set_simple_value(pair.second);
80     }
81   }
82 
83   // StatsAggregator implementation for V2 is based on push-based summary, no-op
84   // in V1.
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)85   Status SetSummaryWriter(
86       SummaryWriterInterface* summary_writer_interface) override {
87     return Status::OK();
88   }
89 
IncrementCounter(const string & name,const string & label,int64 val)90   void IncrementCounter(const string& name, const string& label,
91                         int64 val) override {
92     mutex_lock l(*get_counters_map_lock());
93     auto counters_map = get_counters_map();
94     if (counters_map->find(name) == counters_map->end()) {
95       counters_map->emplace(
96           name,
97           monitoring::Counter<1>::New(
98               /*streamz name*/ name,
99               /*streamz description*/
100               strings::StrCat(name, " generated or consumed by the component."),
101               /*streamz label name*/ "component_descriptor"));
102     }
103     counters_map->at(name)->GetCell(label)->IncrementBy(val);
104   }
105 
106  private:
107   mutex mu_;
108   std::unordered_map<string, histogram::Histogram> histograms_
109       TF_GUARDED_BY(mu_);
110   std::unordered_map<string, float> scalars_ TF_GUARDED_BY(mu_);
111   TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
112 };
113 
114 class StatsAggregatorHandleOp
115     : public ResourceOpKernel<StatsAggregatorResource> {
116  public:
StatsAggregatorHandleOp(OpKernelConstruction * ctx)117   explicit StatsAggregatorHandleOp(OpKernelConstruction* ctx)
118       : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
119 
120  private:
CreateResource(StatsAggregatorResource ** ret)121   Status CreateResource(StatsAggregatorResource** ret) override
122       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
123     *ret =
124         new StatsAggregatorResource(absl::make_unique<StatsAggregatorImpl>());
125     return Status::OK();
126   }
127 };
128 
129 class StatsAggregatorImplV2 : public StatsAggregator {
130  public:
StatsAggregatorImplV2()131   StatsAggregatorImplV2() {}
132 
~StatsAggregatorImplV2()133   ~StatsAggregatorImplV2() override {
134     if (summary_writer_interface_) {
135       summary_writer_interface_->Unref();
136     }
137   }
138 
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64 steps)139   void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
140                       const int64 steps) override {
141     mutex_lock l(mu_);
142     histogram::Histogram& histogram = histograms_[name];
143     for (double value : values) {
144       histogram.Add(value);
145     }
146     AddToEvents(name, steps, histogram);
147   }
148 
AddScalar(const string & name,float value,const int64 steps)149   void AddScalar(const string& name, float value, const int64 steps) override {
150     mutex_lock l(mu_);
151     AddToEvents(name, steps, value);
152   }
153 
154   // TODO(b/116314787): expose this is public API to manually flush summary.
Flush()155   Status Flush() {
156     mutex_lock l(mu_);
157     if (summary_writer_interface_)
158       TF_RETURN_IF_ERROR(summary_writer_interface_->Flush());
159     return Status::OK();
160   }
161 
IncrementCounter(const string & name,const string & label,int64 val)162   void IncrementCounter(const string& name, const string& label,
163                         int64 val) override {
164     mutex_lock l(*get_counters_map_lock());
165     auto counters_map = get_counters_map();
166     if (counters_map->find(name) == counters_map->end()) {
167       counters_map->emplace(
168           name, monitoring::Counter<1>::New(
169                     /*streamz name*/ "/tensorflow/" + name,
170                     /*streamz description*/
171                     name + " generated or consumed by the component.",
172                     /*streamz label name*/ "component_descriptor"));
173     }
174     counters_map->at(name)->GetCell(label)->IncrementBy(val);
175   }
176 
177   // StatsAggregator implementation for V1 is based on pull-based summary, no-op
178   // in V2.
EncodeToProto(Summary * out_summary)179   void EncodeToProto(Summary* out_summary) override {}
180 
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)181   Status SetSummaryWriter(
182       SummaryWriterInterface* summary_writer_interface) override {
183     mutex_lock l(mu_);
184     if (summary_writer_interface_) {
185       summary_writer_interface_->Unref();
186       // If we create stats_aggregator twice in a program, we would end up with
187       // already existing resource. In this case emitting an error if a
188       // `summary_writer_resource` is present is not the intended behavior, we
189       // could either Unref the existing summary_writer_resource or not set the
190       // new resource at all.
191     }
192     summary_writer_interface_ = summary_writer_interface;
193     summary_writer_interface_->Ref();
194     return Status::OK();
195   }
196 
197  private:
AddToEvents(const string & name,const int64 steps,const float scalar_value)198   void AddToEvents(const string& name, const int64 steps,
199                    const float scalar_value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
200     if (summary_writer_interface_ == nullptr) {
201       return;
202     }
203     std::unique_ptr<Event> e{new Event};
204     e->set_step(steps);
205     e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
206     // maybe expose GetWallTime in SummaryWriterInterface
207     Summary::Value* v = e->mutable_summary()->add_value();
208     v->set_tag(name);
209     v->set_simple_value(scalar_value);
210     TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
211   }
212 
AddToEvents(const string & name,const int64 steps,const histogram::Histogram & histogram)213   void AddToEvents(const string& name, const int64 steps,
214                    const histogram::Histogram& histogram)
215       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
216     if (summary_writer_interface_ == nullptr) {
217       return;
218     }
219     std::unique_ptr<Event> e{new Event};
220     e->set_step(steps);
221     e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
222     Summary::Value* v = e->mutable_summary()->add_value();
223     v->set_tag(name);
224     histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
225     TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
226   }
227 
228   mutex mu_;
229   SummaryWriterInterface* summary_writer_interface_ TF_GUARDED_BY(mu_) =
230       nullptr;
231   // not owned, we might be associating the default summary_writer from the
232   // context
233   std::unordered_map<string, histogram::Histogram> histograms_
234       TF_GUARDED_BY(mu_);
235   TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2);
236 };
237 
238 class StatsAggregatorHandleOpV2
239     : public ResourceOpKernel<StatsAggregatorResource> {
240  public:
StatsAggregatorHandleOpV2(OpKernelConstruction * ctx)241   explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx)
242       : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
243 
244  private:
CreateResource(StatsAggregatorResource ** ret)245   Status CreateResource(StatsAggregatorResource** ret) override
246       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
247     *ret =
248         new StatsAggregatorResource(absl::make_unique<StatsAggregatorImplV2>());
249     return Status::OK();
250   }
251 };
252 
253 class StatsAggregatorSummaryOp : public OpKernel {
254  public:
StatsAggregatorSummaryOp(OpKernelConstruction * ctx)255   explicit StatsAggregatorSummaryOp(OpKernelConstruction* ctx)
256       : OpKernel(ctx) {}
257 
Compute(OpKernelContext * ctx)258   void Compute(OpKernelContext* ctx) override {
259     const Tensor& resource_handle_t = ctx->input(0);
260     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
261                 errors::InvalidArgument("resource_handle must be a scalar"));
262 
263     core::RefCountPtr<StatsAggregatorResource> resource;
264     OP_REQUIRES_OK(ctx,
265                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
266 
267     Tensor* summary_t;
268     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
269     Summary summary;
270     resource->stats_aggregator()->EncodeToProto(&summary);
271     summary_t->scalar<tstring>()() = summary.SerializeAsString();
272   }
273 };
274 
275 class StatsAggregatorSetSummaryWriterOp : public OpKernel {
276  public:
StatsAggregatorSetSummaryWriterOp(OpKernelConstruction * ctx)277   explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx)
278       : OpKernel(ctx) {}
279 
Compute(OpKernelContext * ctx)280   void Compute(OpKernelContext* ctx) override {
281     const Tensor& resource_handle_t = ctx->input(0);
282     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
283                 errors::InvalidArgument("resource_handle must be a scalar"));
284 
285     core::RefCountPtr<StatsAggregatorResource> resource;
286     OP_REQUIRES_OK(ctx,
287                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
288 
289     const Tensor& summary_resource_handle_t = ctx->input(1);
290     OP_REQUIRES(ctx,
291                 TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
292                 errors::InvalidArgument("resource_handle must be a scalar"));
293     core::RefCountPtr<SummaryWriterInterface> summary_resource;
294     OP_REQUIRES_OK(
295         ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &summary_resource));
296     TF_CHECK_OK(
297         resource->stats_aggregator()->SetSummaryWriter(summary_resource.get()));
298   }
299 };
300 
301 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandle").Device(DEVICE_CPU),
302                         StatsAggregatorHandleOp);
303 REGISTER_KERNEL_BUILDER(
304     Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
305     StatsAggregatorHandleOp);
306 
307 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
308                         StatsAggregatorHandleOpV2);
309 
310 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
311                         StatsAggregatorSummaryOp);
312 REGISTER_KERNEL_BUILDER(
313     Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
314     StatsAggregatorSummaryOp);
315 
316 REGISTER_KERNEL_BUILDER(
317     Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
318     StatsAggregatorSetSummaryWriterOp);
319 
320 }  // namespace
321 }  // namespace experimental
322 }  // namespace data
323 }  // namespace tensorflow
324