1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for metrics collecting in coordinator."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import time
23from tensorflow.python.distribute import multi_worker_test_base
24from tensorflow.python.distribute import parameter_server_strategy_v2
25from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
26from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
27from tensorflow.python.distribute.coordinator import metric_utils
28from tensorflow.python.eager import def_function
29from tensorflow.python.eager import test
30from tensorflow.python.training.server_lib import ClusterSpec
31
32
33class MetricUtilsTest(test.TestCase):
34
35  def get_rpc_layer(self):
36    return 'grpc'
37
38  def testClusterCoordinatorMetrics(self):
39
40    metric_utils.enable_metrics = True
41
42    cluster_def = multi_worker_test_base.create_in_process_cluster(
43        num_workers=1, num_ps=1, rpc_layer=self.get_rpc_layer())
44    cluster_def['chief'] = [
45        'localhost:%d' % multi_worker_test_base.pick_unused_port()
46    ]
47    cluster_resolver = SimpleClusterResolver(
48        ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
49    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
50        cluster_resolver)
51    cluster = coordinator_lib.Cluster(strategy)
52
53    @def_function.function
54    def func():
55      time.sleep(0.5)
56      return 3
57
58    result = cluster.schedule(func, args=None, kwargs=None)
59    result = cluster.schedule(func, args=None, kwargs=None)
60    cluster.join()
61    self.assertEqual(result.fetch(), 3)
62
63    # Tracing, closure execution, and remote_value fetching should be executed
64    # exactly once for running this function.
65    metric_tracing = metric_utils.get_metric_summary('function_tracing')
66    self.assertEqual(metric_tracing['num'], 1)
67    # Tracing time should be longer than the sleep time in Python function.
68    self.assertGreater(metric_tracing['sum'], 0.5)
69    metric_closure = metric_utils.get_metric_summary('closure_execution')
70    self.assertEqual(metric_closure['num'], 2)
71    metric_remote_value = metric_utils.get_metric_summary('remote_value_fetch')
72    self.assertEqual(metric_remote_value['num'], 2)
73
74
75if __name__ == '__main__':
76  test.main()
77