1#!/usr/bin/env python2.7
2# Copyright 2015 gRPC authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Starts a local DNS server for use in tests"""
17
18import argparse
19import sys
20import yaml
21import signal
22import os
23import threading
24import time
25
26import twisted
27import twisted.internet
28import twisted.internet.reactor
29import twisted.internet.threads
30import twisted.internet.defer
31import twisted.internet.protocol
32import twisted.names
33import twisted.names.client
34import twisted.names.dns
35import twisted.names.server
36from twisted.names import client, server, common, authority, dns
37import argparse
38import platform
39
40_SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax
41_SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123'
42
43class NoFileAuthority(authority.FileAuthority):
44  def __init__(self, soa, records):
45    # skip FileAuthority
46    common.ResolverBase.__init__(self)
47    self.soa = soa
48    self.records = records
49
50def start_local_dns_server(args):
51  all_records = {}
52  def _push_record(name, r):
53    print('pushing record: |%s|' % name)
54    if all_records.get(name) is not None:
55      all_records[name].append(r)
56      return
57    all_records[name] = [r]
58
59  def _maybe_split_up_txt_data(name, txt_data, r_ttl):
60    start = 0
61    txt_data_list = []
62    while len(txt_data[start:]) > 0:
63      next_read = len(txt_data[start:])
64      if next_read > 255:
65        next_read = 255
66      txt_data_list.append(txt_data[start:start+next_read])
67      start += next_read
68    _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl))
69
70  with open(args.records_config_path) as config:
71    test_records_config = yaml.load(config)
72  common_zone_name = test_records_config['resolver_tests_common_zone_name']
73  for group in test_records_config['resolver_component_tests']:
74    for name in group['records'].keys():
75      for record in group['records'][name]:
76        r_type = record['type']
77        r_data = record['data']
78        r_ttl = int(record['TTL'])
79        record_full_name = '%s.%s' % (name, common_zone_name)
80        assert record_full_name[-1] == '.'
81        record_full_name = record_full_name[:-1]
82        if r_type == 'A':
83          _push_record(record_full_name, dns.Record_A(r_data, ttl=r_ttl))
84        if r_type == 'AAAA':
85          _push_record(record_full_name, dns.Record_AAAA(r_data, ttl=r_ttl))
86        if r_type == 'SRV':
87          p, w, port, target = r_data.split(' ')
88          p = int(p)
89          w = int(w)
90          port = int(port)
91          target_full_name = '%s.%s' % (target, common_zone_name)
92          r_data = '%s %s %s %s' % (p, w, port, target_full_name)
93          _push_record(record_full_name, dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl))
94        if r_type == 'TXT':
95          _maybe_split_up_txt_data(record_full_name, r_data, r_ttl)
96  # Server health check record
97  _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME, dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0))
98  soa_record = dns.Record_SOA(mname = common_zone_name)
99  test_domain_com = NoFileAuthority(
100    soa = (common_zone_name, soa_record),
101    records = all_records,
102  )
103  server = twisted.names.server.DNSServerFactory(
104      authorities=[test_domain_com], verbose=2)
105  server.noisy = 2
106  twisted.internet.reactor.listenTCP(args.port, server)
107  dns_proto = twisted.names.dns.DNSDatagramProtocol(server)
108  dns_proto.noisy = 2
109  twisted.internet.reactor.listenUDP(args.port, dns_proto)
110  print('starting local dns server on 127.0.0.1:%s' % args.port)
111  print('starting twisted.internet.reactor')
112  twisted.internet.reactor.suggestThreadPoolSize(1)
113  twisted.internet.reactor.run()
114
115def _quit_on_signal(signum, _frame):
116  print('Received SIGNAL %d. Quitting with exit code 0' % signum)
117  twisted.internet.reactor.stop()
118  sys.stdout.flush()
119  sys.exit(0)
120
121def flush_stdout_loop():
122  num_timeouts_so_far = 0
123  sleep_time = 1
124  # Prevent zombies. Tests that use this server are short-lived.
125  max_timeouts = 60 * 2
126  while num_timeouts_so_far < max_timeouts:
127    sys.stdout.flush()
128    time.sleep(sleep_time)
129    num_timeouts_so_far += 1
130  print('Process timeout reached, or cancelled. Exitting 0.')
131  os.kill(os.getpid(), signal.SIGTERM)
132
133def main():
134  argp = argparse.ArgumentParser(description='Local DNS Server for resolver tests')
135  argp.add_argument('-p', '--port', default=None, type=int,
136                    help='Port for DNS server to listen on for TCP and UDP.')
137  argp.add_argument('-r', '--records_config_path', default=None, type=str,
138                    help=('Directory of resolver_test_record_groups.yaml file. '
139                          'Defauls to path needed when the test is invoked as part of run_tests.py.'))
140  args = argp.parse_args()
141  signal.signal(signal.SIGTERM, _quit_on_signal)
142  signal.signal(signal.SIGINT, _quit_on_signal)
143  output_flush_thread = threading.Thread(target=flush_stdout_loop)
144  output_flush_thread.setDaemon(True)
145  output_flush_thread.start()
146  start_local_dns_server(args)
147
148if __name__ == '__main__':
149  main()
150