1#!/usr/bin/env python
2# Merge or print the coverage data collected by asan's coverage.
3# Input files are sequences of 4-byte integers.
4# We need to merge these integers into a set and then
5# either print them (as hex) or dump them into another file.
6import array
7import bisect
8import glob
9import os.path
10import struct
11import subprocess
12import sys
13
14prog_name = ""
15
16def Usage():
17  sys.stderr.write(
18    "Usage: \n" + \
19    " " + prog_name + " merge FILE [FILE...] > OUTPUT\n" \
20    " " + prog_name + " print FILE [FILE...]\n" \
21    " " + prog_name + " unpack FILE [FILE...]\n" \
22    " " + prog_name + " rawunpack FILE [FILE ...]\n" \
23    " " + prog_name + " missing BINARY < LIST_OF_PCS\n" \
24    "\n")
25  exit(1)
26
27def CheckBits(bits):
28  if bits != 32 and bits != 64:
29    raise Exception("Wrong bitness: %d" % bits)
30
31def TypeCodeForBits(bits):
32  CheckBits(bits)
33  return 'L' if bits == 64 else 'I'
34
35def TypeCodeForStruct(bits):
36  CheckBits(bits)
37  return 'Q' if bits == 64 else 'I'
38
39kMagic32SecondHalf = 0xFFFFFF32;
40kMagic64SecondHalf = 0xFFFFFF64;
41kMagicFirstHalf    = 0xC0BFFFFF;
42
43def MagicForBits(bits):
44  CheckBits(bits)
45  if sys.byteorder == 'little':
46    return [kMagic64SecondHalf if bits == 64 else kMagic32SecondHalf, kMagicFirstHalf]
47  else:
48    return [kMagicFirstHalf, kMagic64SecondHalf if bits == 64 else kMagic32SecondHalf]
49
50def ReadMagicAndReturnBitness(f, path):
51  magic_bytes = f.read(8)
52  magic_words = struct.unpack('II', magic_bytes);
53  bits = 0
54  idx = 1 if sys.byteorder == 'little' else 0
55  if magic_words[idx] == kMagicFirstHalf:
56    if magic_words[1-idx] == kMagic64SecondHalf:
57      bits = 64
58    elif magic_words[1-idx] == kMagic32SecondHalf:
59      bits = 32
60  if bits == 0:
61    raise Exception('Bad magic word in %s' % path)
62  return bits
63
64def ReadOneFile(path):
65  with open(path, mode="rb") as f:
66    f.seek(0, 2)
67    size = f.tell()
68    f.seek(0, 0)
69    if size < 8:
70      raise Exception('File %s is short (< 8 bytes)' % path)
71    bits = ReadMagicAndReturnBitness(f, path)
72    size -= 8
73    w = size * 8 // bits
74    s = struct.unpack_from(TypeCodeForStruct(bits) * (w), f.read(size))
75  sys.stderr.write(
76    "%s: read %d %d-bit PCs from %s\n" % (prog_name, w, bits, path))
77  return s
78
79def Merge(files):
80  s = set()
81  for f in files:
82    s = s.union(set(ReadOneFile(f)))
83  sys.stderr.write(
84    "%s: %d files merged; %d PCs total\n" % (prog_name, len(files), len(s))
85  )
86  return sorted(s)
87
88def PrintFiles(files):
89  if len(files) > 1:
90    s = Merge(files)
91  else:  # If there is just on file, print the PCs in order.
92    s = ReadOneFile(files[0])
93    sys.stderr.write("%s: 1 file merged; %d PCs total\n" % (prog_name, len(s)))
94  for i in s:
95    print("0x%x" % i)
96
97def MergeAndPrint(files):
98  if sys.stdout.isatty():
99    Usage()
100  s = Merge(files)
101  bits = 32
102  if max(s) > 0xFFFFFFFF:
103    bits = 64
104  stdout_buf = getattr(sys.stdout, 'buffer', sys.stdout)
105  array.array('I', MagicForBits(bits)).tofile(stdout_buf)
106  a = struct.pack(TypeCodeForStruct(bits) * len(s), *s)
107  stdout_buf.write(a)
108
109
110def UnpackOneFile(path):
111  with open(path, mode="rb") as f:
112    sys.stderr.write("%s: unpacking %s\n" % (prog_name, path))
113    while True:
114      header = f.read(12)
115      if not header: return
116      if len(header) < 12:
117        break
118      pid, module_length, blob_size = struct.unpack('iII', header)
119      module = f.read(module_length).decode('utf-8')
120      blob = f.read(blob_size)
121      assert(len(module) == module_length)
122      assert(len(blob) == blob_size)
123      extracted_file = "%s.%d.sancov" % (module, pid)
124      sys.stderr.write("%s: extracting %s\n" % (prog_name, extracted_file))
125      # The packed file may contain multiple blobs for the same pid/module
126      # pair. Append to the end of the file instead of overwriting.
127      with open(extracted_file, 'ab') as f2:
128        f2.write(blob)
129    # fail
130    raise Exception('Error reading file %s' % path)
131
132
133def Unpack(files):
134  for f in files:
135    UnpackOneFile(f)
136
137def UnpackOneRawFile(path, map_path):
138  mem_map = []
139  with open(map_path, mode="rt") as f_map:
140    sys.stderr.write("%s: reading map %s\n" % (prog_name, map_path))
141    bits = int(f_map.readline())
142    if bits != 32 and bits != 64:
143      raise Exception('Wrong bits size in the map')
144    for line in f_map:
145      parts = line.rstrip().split()
146      mem_map.append((int(parts[0], 16),
147                  int(parts[1], 16),
148                  int(parts[2], 16),
149                  ' '.join(parts[3:])))
150  mem_map.sort(key=lambda m : m[0])
151  mem_map_keys = [m[0] for m in mem_map]
152
153  with open(path, mode="rb") as f:
154    sys.stderr.write("%s: unpacking %s\n" % (prog_name, path))
155
156    f.seek(0, 2)
157    size = f.tell()
158    f.seek(0, 0)
159    pcs = struct.unpack_from(TypeCodeForStruct(bits) * (size * 8 // bits), f.read(size))
160    mem_map_pcs = [[] for i in range(0, len(mem_map))]
161
162    for pc in pcs:
163      if pc == 0: continue
164      map_idx = bisect.bisect(mem_map_keys, pc) - 1
165      (start, end, base, module_path) = mem_map[map_idx]
166      assert pc >= start
167      if pc >= end:
168        sys.stderr.write("warning: %s: pc %x outside of any known mapping\n" % (prog_name, pc))
169        continue
170      mem_map_pcs[map_idx].append(pc - base)
171
172    for ((start, end, base, module_path), pc_list) in zip(mem_map, mem_map_pcs):
173      if len(pc_list) == 0: continue
174      assert path.endswith('.sancov.raw')
175      dst_path = module_path + '.' + os.path.basename(path)[:-4]
176      sys.stderr.write("%s: writing %d PCs to %s\n" % (prog_name, len(pc_list), dst_path))
177      sorted_pc_list = sorted(pc_list)
178      pc_buffer = struct.pack(TypeCodeForStruct(bits) * len(pc_list), *sorted_pc_list)
179      with open(dst_path, 'ab+') as f2:
180        array.array('I', MagicForBits(bits)).tofile(f2)
181        f2.seek(0, 2)
182        f2.write(pc_buffer)
183
184def RawUnpack(files):
185  for f in files:
186    if not f.endswith('.sancov.raw'):
187      raise Exception('Unexpected raw file name %s' % f)
188    f_map = f[:-3] + 'map'
189    UnpackOneRawFile(f, f_map)
190
191def GetInstrumentedPCs(binary):
192  # This looks scary, but all it does is extract all offsets where we call:
193  # - __sanitizer_cov() or __sanitizer_cov_with_check(),
194  # - with call or callq,
195  # - directly or via PLT.
196  cmd = r"objdump --no-show-raw-insn -d %s | " \
197        r"grep '^\s\+[0-9a-f]\+:\s\+call\(q\|\)\s\+\(0x\|\)[0-9a-f]\+ <__sanitizer_cov\(_with_check\|\|_trace_pc_guard\)\(@plt\|\)>' | " \
198        r"grep -o '^\s\+[0-9a-f]\+'" % binary
199  lines = subprocess.check_output(cmd, stdin=subprocess.PIPE, shell=True).splitlines()
200  # The PCs we get from objdump are off by 4 bytes, as they point to the
201  # beginning of the callq instruction. Empirically this is true on x86 and
202  # x86_64.
203  return set(int(line.strip(), 16) + 4 for line in lines)
204
205def PrintMissing(binary):
206  if not os.path.isfile(binary):
207    raise Exception('File not found: %s' % binary)
208  instrumented = GetInstrumentedPCs(binary)
209  sys.stderr.write("%s: found %d instrumented PCs in %s\n" % (prog_name,
210                                                              len(instrumented),
211                                                              binary))
212  covered = set(int(line, 16) for line in sys.stdin)
213  sys.stderr.write("%s: read %d PCs from stdin\n" % (prog_name, len(covered)))
214  missing = instrumented - covered
215  sys.stderr.write("%s: %d PCs missing from coverage\n" % (prog_name, len(missing)))
216  if (len(missing) > len(instrumented) - len(covered)):
217    sys.stderr.write(
218      "%s: WARNING: stdin contains PCs not found in binary\n" % prog_name
219    )
220  for pc in sorted(missing):
221    print("0x%x" % pc)
222
223if __name__ == '__main__':
224  prog_name = sys.argv[0]
225  if len(sys.argv) <= 2:
226    Usage();
227
228  if sys.argv[1] == "missing":
229    if len(sys.argv) != 3:
230      Usage()
231    PrintMissing(sys.argv[2])
232    exit(0)
233
234  file_list = []
235  for f in sys.argv[2:]:
236    file_list += glob.glob(f)
237  if not file_list:
238    Usage()
239
240  if sys.argv[1] == "print":
241    PrintFiles(file_list)
242  elif sys.argv[1] == "merge":
243    MergeAndPrint(file_list)
244  elif sys.argv[1] == "unpack":
245    Unpack(file_list)
246  elif sys.argv[1] == "rawunpack":
247    RawUnpack(file_list)
248  else:
249    Usage()
250