1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Metrics collecting utilities for single client training."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import time
23
24from tensorflow.python.eager import monitoring
25from tensorflow.python.util import tf_contextlib
26
27enable_metrics = False
28_METRICS_MAPPING = {}
29
30
31def _init():
32  """Initialize the metrics mapping."""
33  global _METRICS_MAPPING
34
35  # Time in seconds to bucket the distribution of execution time. Range from
36  # 0.001s (i.e., 1ms) to 1000s.
37  time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6)
38
39  function_tracing_sampler = monitoring.Sampler(
40      '/tensorflow/api/ps_strategy/coordinator/function_tracing', time_buckets,
41      'Sampler to track the time (in seconds) for tracing functions.')
42
43  closure_execution_sampler = monitoring.Sampler(
44      '/tensorflow/api/ps_strategy/coordinator/closure_execution',
45      time_buckets,
46      'Sampler to track the time (in seconds) for executing closures.')
47
48  remote_value_fetch_sampler = monitoring.Sampler(
49      '/tensorflow/api/ps_strategy/coordinator/remote_value_fetch',
50      time_buckets,
51      'Sampler to track the time (in seconds) for fetching remote_value.')
52
53  _METRICS_MAPPING = {
54      'function_tracing': function_tracing_sampler,
55      'closure_execution': closure_execution_sampler,
56      'remote_value_fetch': remote_value_fetch_sampler
57  }
58
59
60@tf_contextlib.contextmanager
61def monitored_timer(metric_name, state_tracker=None):
62  """Monitor the execution time and collect it into the specified metric."""
63  if not enable_metrics:
64    yield
65  else:
66    if not _METRICS_MAPPING:
67      _init()
68    start_time = time.time()
69    start_state = state_tracker() if state_tracker else None
70    yield
71    duration_sec = time.time() - start_time
72    # If a state_checker is provided, record the metric only if the end state is
73    # different from the start state.
74    if state_tracker is None or state_tracker() != start_state:
75      metric = _METRICS_MAPPING[metric_name]
76      metric.get_cell().add(duration_sec)
77
78
79def get_metric_summary(metric_name):
80  """Get summary for the specified metric."""
81  metric = _METRICS_MAPPING[metric_name]
82  histogram_proto = metric.get_cell().value()
83  ret = dict()
84  ret['min'] = histogram_proto.min
85  ret['max'] = histogram_proto.max
86  ret['num'] = histogram_proto.num
87  ret['sum'] = histogram_proto.sum
88  # TODO(haoyuzhang): consider reporting the distribution in buckets.
89  return ret
90