1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import heapq
20import itertools
21import multiprocessing
22import os
23import re
24import subprocess
25import threading
26import tempfile
27
28from rangelib import RangeSet
29
30
31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
33
34def compute_patch(src, tgt, imgdiff=False):
35  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
36  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
37  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
38  os.close(patchfd)
39
40  try:
41    with os.fdopen(srcfd, "wb") as f_src:
42      for p in src:
43        f_src.write(p)
44
45    with os.fdopen(tgtfd, "wb") as f_tgt:
46      for p in tgt:
47        f_tgt.write(p)
48    try:
49      os.unlink(patchfile)
50    except OSError:
51      pass
52    if imgdiff:
53      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
54                          stdout=open("/dev/null", "a"),
55                          stderr=subprocess.STDOUT)
56    else:
57      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
58
59    if p:
60      raise ValueError("diff failed: " + str(p))
61
62    with open(patchfile, "rb") as f:
63      return f.read()
64  finally:
65    try:
66      os.unlink(srcfile)
67      os.unlink(tgtfile)
68      os.unlink(patchfile)
69    except OSError:
70      pass
71
72
73class Image(object):
74  def ReadRangeSet(self, ranges):
75    raise NotImplementedError
76
77  def TotalSha1(self, include_clobbered_blocks=False):
78    raise NotImplementedError
79
80
81class EmptyImage(Image):
82  """A zero-length image."""
83  blocksize = 4096
84  care_map = RangeSet()
85  clobbered_blocks = RangeSet()
86  extended = RangeSet()
87  total_blocks = 0
88  file_map = {}
89  def ReadRangeSet(self, ranges):
90    return ()
91  def TotalSha1(self, include_clobbered_blocks=False):
92    # EmptyImage always carries empty clobbered_blocks, so
93    # include_clobbered_blocks can be ignored.
94    assert self.clobbered_blocks.size() == 0
95    return sha1().hexdigest()
96
97
98class DataImage(Image):
99  """An image wrapped around a single string of data."""
100
101  def __init__(self, data, trim=False, pad=False):
102    self.data = data
103    self.blocksize = 4096
104
105    assert not (trim and pad)
106
107    partial = len(self.data) % self.blocksize
108    if partial > 0:
109      if trim:
110        self.data = self.data[:-partial]
111      elif pad:
112        self.data += '\0' * (self.blocksize - partial)
113      else:
114        raise ValueError(("data for DataImage must be multiple of %d bytes "
115                          "unless trim or pad is specified") %
116                         (self.blocksize,))
117
118    assert len(self.data) % self.blocksize == 0
119
120    self.total_blocks = len(self.data) / self.blocksize
121    self.care_map = RangeSet(data=(0, self.total_blocks))
122    self.clobbered_blocks = RangeSet()
123    self.extended = RangeSet()
124
125    zero_blocks = []
126    nonzero_blocks = []
127    reference = '\0' * self.blocksize
128
129    for i in range(self.total_blocks):
130      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
131      if d == reference:
132        zero_blocks.append(i)
133        zero_blocks.append(i+1)
134      else:
135        nonzero_blocks.append(i)
136        nonzero_blocks.append(i+1)
137
138    self.file_map = {"__ZERO": RangeSet(zero_blocks),
139                     "__NONZERO": RangeSet(nonzero_blocks)}
140
141  def ReadRangeSet(self, ranges):
142    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
143
144  def TotalSha1(self, include_clobbered_blocks=False):
145    # DataImage always carries empty clobbered_blocks, so
146    # include_clobbered_blocks can be ignored.
147    assert self.clobbered_blocks.size() == 0
148    return sha1(self.data).hexdigest()
149
150
151class Transfer(object):
152  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
153    self.tgt_name = tgt_name
154    self.src_name = src_name
155    self.tgt_ranges = tgt_ranges
156    self.src_ranges = src_ranges
157    self.style = style
158    self.intact = (getattr(tgt_ranges, "monotonic", False) and
159                   getattr(src_ranges, "monotonic", False))
160
161    # We use OrderedDict rather than dict so that the output is repeatable;
162    # otherwise it would depend on the hash values of the Transfer objects.
163    self.goes_before = OrderedDict()
164    self.goes_after = OrderedDict()
165
166    self.stash_before = []
167    self.use_stash = []
168
169    self.id = len(by_id)
170    by_id.append(self)
171
172  def NetStashChange(self):
173    return (sum(sr.size() for (_, sr) in self.stash_before) -
174            sum(sr.size() for (_, sr) in self.use_stash))
175
176  def __str__(self):
177    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
178            " to " + str(self.tgt_ranges) + ">")
179
180
181# BlockImageDiff works on two image objects.  An image object is
182# anything that provides the following attributes:
183#
184#    blocksize: the size in bytes of a block, currently must be 4096.
185#
186#    total_blocks: the total size of the partition/image, in blocks.
187#
188#    care_map: a RangeSet containing which blocks (in the range [0,
189#      total_blocks) we actually care about; i.e. which blocks contain
190#      data.
191#
192#    file_map: a dict that partitions the blocks contained in care_map
193#      into smaller domains that are useful for doing diffs on.
194#      (Typically a domain is a file, and the key in file_map is the
195#      pathname.)
196#
197#    clobbered_blocks: a RangeSet containing which blocks contain data
198#      but may be altered by the FS. They need to be excluded when
199#      verifying the partition integrity.
200#
201#    ReadRangeSet(): a function that takes a RangeSet and returns the
202#      data contained in the image blocks of that RangeSet.  The data
203#      is returned as a list or tuple of strings; concatenating the
204#      elements together should produce the requested data.
205#      Implementations are free to break up the data into list/tuple
206#      elements in any way that is convenient.
207#
208#    TotalSha1(): a function that returns (as a hex string) the SHA-1
209#      hash of all the data in the image (ie, all the blocks in the
210#      care_map minus clobbered_blocks, or including the clobbered
211#      blocks if include_clobbered_blocks is True).
212#
213# When creating a BlockImageDiff, the src image may be None, in which
214# case the list of transfers produced will never read from the
215# original image.
216
217class BlockImageDiff(object):
218  def __init__(self, tgt, src=None, threads=None, version=3):
219    if threads is None:
220      threads = multiprocessing.cpu_count() // 2
221      if threads == 0:
222        threads = 1
223    self.threads = threads
224    self.version = version
225    self.transfers = []
226    self.src_basenames = {}
227    self.src_numpatterns = {}
228
229    assert version in (1, 2, 3)
230
231    self.tgt = tgt
232    if src is None:
233      src = EmptyImage()
234    self.src = src
235
236    # The updater code that installs the patch always uses 4k blocks.
237    assert tgt.blocksize == 4096
238    assert src.blocksize == 4096
239
240    # The range sets in each filemap should comprise a partition of
241    # the care map.
242    self.AssertPartition(src.care_map, src.file_map.values())
243    self.AssertPartition(tgt.care_map, tgt.file_map.values())
244
245  def Compute(self, prefix):
246    # When looking for a source file to use as the diff input for a
247    # target file, we try:
248    #   1) an exact path match if available, otherwise
249    #   2) a exact basename match if available, otherwise
250    #   3) a basename match after all runs of digits are replaced by
251    #      "#" if available, otherwise
252    #   4) we have no source for this target.
253    self.AbbreviateSourceNames()
254    self.FindTransfers()
255
256    # Find the ordering dependencies among transfers (this is O(n^2)
257    # in the number of transfers).
258    self.GenerateDigraph()
259    # Find a sequence of transfers that satisfies as many ordering
260    # dependencies as possible (heuristically).
261    self.FindVertexSequence()
262    # Fix up the ordering dependencies that the sequence didn't
263    # satisfy.
264    if self.version == 1:
265      self.RemoveBackwardEdges()
266    else:
267      self.ReverseBackwardEdges()
268      self.ImproveVertexSequence()
269
270    # Double-check our work.
271    self.AssertSequenceGood()
272
273    self.ComputePatches(prefix)
274    self.WriteTransfers(prefix)
275
276  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
277    data = source.ReadRangeSet(ranges)
278    ctx = sha1()
279
280    for p in data:
281      ctx.update(p)
282
283    return ctx.hexdigest()
284
285  def WriteTransfers(self, prefix):
286    out = []
287
288    total = 0
289    performs_read = False
290
291    stashes = {}
292    stashed_blocks = 0
293    max_stashed_blocks = 0
294
295    free_stash_ids = []
296    next_stash_id = 0
297
298    for xf in self.transfers:
299
300      if self.version < 2:
301        assert not xf.stash_before
302        assert not xf.use_stash
303
304      for s, sr in xf.stash_before:
305        assert s not in stashes
306        if free_stash_ids:
307          sid = heapq.heappop(free_stash_ids)
308        else:
309          sid = next_stash_id
310          next_stash_id += 1
311        stashes[s] = sid
312        stashed_blocks += sr.size()
313        if self.version == 2:
314          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
315        else:
316          sh = self.HashBlocks(self.src, sr)
317          if sh in stashes:
318            stashes[sh] += 1
319          else:
320            stashes[sh] = 1
321            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
322
323      if stashed_blocks > max_stashed_blocks:
324        max_stashed_blocks = stashed_blocks
325
326      free_string = []
327
328      if self.version == 1:
329        src_str = xf.src_ranges.to_string_raw()
330      elif self.version >= 2:
331
332        #   <# blocks> <src ranges>
333        #     OR
334        #   <# blocks> <src ranges> <src locs> <stash refs...>
335        #     OR
336        #   <# blocks> - <stash refs...>
337
338        size = xf.src_ranges.size()
339        src_str = [str(size)]
340
341        unstashed_src_ranges = xf.src_ranges
342        mapped_stashes = []
343        for s, sr in xf.use_stash:
344          sid = stashes.pop(s)
345          stashed_blocks -= sr.size()
346          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
347          sh = self.HashBlocks(self.src, sr)
348          sr = xf.src_ranges.map_within(sr)
349          mapped_stashes.append(sr)
350          if self.version == 2:
351            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
352          else:
353            assert sh in stashes
354            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
355            stashes[sh] -= 1
356            if stashes[sh] == 0:
357              free_string.append("free %s\n" % (sh))
358              stashes.pop(sh)
359          heapq.heappush(free_stash_ids, sid)
360
361        if unstashed_src_ranges:
362          src_str.insert(1, unstashed_src_ranges.to_string_raw())
363          if xf.use_stash:
364            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
365            src_str.insert(2, mapped_unstashed.to_string_raw())
366            mapped_stashes.append(mapped_unstashed)
367            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
368        else:
369          src_str.insert(1, "-")
370          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
371
372        src_str = " ".join(src_str)
373
374      # all versions:
375      #   zero <rangeset>
376      #   new <rangeset>
377      #   erase <rangeset>
378      #
379      # version 1:
380      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
381      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
382      #   move <src rangeset> <tgt rangeset>
383      #
384      # version 2:
385      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
386      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
387      #   move <tgt rangeset> <src_str>
388      #
389      # version 3:
390      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
391      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
392      #   move hash <tgt rangeset> <src_str>
393
394      tgt_size = xf.tgt_ranges.size()
395
396      if xf.style == "new":
397        assert xf.tgt_ranges
398        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
399        total += tgt_size
400      elif xf.style == "move":
401        performs_read = True
402        assert xf.tgt_ranges
403        assert xf.src_ranges.size() == tgt_size
404        if xf.src_ranges != xf.tgt_ranges:
405          if self.version == 1:
406            out.append("%s %s %s\n" % (
407                xf.style,
408                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
409          elif self.version == 2:
410            out.append("%s %s %s\n" % (
411                xf.style,
412                xf.tgt_ranges.to_string_raw(), src_str))
413          elif self.version >= 3:
414            # take into account automatic stashing of overlapping blocks
415            if xf.src_ranges.overlaps(xf.tgt_ranges):
416              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
417              if temp_stash_usage > max_stashed_blocks:
418                max_stashed_blocks = temp_stash_usage
419
420            out.append("%s %s %s %s\n" % (
421                xf.style,
422                self.HashBlocks(self.tgt, xf.tgt_ranges),
423                xf.tgt_ranges.to_string_raw(), src_str))
424          total += tgt_size
425      elif xf.style in ("bsdiff", "imgdiff"):
426        performs_read = True
427        assert xf.tgt_ranges
428        assert xf.src_ranges
429        if self.version == 1:
430          out.append("%s %d %d %s %s\n" % (
431              xf.style, xf.patch_start, xf.patch_len,
432              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
433        elif self.version == 2:
434          out.append("%s %d %d %s %s\n" % (
435              xf.style, xf.patch_start, xf.patch_len,
436              xf.tgt_ranges.to_string_raw(), src_str))
437        elif self.version >= 3:
438          # take into account automatic stashing of overlapping blocks
439          if xf.src_ranges.overlaps(xf.tgt_ranges):
440            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
441            if temp_stash_usage > max_stashed_blocks:
442              max_stashed_blocks = temp_stash_usage
443
444          out.append("%s %d %d %s %s %s %s\n" % (
445              xf.style,
446              xf.patch_start, xf.patch_len,
447              self.HashBlocks(self.src, xf.src_ranges),
448              self.HashBlocks(self.tgt, xf.tgt_ranges),
449              xf.tgt_ranges.to_string_raw(), src_str))
450        total += tgt_size
451      elif xf.style == "zero":
452        assert xf.tgt_ranges
453        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
454        if to_zero:
455          out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
456          total += to_zero.size()
457      else:
458        raise ValueError("unknown transfer style '%s'\n" % xf.style)
459
460      if free_string:
461        out.append("".join(free_string))
462
463      # sanity check: abort if we're going to need more than 512 MB if
464      # stash space
465      assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
466
467    # Zero out extended blocks as a workaround for bug 20881595.
468    if self.tgt.extended:
469      out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),))
470
471    # We erase all the blocks on the partition that a) don't contain useful
472    # data in the new image and b) will not be touched by dm-verity.
473    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
474    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
475    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
476    if new_dontcare:
477      out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
478
479    out.insert(0, "%d\n" % (self.version,))   # format version number
480    out.insert(1, str(total) + "\n")
481    if self.version >= 2:
482      # version 2 only: after the total block count, we give the number
483      # of stash slots needed, and the maximum size needed (in blocks)
484      out.insert(2, str(next_stash_id) + "\n")
485      out.insert(3, str(max_stashed_blocks) + "\n")
486
487    with open(prefix + ".transfer.list", "wb") as f:
488      for i in out:
489        f.write(i)
490
491    if self.version >= 2:
492      print("max stashed blocks: %d  (%d bytes)\n" % (
493          max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
494
495  def ComputePatches(self, prefix):
496    print("Reticulating splines...")
497    diff_q = []
498    patch_num = 0
499    with open(prefix + ".new.dat", "wb") as new_f:
500      for xf in self.transfers:
501        if xf.style == "zero":
502          pass
503        elif xf.style == "new":
504          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
505            new_f.write(piece)
506        elif xf.style == "diff":
507          src = self.src.ReadRangeSet(xf.src_ranges)
508          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
509
510          # We can't compare src and tgt directly because they may have
511          # the same content but be broken up into blocks differently, eg:
512          #
513          #    ["he", "llo"]  vs  ["h", "ello"]
514          #
515          # We want those to compare equal, ideally without having to
516          # actually concatenate the strings (these may be tens of
517          # megabytes).
518
519          src_sha1 = sha1()
520          for p in src:
521            src_sha1.update(p)
522          tgt_sha1 = sha1()
523          tgt_size = 0
524          for p in tgt:
525            tgt_sha1.update(p)
526            tgt_size += len(p)
527
528          if src_sha1.digest() == tgt_sha1.digest():
529            # These are identical; we don't need to generate a patch,
530            # just issue copy commands on the device.
531            xf.style = "move"
532          else:
533            # For files in zip format (eg, APKs, JARs, etc.) we would
534            # like to use imgdiff -z if possible (because it usually
535            # produces significantly smaller patches than bsdiff).
536            # This is permissible if:
537            #
538            #  - the source and target files are monotonic (ie, the
539            #    data is stored with blocks in increasing order), and
540            #  - we haven't removed any blocks from the source set.
541            #
542            # If these conditions are satisfied then appending all the
543            # blocks in the set together in order will produce a valid
544            # zip file (plus possibly extra zeros in the last block),
545            # which is what imgdiff needs to operate.  (imgdiff is
546            # fine with extra zeros at the end of the file.)
547            imgdiff = (xf.intact and
548                       xf.tgt_name.split(".")[-1].lower()
549                       in ("apk", "jar", "zip"))
550            xf.style = "imgdiff" if imgdiff else "bsdiff"
551            diff_q.append((tgt_size, src, tgt, xf, patch_num))
552            patch_num += 1
553
554        else:
555          assert False, "unknown style " + xf.style
556
557    if diff_q:
558      if self.threads > 1:
559        print("Computing patches (using %d threads)..." % (self.threads,))
560      else:
561        print("Computing patches...")
562      diff_q.sort()
563
564      patches = [None] * patch_num
565
566      # TODO: Rewrite with multiprocessing.ThreadPool?
567      lock = threading.Lock()
568      def diff_worker():
569        while True:
570          with lock:
571            if not diff_q:
572              return
573            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
574          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
575          size = len(patch)
576          with lock:
577            patches[patchnum] = (patch, xf)
578            print("%10d %10d (%6.2f%%) %7s %s" % (
579                size, tgt_size, size * 100.0 / tgt_size, xf.style,
580                xf.tgt_name if xf.tgt_name == xf.src_name else (
581                    xf.tgt_name + " (from " + xf.src_name + ")")))
582
583      threads = [threading.Thread(target=diff_worker)
584                 for _ in range(self.threads)]
585      for th in threads:
586        th.start()
587      while threads:
588        threads.pop().join()
589    else:
590      patches = []
591
592    p = 0
593    with open(prefix + ".patch.dat", "wb") as patch_f:
594      for patch, xf in patches:
595        xf.patch_start = p
596        xf.patch_len = len(patch)
597        patch_f.write(patch)
598        p += len(patch)
599
600  def AssertSequenceGood(self):
601    # Simulate the sequences of transfers we will output, and check that:
602    # - we never read a block after writing it, and
603    # - we write every block we care about exactly once.
604
605    # Start with no blocks having been touched yet.
606    touched = RangeSet()
607
608    # Imagine processing the transfers in order.
609    for xf in self.transfers:
610      # Check that the input blocks for this transfer haven't yet been touched.
611
612      x = xf.src_ranges
613      if self.version >= 2:
614        for _, sr in xf.use_stash:
615          x = x.subtract(sr)
616
617      assert not touched.overlaps(x)
618      # Check that the output blocks for this transfer haven't yet been touched.
619      assert not touched.overlaps(xf.tgt_ranges)
620      # Touch all the blocks written by this transfer.
621      touched = touched.union(xf.tgt_ranges)
622
623    # Check that we've written every target block.
624    assert touched == self.tgt.care_map
625
626  def ImproveVertexSequence(self):
627    print("Improving vertex order...")
628
629    # At this point our digraph is acyclic; we reversed any edges that
630    # were backwards in the heuristically-generated sequence.  The
631    # previously-generated order is still acceptable, but we hope to
632    # find a better order that needs less memory for stashed data.
633    # Now we do a topological sort to generate a new vertex order,
634    # using a greedy algorithm to choose which vertex goes next
635    # whenever we have a choice.
636
637    # Make a copy of the edge set; this copy will get destroyed by the
638    # algorithm.
639    for xf in self.transfers:
640      xf.incoming = xf.goes_after.copy()
641      xf.outgoing = xf.goes_before.copy()
642
643    L = []   # the new vertex order
644
645    # S is the set of sources in the remaining graph; we always choose
646    # the one that leaves the least amount of stashed data after it's
647    # executed.
648    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
649         if not u.incoming]
650    heapq.heapify(S)
651
652    while S:
653      _, _, xf = heapq.heappop(S)
654      L.append(xf)
655      for u in xf.outgoing:
656        del u.incoming[xf]
657        if not u.incoming:
658          heapq.heappush(S, (u.NetStashChange(), u.order, u))
659
660    # if this fails then our graph had a cycle.
661    assert len(L) == len(self.transfers)
662
663    self.transfers = L
664    for i, xf in enumerate(L):
665      xf.order = i
666
667  def RemoveBackwardEdges(self):
668    print("Removing backward edges...")
669    in_order = 0
670    out_of_order = 0
671    lost_source = 0
672
673    for xf in self.transfers:
674      lost = 0
675      size = xf.src_ranges.size()
676      for u in xf.goes_before:
677        # xf should go before u
678        if xf.order < u.order:
679          # it does, hurray!
680          in_order += 1
681        else:
682          # it doesn't, boo.  trim the blocks that u writes from xf's
683          # source, so that xf can go after u.
684          out_of_order += 1
685          assert xf.src_ranges.overlaps(u.tgt_ranges)
686          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
687          xf.intact = False
688
689      if xf.style == "diff" and not xf.src_ranges:
690        # nothing left to diff from; treat as new data
691        xf.style = "new"
692
693      lost = size - xf.src_ranges.size()
694      lost_source += lost
695
696    print(("  %d/%d dependencies (%.2f%%) were violated; "
697           "%d source blocks removed.") %
698          (out_of_order, in_order + out_of_order,
699           (out_of_order * 100.0 / (in_order + out_of_order))
700           if (in_order + out_of_order) else 0.0,
701           lost_source))
702
703  def ReverseBackwardEdges(self):
704    print("Reversing backward edges...")
705    in_order = 0
706    out_of_order = 0
707    stashes = 0
708    stash_size = 0
709
710    for xf in self.transfers:
711      for u in xf.goes_before.copy():
712        # xf should go before u
713        if xf.order < u.order:
714          # it does, hurray!
715          in_order += 1
716        else:
717          # it doesn't, boo.  modify u to stash the blocks that it
718          # writes that xf wants to read, and then require u to go
719          # before xf.
720          out_of_order += 1
721
722          overlap = xf.src_ranges.intersect(u.tgt_ranges)
723          assert overlap
724
725          u.stash_before.append((stashes, overlap))
726          xf.use_stash.append((stashes, overlap))
727          stashes += 1
728          stash_size += overlap.size()
729
730          # reverse the edge direction; now xf must go after u
731          del xf.goes_before[u]
732          del u.goes_after[xf]
733          xf.goes_after[u] = None    # value doesn't matter
734          u.goes_before[xf] = None
735
736    print(("  %d/%d dependencies (%.2f%%) were violated; "
737           "%d source blocks stashed.") %
738          (out_of_order, in_order + out_of_order,
739           (out_of_order * 100.0 / (in_order + out_of_order))
740           if (in_order + out_of_order) else 0.0,
741           stash_size))
742
743  def FindVertexSequence(self):
744    print("Finding vertex sequence...")
745
746    # This is based on "A Fast & Effective Heuristic for the Feedback
747    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
748    # it as starting with the digraph G and moving all the vertices to
749    # be on a horizontal line in some order, trying to minimize the
750    # number of edges that end up pointing to the left.  Left-pointing
751    # edges will get removed to turn the digraph into a DAG.  In this
752    # case each edge has a weight which is the number of source blocks
753    # we'll lose if that edge is removed; we try to minimize the total
754    # weight rather than just the number of edges.
755
756    # Make a copy of the edge set; this copy will get destroyed by the
757    # algorithm.
758    for xf in self.transfers:
759      xf.incoming = xf.goes_after.copy()
760      xf.outgoing = xf.goes_before.copy()
761
762    # We use an OrderedDict instead of just a set so that the output
763    # is repeatable; otherwise it would depend on the hash values of
764    # the transfer objects.
765    G = OrderedDict()
766    for xf in self.transfers:
767      G[xf] = None
768    s1 = deque()  # the left side of the sequence, built from left to right
769    s2 = deque()  # the right side of the sequence, built from right to left
770
771    while G:
772
773      # Put all sinks at the end of the sequence.
774      while True:
775        sinks = [u for u in G if not u.outgoing]
776        if not sinks:
777          break
778        for u in sinks:
779          s2.appendleft(u)
780          del G[u]
781          for iu in u.incoming:
782            del iu.outgoing[u]
783
784      # Put all the sources at the beginning of the sequence.
785      while True:
786        sources = [u for u in G if not u.incoming]
787        if not sources:
788          break
789        for u in sources:
790          s1.append(u)
791          del G[u]
792          for iu in u.outgoing:
793            del iu.incoming[u]
794
795      if not G:
796        break
797
798      # Find the "best" vertex to put next.  "Best" is the one that
799      # maximizes the net difference in source blocks saved we get by
800      # pretending it's a source rather than a sink.
801
802      max_d = None
803      best_u = None
804      for u in G:
805        d = sum(u.outgoing.values()) - sum(u.incoming.values())
806        if best_u is None or d > max_d:
807          max_d = d
808          best_u = u
809
810      u = best_u
811      s1.append(u)
812      del G[u]
813      for iu in u.outgoing:
814        del iu.incoming[u]
815      for iu in u.incoming:
816        del iu.outgoing[u]
817
818    # Now record the sequence in the 'order' field of each transfer,
819    # and by rearranging self.transfers to be in the chosen sequence.
820
821    new_transfers = []
822    for x in itertools.chain(s1, s2):
823      x.order = len(new_transfers)
824      new_transfers.append(x)
825      del x.incoming
826      del x.outgoing
827
828    self.transfers = new_transfers
829
830  def GenerateDigraph(self):
831    print("Generating digraph...")
832    for a in self.transfers:
833      for b in self.transfers:
834        if a is b:
835          continue
836
837        # If the blocks written by A are read by B, then B needs to go before A.
838        i = a.tgt_ranges.intersect(b.src_ranges)
839        if i:
840          if b.src_name == "__ZERO":
841            # the cost of removing source blocks for the __ZERO domain
842            # is (nearly) zero.
843            size = 0
844          else:
845            size = i.size()
846          b.goes_before[a] = size
847          a.goes_after[b] = size
848
849  def FindTransfers(self):
850    empty = RangeSet()
851    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
852      if tgt_fn == "__ZERO":
853        # the special "__ZERO" domain is all the blocks not contained
854        # in any file and that are filled with zeros.  We have a
855        # special transfer style for zero blocks.
856        src_ranges = self.src.file_map.get("__ZERO", empty)
857        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
858                 "zero", self.transfers)
859        continue
860
861      elif tgt_fn == "__COPY":
862        # "__COPY" domain includes all the blocks not contained in any
863        # file and that need to be copied unconditionally to the target.
864        Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
865        continue
866
867      elif tgt_fn in self.src.file_map:
868        # Look for an exact pathname match in the source.
869        Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
870                 "diff", self.transfers)
871        continue
872
873      b = os.path.basename(tgt_fn)
874      if b in self.src_basenames:
875        # Look for an exact basename match in the source.
876        src_fn = self.src_basenames[b]
877        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
878                 "diff", self.transfers)
879        continue
880
881      b = re.sub("[0-9]+", "#", b)
882      if b in self.src_numpatterns:
883        # Look for a 'number pattern' match (a basename match after
884        # all runs of digits are replaced by "#").  (This is useful
885        # for .so files that contain version numbers in the filename
886        # that get bumped.)
887        src_fn = self.src_numpatterns[b]
888        Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
889                 "diff", self.transfers)
890        continue
891
892      Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
893
894  def AbbreviateSourceNames(self):
895    for k in self.src.file_map.keys():
896      b = os.path.basename(k)
897      self.src_basenames[b] = k
898      b = re.sub("[0-9]+", "#", b)
899      self.src_numpatterns[b] = k
900
901  @staticmethod
902  def AssertPartition(total, seq):
903    """Assert that all the RangeSets in 'seq' form a partition of the
904    'total' RangeSet (ie, they are nonintersecting and their union
905    equals 'total')."""
906    so_far = RangeSet()
907    for i in seq:
908      assert not so_far.overlaps(i)
909      so_far = so_far.union(i)
910    assert so_far == total
911