1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import multiprocessing
16import random
17import threading
18import time
19
20from concurrent import futures
21import grpc
22from src.proto.grpc.testing import control_pb2
23from src.proto.grpc.testing import benchmark_service_pb2_grpc
24from src.proto.grpc.testing import worker_service_pb2_grpc
25from src.proto.grpc.testing import stats_pb2
26
27from tests.qps import benchmark_client
28from tests.qps import benchmark_server
29from tests.qps import client_runner
30from tests.qps import histogram
31from tests.unit import resources
32from tests.unit import test_common
33
34
35class WorkerServer(worker_service_pb2_grpc.WorkerServiceServicer):
36    """Python Worker Server implementation."""
37
38    def __init__(self):
39        self._quit_event = threading.Event()
40
41    def RunServer(self, request_iterator, context):
42        config = next(request_iterator).setup
43        server, port = self._create_server(config)
44        cores = multiprocessing.cpu_count()
45        server.start()
46        start_time = time.time()
47        yield self._get_server_status(start_time, start_time, port, cores)
48
49        for request in request_iterator:
50            end_time = time.time()
51            status = self._get_server_status(start_time, end_time, port, cores)
52            if request.mark.reset:
53                start_time = end_time
54            yield status
55        server.stop(None)
56
57    def _get_server_status(self, start_time, end_time, port, cores):
58        end_time = time.time()
59        elapsed_time = end_time - start_time
60        stats = stats_pb2.ServerStats(
61            time_elapsed=elapsed_time,
62            time_user=elapsed_time,
63            time_system=elapsed_time)
64        return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
65
66    def _create_server(self, config):
67        if config.async_server_threads == 0:
68            # This is the default concurrent.futures thread pool size, but
69            # None doesn't seem to work
70            server_threads = multiprocessing.cpu_count() * 5
71        else:
72            server_threads = config.async_server_threads
73        server = test_common.test_server(max_workers=server_threads)
74        if config.server_type == control_pb2.ASYNC_SERVER:
75            servicer = benchmark_server.BenchmarkServer()
76            benchmark_service_pb2_grpc.add_BenchmarkServiceServicer_to_server(
77                servicer, server)
78        elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
79            resp_size = config.payload_config.bytebuf_params.resp_size
80            servicer = benchmark_server.GenericBenchmarkServer(resp_size)
81            method_implementations = {
82                'StreamingCall':
83                grpc.stream_stream_rpc_method_handler(servicer.StreamingCall),
84                'UnaryCall':
85                grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
86            }
87            handler = grpc.method_handlers_generic_handler(
88                'grpc.testing.BenchmarkService', method_implementations)
89            server.add_generic_rpc_handlers((handler,))
90        else:
91            raise Exception('Unsupported server type {}'.format(
92                config.server_type))
93
94        if config.HasField('security_params'):  # Use SSL
95            server_creds = grpc.ssl_server_credentials(
96                ((resources.private_key(), resources.certificate_chain()),))
97            port = server.add_secure_port('[::]:{}'.format(config.port),
98                                          server_creds)
99        else:
100            port = server.add_insecure_port('[::]:{}'.format(config.port))
101
102        return (server, port)
103
104    def RunClient(self, request_iterator, context):
105        config = next(request_iterator).setup
106        client_runners = []
107        qps_data = histogram.Histogram(config.histogram_params.resolution,
108                                       config.histogram_params.max_possible)
109        start_time = time.time()
110
111        # Create a client for each channel
112        for i in xrange(config.client_channels):
113            server = config.server_targets[i % len(config.server_targets)]
114            runner = self._create_client_runner(server, config, qps_data)
115            client_runners.append(runner)
116            runner.start()
117
118        end_time = time.time()
119        yield self._get_client_status(start_time, end_time, qps_data)
120
121        # Respond to stat requests
122        for request in request_iterator:
123            end_time = time.time()
124            status = self._get_client_status(start_time, end_time, qps_data)
125            if request.mark.reset:
126                qps_data.reset()
127                start_time = time.time()
128            yield status
129
130        # Cleanup the clients
131        for runner in client_runners:
132            runner.stop()
133
134    def _get_client_status(self, start_time, end_time, qps_data):
135        latencies = qps_data.get_data()
136        end_time = time.time()
137        elapsed_time = end_time - start_time
138        stats = stats_pb2.ClientStats(
139            latencies=latencies,
140            time_elapsed=elapsed_time,
141            time_user=elapsed_time,
142            time_system=elapsed_time)
143        return control_pb2.ClientStatus(stats=stats)
144
145    def _create_client_runner(self, server, config, qps_data):
146        if config.client_type == control_pb2.SYNC_CLIENT:
147            if config.rpc_type == control_pb2.UNARY:
148                client = benchmark_client.UnarySyncBenchmarkClient(
149                    server, config, qps_data)
150            elif config.rpc_type == control_pb2.STREAMING:
151                client = benchmark_client.StreamingSyncBenchmarkClient(
152                    server, config, qps_data)
153        elif config.client_type == control_pb2.ASYNC_CLIENT:
154            if config.rpc_type == control_pb2.UNARY:
155                client = benchmark_client.UnaryAsyncBenchmarkClient(
156                    server, config, qps_data)
157            else:
158                raise Exception('Async streaming client not supported')
159        else:
160            raise Exception('Unsupported client type {}'.format(
161                config.client_type))
162
163        # In multi-channel tests, we split the load across all channels
164        load_factor = float(config.client_channels)
165        if config.load_params.WhichOneof('load') == 'closed_loop':
166            runner = client_runner.ClosedLoopClientRunner(
167                client, config.outstanding_rpcs_per_channel)
168        else:  # Open loop Poisson
169            alpha = config.load_params.poisson.offered_load / load_factor
170
171            def poisson():
172                while True:
173                    yield random.expovariate(alpha)
174
175            runner = client_runner.OpenLoopClientRunner(client, poisson())
176
177        return runner
178
179    def CoreCount(self, request, context):
180        return control_pb2.CoreResponse(cores=multiprocessing.cpu_count())
181
182    def QuitWorker(self, request, context):
183        self._quit_event.set()
184        return control_pb2.Void()
185
186    def wait_for_quit(self):
187        self._quit_event.wait()
188