1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Metrics collecting utilities for single client training.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import time 23 24from tensorflow.python.eager import monitoring 25from tensorflow.python.util import tf_contextlib 26 27enable_metrics = False 28_METRICS_MAPPING = {} 29 30 31def _init(): 32 """Initialize the metrics mapping.""" 33 global _METRICS_MAPPING 34 35 # Time in seconds to bucket the distribution of execution time. Range from 36 # 0.001s (i.e., 1ms) to 1000s. 37 time_buckets = monitoring.ExponentialBuckets(0.001, 10, 6) 38 39 function_tracing_sampler = monitoring.Sampler( 40 '/tensorflow/api/ps_strategy/coordinator/function_tracing', time_buckets, 41 'Sampler to track the time (in seconds) for tracing functions.') 42 43 closure_execution_sampler = monitoring.Sampler( 44 '/tensorflow/api/ps_strategy/coordinator/closure_execution', 45 time_buckets, 46 'Sampler to track the time (in seconds) for executing closures.') 47 48 remote_value_fetch_sampler = monitoring.Sampler( 49 '/tensorflow/api/ps_strategy/coordinator/remote_value_fetch', 50 time_buckets, 51 'Sampler to track the time (in seconds) for fetching remote_value.') 52 53 _METRICS_MAPPING = { 54 'function_tracing': function_tracing_sampler, 55 'closure_execution': closure_execution_sampler, 56 'remote_value_fetch': remote_value_fetch_sampler 57 } 58 59 60@tf_contextlib.contextmanager 61def monitored_timer(metric_name, state_tracker=None): 62 """Monitor the execution time and collect it into the specified metric.""" 63 if not enable_metrics: 64 yield 65 else: 66 if not _METRICS_MAPPING: 67 _init() 68 start_time = time.time() 69 start_state = state_tracker() if state_tracker else None 70 yield 71 duration_sec = time.time() - start_time 72 # If a state_checker is provided, record the metric only if the end state is 73 # different from the start state. 74 if state_tracker is None or state_tracker() != start_state: 75 metric = _METRICS_MAPPING[metric_name] 76 metric.get_cell().add(duration_sec) 77 78 79def get_metric_summary(metric_name): 80 """Get summary for the specified metric.""" 81 metric = _METRICS_MAPPING[metric_name] 82 histogram_proto = metric.get_cell().value() 83 ret = dict() 84 ret['min'] = histogram_proto.min 85 ret['max'] = histogram_proto.max 86 ret['num'] = histogram_proto.num 87 ret['sum'] = histogram_proto.sum 88 # TODO(haoyuzhang): consider reporting the distribution in buckets. 89 return ret 90