1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import array
20import common
21import functools
22import heapq
23import itertools
24import multiprocessing
25import os
26import re
27import subprocess
28import threading
29import time
30import tempfile
31
32from rangelib import RangeSet
33
34
35__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
36
37
38def compute_patch(src, tgt, imgdiff=False):
39  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
40  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
41  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
42  os.close(patchfd)
43
44  try:
45    with os.fdopen(srcfd, "wb") as f_src:
46      for p in src:
47        f_src.write(p)
48
49    with os.fdopen(tgtfd, "wb") as f_tgt:
50      for p in tgt:
51        f_tgt.write(p)
52    try:
53      os.unlink(patchfile)
54    except OSError:
55      pass
56    if imgdiff:
57      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
58                          stdout=open("/dev/null", "a"),
59                          stderr=subprocess.STDOUT)
60    else:
61      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
62
63    if p:
64      raise ValueError("diff failed: " + str(p))
65
66    with open(patchfile, "rb") as f:
67      return f.read()
68  finally:
69    try:
70      os.unlink(srcfile)
71      os.unlink(tgtfile)
72      os.unlink(patchfile)
73    except OSError:
74      pass
75
76
77class Image(object):
78  def ReadRangeSet(self, ranges):
79    raise NotImplementedError
80
81  def TotalSha1(self, include_clobbered_blocks=False):
82    raise NotImplementedError
83
84
85class EmptyImage(Image):
86  """A zero-length image."""
87  blocksize = 4096
88  care_map = RangeSet()
89  clobbered_blocks = RangeSet()
90  extended = RangeSet()
91  total_blocks = 0
92  file_map = {}
93  def ReadRangeSet(self, ranges):
94    return ()
95  def TotalSha1(self, include_clobbered_blocks=False):
96    # EmptyImage always carries empty clobbered_blocks, so
97    # include_clobbered_blocks can be ignored.
98    assert self.clobbered_blocks.size() == 0
99    return sha1().hexdigest()
100
101
102class DataImage(Image):
103  """An image wrapped around a single string of data."""
104
105  def __init__(self, data, trim=False, pad=False):
106    self.data = data
107    self.blocksize = 4096
108
109    assert not (trim and pad)
110
111    partial = len(self.data) % self.blocksize
112    padded = False
113    if partial > 0:
114      if trim:
115        self.data = self.data[:-partial]
116      elif pad:
117        self.data += '\0' * (self.blocksize - partial)
118        padded = True
119      else:
120        raise ValueError(("data for DataImage must be multiple of %d bytes "
121                          "unless trim or pad is specified") %
122                         (self.blocksize,))
123
124    assert len(self.data) % self.blocksize == 0
125
126    self.total_blocks = len(self.data) / self.blocksize
127    self.care_map = RangeSet(data=(0, self.total_blocks))
128    # When the last block is padded, we always write the whole block even for
129    # incremental OTAs. Because otherwise the last block may get skipped if
130    # unchanged for an incremental, but would fail the post-install
131    # verification if it has non-zero contents in the padding bytes.
132    # Bug: 23828506
133    if padded:
134      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
135    else:
136      clobbered_blocks = []
137    self.clobbered_blocks = clobbered_blocks
138    self.extended = RangeSet()
139
140    zero_blocks = []
141    nonzero_blocks = []
142    reference = '\0' * self.blocksize
143
144    for i in range(self.total_blocks-1 if padded else self.total_blocks):
145      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
146      if d == reference:
147        zero_blocks.append(i)
148        zero_blocks.append(i+1)
149      else:
150        nonzero_blocks.append(i)
151        nonzero_blocks.append(i+1)
152
153    assert zero_blocks or nonzero_blocks or clobbered_blocks
154
155    self.file_map = dict()
156    if zero_blocks:
157      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
158    if nonzero_blocks:
159      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
160    if clobbered_blocks:
161      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
162
163  def ReadRangeSet(self, ranges):
164    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
165
166  def TotalSha1(self, include_clobbered_blocks=False):
167    if not include_clobbered_blocks:
168      ranges = self.care_map.subtract(self.clobbered_blocks)
169      return sha1(self.ReadRangeSet(ranges)).hexdigest()
170    else:
171      return sha1(self.data).hexdigest()
172
173
174class Transfer(object):
175  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
176    self.tgt_name = tgt_name
177    self.src_name = src_name
178    self.tgt_ranges = tgt_ranges
179    self.src_ranges = src_ranges
180    self.style = style
181    self.intact = (getattr(tgt_ranges, "monotonic", False) and
182                   getattr(src_ranges, "monotonic", False))
183
184    # We use OrderedDict rather than dict so that the output is repeatable;
185    # otherwise it would depend on the hash values of the Transfer objects.
186    self.goes_before = OrderedDict()
187    self.goes_after = OrderedDict()
188
189    self.stash_before = []
190    self.use_stash = []
191
192    self.id = len(by_id)
193    by_id.append(self)
194
195  def NetStashChange(self):
196    return (sum(sr.size() for (_, sr) in self.stash_before) -
197            sum(sr.size() for (_, sr) in self.use_stash))
198
199  def ConvertToNew(self):
200    assert self.style != "new"
201    self.use_stash = []
202    self.style = "new"
203    self.src_ranges = RangeSet()
204
205  def __str__(self):
206    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
207            " to " + str(self.tgt_ranges) + ">")
208
209
210@functools.total_ordering
211class HeapItem(object):
212  def __init__(self, item):
213    self.item = item
214    # Negate the score since python's heap is a min-heap and we want
215    # the maximum score.
216    self.score = -item.score
217  def clear(self):
218    self.item = None
219  def __bool__(self):
220    return self.item is None
221  def __eq__(self, other):
222    return self.score == other.score
223  def __le__(self, other):
224    return self.score <= other.score
225
226
227# BlockImageDiff works on two image objects.  An image object is
228# anything that provides the following attributes:
229#
230#    blocksize: the size in bytes of a block, currently must be 4096.
231#
232#    total_blocks: the total size of the partition/image, in blocks.
233#
234#    care_map: a RangeSet containing which blocks (in the range [0,
235#      total_blocks) we actually care about; i.e. which blocks contain
236#      data.
237#
238#    file_map: a dict that partitions the blocks contained in care_map
239#      into smaller domains that are useful for doing diffs on.
240#      (Typically a domain is a file, and the key in file_map is the
241#      pathname.)
242#
243#    clobbered_blocks: a RangeSet containing which blocks contain data
244#      but may be altered by the FS. They need to be excluded when
245#      verifying the partition integrity.
246#
247#    ReadRangeSet(): a function that takes a RangeSet and returns the
248#      data contained in the image blocks of that RangeSet.  The data
249#      is returned as a list or tuple of strings; concatenating the
250#      elements together should produce the requested data.
251#      Implementations are free to break up the data into list/tuple
252#      elements in any way that is convenient.
253#
254#    TotalSha1(): a function that returns (as a hex string) the SHA-1
255#      hash of all the data in the image (ie, all the blocks in the
256#      care_map minus clobbered_blocks, or including the clobbered
257#      blocks if include_clobbered_blocks is True).
258#
259# When creating a BlockImageDiff, the src image may be None, in which
260# case the list of transfers produced will never read from the
261# original image.
262
263class BlockImageDiff(object):
264  def __init__(self, tgt, src=None, threads=None, version=4,
265               disable_imgdiff=False):
266    if threads is None:
267      threads = multiprocessing.cpu_count() // 2
268      if threads == 0:
269        threads = 1
270    self.threads = threads
271    self.version = version
272    self.transfers = []
273    self.src_basenames = {}
274    self.src_numpatterns = {}
275    self._max_stashed_size = 0
276    self.touched_src_ranges = RangeSet()
277    self.touched_src_sha1 = None
278    self.disable_imgdiff = disable_imgdiff
279
280    assert version in (1, 2, 3, 4)
281
282    self.tgt = tgt
283    if src is None:
284      src = EmptyImage()
285    self.src = src
286
287    # The updater code that installs the patch always uses 4k blocks.
288    assert tgt.blocksize == 4096
289    assert src.blocksize == 4096
290
291    # The range sets in each filemap should comprise a partition of
292    # the care map.
293    self.AssertPartition(src.care_map, src.file_map.values())
294    self.AssertPartition(tgt.care_map, tgt.file_map.values())
295
296  @property
297  def max_stashed_size(self):
298    return self._max_stashed_size
299
300  def Compute(self, prefix):
301    # When looking for a source file to use as the diff input for a
302    # target file, we try:
303    #   1) an exact path match if available, otherwise
304    #   2) a exact basename match if available, otherwise
305    #   3) a basename match after all runs of digits are replaced by
306    #      "#" if available, otherwise
307    #   4) we have no source for this target.
308    self.AbbreviateSourceNames()
309    self.FindTransfers()
310
311    # Find the ordering dependencies among transfers (this is O(n^2)
312    # in the number of transfers).
313    self.GenerateDigraph()
314    # Find a sequence of transfers that satisfies as many ordering
315    # dependencies as possible (heuristically).
316    self.FindVertexSequence()
317    # Fix up the ordering dependencies that the sequence didn't
318    # satisfy.
319    if self.version == 1:
320      self.RemoveBackwardEdges()
321    else:
322      self.ReverseBackwardEdges()
323      self.ImproveVertexSequence()
324
325    # Ensure the runtime stash size is under the limit.
326    if self.version >= 2 and common.OPTIONS.cache_size is not None:
327      self.ReviseStashSize()
328
329    # Double-check our work.
330    self.AssertSequenceGood()
331
332    self.ComputePatches(prefix)
333    self.WriteTransfers(prefix)
334
335  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
336    data = source.ReadRangeSet(ranges)
337    ctx = sha1()
338
339    for p in data:
340      ctx.update(p)
341
342    return ctx.hexdigest()
343
344  def WriteTransfers(self, prefix):
345    def WriteTransfersZero(out, to_zero):
346      """Limit the number of blocks in command zero to 1024 blocks.
347
348      This prevents the target size of one command from being too large; and
349      might help to avoid fsync errors on some devices."""
350
351      zero_blocks_limit = 1024
352      total = 0
353      while to_zero:
354        zero_blocks = to_zero.first(zero_blocks_limit)
355        out.append("zero %s\n" % (zero_blocks.to_string_raw(),))
356        total += zero_blocks.size()
357        to_zero = to_zero.subtract(zero_blocks)
358      return total
359
360    out = []
361
362    total = 0
363
364    stashes = {}
365    stashed_blocks = 0
366    max_stashed_blocks = 0
367
368    free_stash_ids = []
369    next_stash_id = 0
370
371    for xf in self.transfers:
372
373      if self.version < 2:
374        assert not xf.stash_before
375        assert not xf.use_stash
376
377      for s, sr in xf.stash_before:
378        assert s not in stashes
379        if free_stash_ids:
380          sid = heapq.heappop(free_stash_ids)
381        else:
382          sid = next_stash_id
383          next_stash_id += 1
384        stashes[s] = sid
385        if self.version == 2:
386          stashed_blocks += sr.size()
387          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
388        else:
389          sh = self.HashBlocks(self.src, sr)
390          if sh in stashes:
391            stashes[sh] += 1
392          else:
393            stashes[sh] = 1
394            stashed_blocks += sr.size()
395            self.touched_src_ranges = self.touched_src_ranges.union(sr)
396            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
397
398      if stashed_blocks > max_stashed_blocks:
399        max_stashed_blocks = stashed_blocks
400
401      free_string = []
402      free_size = 0
403
404      if self.version == 1:
405        src_str = xf.src_ranges.to_string_raw() if xf.src_ranges else ""
406      elif self.version >= 2:
407
408        #   <# blocks> <src ranges>
409        #     OR
410        #   <# blocks> <src ranges> <src locs> <stash refs...>
411        #     OR
412        #   <# blocks> - <stash refs...>
413
414        size = xf.src_ranges.size()
415        src_str = [str(size)]
416
417        unstashed_src_ranges = xf.src_ranges
418        mapped_stashes = []
419        for s, sr in xf.use_stash:
420          sid = stashes.pop(s)
421          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
422          sh = self.HashBlocks(self.src, sr)
423          sr = xf.src_ranges.map_within(sr)
424          mapped_stashes.append(sr)
425          if self.version == 2:
426            src_str.append("%d:%s" % (sid, sr.to_string_raw()))
427            # A stash will be used only once. We need to free the stash
428            # immediately after the use, instead of waiting for the automatic
429            # clean-up at the end. Because otherwise it may take up extra space
430            # and lead to OTA failures.
431            # Bug: 23119955
432            free_string.append("free %d\n" % (sid,))
433            free_size += sr.size()
434          else:
435            assert sh in stashes
436            src_str.append("%s:%s" % (sh, sr.to_string_raw()))
437            stashes[sh] -= 1
438            if stashes[sh] == 0:
439              free_size += sr.size()
440              free_string.append("free %s\n" % (sh))
441              stashes.pop(sh)
442          heapq.heappush(free_stash_ids, sid)
443
444        if unstashed_src_ranges:
445          src_str.insert(1, unstashed_src_ranges.to_string_raw())
446          if xf.use_stash:
447            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
448            src_str.insert(2, mapped_unstashed.to_string_raw())
449            mapped_stashes.append(mapped_unstashed)
450            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
451        else:
452          src_str.insert(1, "-")
453          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
454
455        src_str = " ".join(src_str)
456
457      # all versions:
458      #   zero <rangeset>
459      #   new <rangeset>
460      #   erase <rangeset>
461      #
462      # version 1:
463      #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
464      #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
465      #   move <src rangeset> <tgt rangeset>
466      #
467      # version 2:
468      #   bsdiff patchstart patchlen <tgt rangeset> <src_str>
469      #   imgdiff patchstart patchlen <tgt rangeset> <src_str>
470      #   move <tgt rangeset> <src_str>
471      #
472      # version 3:
473      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
474      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
475      #   move hash <tgt rangeset> <src_str>
476
477      tgt_size = xf.tgt_ranges.size()
478
479      if xf.style == "new":
480        assert xf.tgt_ranges
481        out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
482        total += tgt_size
483      elif xf.style == "move":
484        assert xf.tgt_ranges
485        assert xf.src_ranges.size() == tgt_size
486        if xf.src_ranges != xf.tgt_ranges:
487          if self.version == 1:
488            out.append("%s %s %s\n" % (
489                xf.style,
490                xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
491          elif self.version == 2:
492            out.append("%s %s %s\n" % (
493                xf.style,
494                xf.tgt_ranges.to_string_raw(), src_str))
495          elif self.version >= 3:
496            # take into account automatic stashing of overlapping blocks
497            if xf.src_ranges.overlaps(xf.tgt_ranges):
498              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
499              if temp_stash_usage > max_stashed_blocks:
500                max_stashed_blocks = temp_stash_usage
501
502            self.touched_src_ranges = self.touched_src_ranges.union(
503                xf.src_ranges)
504
505            out.append("%s %s %s %s\n" % (
506                xf.style,
507                self.HashBlocks(self.tgt, xf.tgt_ranges),
508                xf.tgt_ranges.to_string_raw(), src_str))
509          total += tgt_size
510      elif xf.style in ("bsdiff", "imgdiff"):
511        assert xf.tgt_ranges
512        assert xf.src_ranges
513        if self.version == 1:
514          out.append("%s %d %d %s %s\n" % (
515              xf.style, xf.patch_start, xf.patch_len,
516              xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
517        elif self.version == 2:
518          out.append("%s %d %d %s %s\n" % (
519              xf.style, xf.patch_start, xf.patch_len,
520              xf.tgt_ranges.to_string_raw(), src_str))
521        elif self.version >= 3:
522          # take into account automatic stashing of overlapping blocks
523          if xf.src_ranges.overlaps(xf.tgt_ranges):
524            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
525            if temp_stash_usage > max_stashed_blocks:
526              max_stashed_blocks = temp_stash_usage
527
528          self.touched_src_ranges = self.touched_src_ranges.union(
529              xf.src_ranges)
530
531          out.append("%s %d %d %s %s %s %s\n" % (
532              xf.style,
533              xf.patch_start, xf.patch_len,
534              self.HashBlocks(self.src, xf.src_ranges),
535              self.HashBlocks(self.tgt, xf.tgt_ranges),
536              xf.tgt_ranges.to_string_raw(), src_str))
537        total += tgt_size
538      elif xf.style == "zero":
539        assert xf.tgt_ranges
540        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
541        assert WriteTransfersZero(out, to_zero) == to_zero.size()
542        total += to_zero.size()
543      else:
544        raise ValueError("unknown transfer style '%s'\n" % xf.style)
545
546      if free_string:
547        out.append("".join(free_string))
548        stashed_blocks -= free_size
549
550      if self.version >= 2 and common.OPTIONS.cache_size is not None:
551        # Sanity check: abort if we're going to need more stash space than
552        # the allowed size (cache_size * threshold). There are two purposes
553        # of having a threshold here. a) Part of the cache may have been
554        # occupied by some recovery logs. b) It will buy us some time to deal
555        # with the oversize issue.
556        cache_size = common.OPTIONS.cache_size
557        stash_threshold = common.OPTIONS.stash_threshold
558        max_allowed = cache_size * stash_threshold
559        assert max_stashed_blocks * self.tgt.blocksize < max_allowed, \
560               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
561                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
562                   self.tgt.blocksize, max_allowed, cache_size,
563                   stash_threshold)
564
565    if self.version >= 3:
566      self.touched_src_sha1 = self.HashBlocks(
567          self.src, self.touched_src_ranges)
568
569    # Zero out extended blocks as a workaround for bug 20881595.
570    if self.tgt.extended:
571      assert (WriteTransfersZero(out, self.tgt.extended) ==
572              self.tgt.extended.size())
573      total += self.tgt.extended.size()
574
575    # We erase all the blocks on the partition that a) don't contain useful
576    # data in the new image; b) will not be touched by dm-verity. Out of those
577    # blocks, we erase the ones that won't be used in this update at the
578    # beginning of an update. The rest would be erased at the end. This is to
579    # work around the eMMC issue observed on some devices, which may otherwise
580    # get starving for clean blocks and thus fail the update. (b/28347095)
581    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
582    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
583    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
584
585    erase_first = new_dontcare.subtract(self.touched_src_ranges)
586    if erase_first:
587      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
588
589    erase_last = new_dontcare.subtract(erase_first)
590    if erase_last:
591      out.append("erase %s\n" % (erase_last.to_string_raw(),))
592
593    out.insert(0, "%d\n" % (self.version,))   # format version number
594    out.insert(1, "%d\n" % (total,))
595    if self.version >= 2:
596      # version 2 only: after the total block count, we give the number
597      # of stash slots needed, and the maximum size needed (in blocks)
598      out.insert(2, str(next_stash_id) + "\n")
599      out.insert(3, str(max_stashed_blocks) + "\n")
600
601    with open(prefix + ".transfer.list", "wb") as f:
602      for i in out:
603        f.write(i)
604
605    if self.version >= 2:
606      self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
607      OPTIONS = common.OPTIONS
608      if OPTIONS.cache_size is not None:
609        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
610        print("max stashed blocks: %d  (%d bytes), "
611              "limit: %d bytes (%.2f%%)\n" % (
612              max_stashed_blocks, self._max_stashed_size, max_allowed,
613              self._max_stashed_size * 100.0 / max_allowed))
614      else:
615        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
616              max_stashed_blocks, self._max_stashed_size))
617
618  def ReviseStashSize(self):
619    print("Revising stash size...")
620    stashes = {}
621
622    # Create the map between a stash and its def/use points. For example, for a
623    # given stash of (idx, sr), stashes[idx] = (sr, def_cmd, use_cmd).
624    for xf in self.transfers:
625      # Command xf defines (stores) all the stashes in stash_before.
626      for idx, sr in xf.stash_before:
627        stashes[idx] = (sr, xf)
628
629      # Record all the stashes command xf uses.
630      for idx, _ in xf.use_stash:
631        stashes[idx] += (xf,)
632
633    # Compute the maximum blocks available for stash based on /cache size and
634    # the threshold.
635    cache_size = common.OPTIONS.cache_size
636    stash_threshold = common.OPTIONS.stash_threshold
637    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
638
639    stashed_blocks = 0
640    new_blocks = 0
641
642    # Now go through all the commands. Compute the required stash size on the
643    # fly. If a command requires excess stash than available, it deletes the
644    # stash by replacing the command that uses the stash with a "new" command
645    # instead.
646    for xf in self.transfers:
647      replaced_cmds = []
648
649      # xf.stash_before generates explicit stash commands.
650      for idx, sr in xf.stash_before:
651        if stashed_blocks + sr.size() > max_allowed:
652          # We cannot stash this one for a later command. Find out the command
653          # that will use this stash and replace the command with "new".
654          use_cmd = stashes[idx][2]
655          replaced_cmds.append(use_cmd)
656          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
657        else:
658          stashed_blocks += sr.size()
659
660      # xf.use_stash generates free commands.
661      for _, sr in xf.use_stash:
662        stashed_blocks -= sr.size()
663
664      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
665      # ComputePatches(), they both have the style of "diff".
666      if xf.style == "diff" and self.version >= 3:
667        assert xf.tgt_ranges and xf.src_ranges
668        if xf.src_ranges.overlaps(xf.tgt_ranges):
669          if stashed_blocks + xf.src_ranges.size() > max_allowed:
670            replaced_cmds.append(xf)
671            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
672
673      # Replace the commands in replaced_cmds with "new"s.
674      for cmd in replaced_cmds:
675        # It no longer uses any commands in "use_stash". Remove the def points
676        # for all those stashes.
677        for idx, sr in cmd.use_stash:
678          def_cmd = stashes[idx][1]
679          assert (idx, sr) in def_cmd.stash_before
680          def_cmd.stash_before.remove((idx, sr))
681
682        # Add up blocks that violates space limit and print total number to
683        # screen later.
684        new_blocks += cmd.tgt_ranges.size()
685        cmd.ConvertToNew()
686
687    num_of_bytes = new_blocks * self.tgt.blocksize
688    print("  Total %d blocks (%d bytes) are packed as new blocks due to "
689          "insufficient cache size." % (new_blocks, num_of_bytes))
690
691  def ComputePatches(self, prefix):
692    print("Reticulating splines...")
693    diff_q = []
694    patch_num = 0
695    with open(prefix + ".new.dat", "wb") as new_f:
696      for xf in self.transfers:
697        if xf.style == "zero":
698          pass
699        elif xf.style == "new":
700          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
701            new_f.write(piece)
702        elif xf.style == "diff":
703          src = self.src.ReadRangeSet(xf.src_ranges)
704          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
705
706          # We can't compare src and tgt directly because they may have
707          # the same content but be broken up into blocks differently, eg:
708          #
709          #    ["he", "llo"]  vs  ["h", "ello"]
710          #
711          # We want those to compare equal, ideally without having to
712          # actually concatenate the strings (these may be tens of
713          # megabytes).
714
715          src_sha1 = sha1()
716          for p in src:
717            src_sha1.update(p)
718          tgt_sha1 = sha1()
719          tgt_size = 0
720          for p in tgt:
721            tgt_sha1.update(p)
722            tgt_size += len(p)
723
724          if src_sha1.digest() == tgt_sha1.digest():
725            # These are identical; we don't need to generate a patch,
726            # just issue copy commands on the device.
727            xf.style = "move"
728          else:
729            # For files in zip format (eg, APKs, JARs, etc.) we would
730            # like to use imgdiff -z if possible (because it usually
731            # produces significantly smaller patches than bsdiff).
732            # This is permissible if:
733            #
734            #  - imgdiff is not disabled, and
735            #  - the source and target files are monotonic (ie, the
736            #    data is stored with blocks in increasing order), and
737            #  - we haven't removed any blocks from the source set.
738            #
739            # If these conditions are satisfied then appending all the
740            # blocks in the set together in order will produce a valid
741            # zip file (plus possibly extra zeros in the last block),
742            # which is what imgdiff needs to operate.  (imgdiff is
743            # fine with extra zeros at the end of the file.)
744            imgdiff = (not self.disable_imgdiff and xf.intact and
745                       xf.tgt_name.split(".")[-1].lower()
746                       in ("apk", "jar", "zip"))
747            xf.style = "imgdiff" if imgdiff else "bsdiff"
748            diff_q.append((tgt_size, src, tgt, xf, patch_num))
749            patch_num += 1
750
751        else:
752          assert False, "unknown style " + xf.style
753
754    if diff_q:
755      if self.threads > 1:
756        print("Computing patches (using %d threads)..." % (self.threads,))
757      else:
758        print("Computing patches...")
759      diff_q.sort()
760
761      patches = [None] * patch_num
762
763      # TODO: Rewrite with multiprocessing.ThreadPool?
764      lock = threading.Lock()
765      def diff_worker():
766        while True:
767          with lock:
768            if not diff_q:
769              return
770            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
771          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
772          size = len(patch)
773          with lock:
774            patches[patchnum] = (patch, xf)
775            print("%10d %10d (%6.2f%%) %7s %s" % (
776                size, tgt_size, size * 100.0 / tgt_size, xf.style,
777                xf.tgt_name if xf.tgt_name == xf.src_name else (
778                    xf.tgt_name + " (from " + xf.src_name + ")")))
779
780      threads = [threading.Thread(target=diff_worker)
781                 for _ in range(self.threads)]
782      for th in threads:
783        th.start()
784      while threads:
785        threads.pop().join()
786    else:
787      patches = []
788
789    p = 0
790    with open(prefix + ".patch.dat", "wb") as patch_f:
791      for patch, xf in patches:
792        xf.patch_start = p
793        xf.patch_len = len(patch)
794        patch_f.write(patch)
795        p += len(patch)
796
797  def AssertSequenceGood(self):
798    # Simulate the sequences of transfers we will output, and check that:
799    # - we never read a block after writing it, and
800    # - we write every block we care about exactly once.
801
802    # Start with no blocks having been touched yet.
803    touched = array.array("B", "\0" * self.tgt.total_blocks)
804
805    # Imagine processing the transfers in order.
806    for xf in self.transfers:
807      # Check that the input blocks for this transfer haven't yet been touched.
808
809      x = xf.src_ranges
810      if self.version >= 2:
811        for _, sr in xf.use_stash:
812          x = x.subtract(sr)
813
814      for s, e in x:
815        # Source image could be larger. Don't check the blocks that are in the
816        # source image only. Since they are not in 'touched', and won't ever
817        # be touched.
818        for i in range(s, min(e, self.tgt.total_blocks)):
819          assert touched[i] == 0
820
821      # Check that the output blocks for this transfer haven't yet
822      # been touched, and touch all the blocks written by this
823      # transfer.
824      for s, e in xf.tgt_ranges:
825        for i in range(s, e):
826          assert touched[i] == 0
827          touched[i] = 1
828
829    # Check that we've written every target block.
830    for s, e in self.tgt.care_map:
831      for i in range(s, e):
832        assert touched[i] == 1
833
834  def ImproveVertexSequence(self):
835    print("Improving vertex order...")
836
837    # At this point our digraph is acyclic; we reversed any edges that
838    # were backwards in the heuristically-generated sequence.  The
839    # previously-generated order is still acceptable, but we hope to
840    # find a better order that needs less memory for stashed data.
841    # Now we do a topological sort to generate a new vertex order,
842    # using a greedy algorithm to choose which vertex goes next
843    # whenever we have a choice.
844
845    # Make a copy of the edge set; this copy will get destroyed by the
846    # algorithm.
847    for xf in self.transfers:
848      xf.incoming = xf.goes_after.copy()
849      xf.outgoing = xf.goes_before.copy()
850
851    L = []   # the new vertex order
852
853    # S is the set of sources in the remaining graph; we always choose
854    # the one that leaves the least amount of stashed data after it's
855    # executed.
856    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
857         if not u.incoming]
858    heapq.heapify(S)
859
860    while S:
861      _, _, xf = heapq.heappop(S)
862      L.append(xf)
863      for u in xf.outgoing:
864        del u.incoming[xf]
865        if not u.incoming:
866          heapq.heappush(S, (u.NetStashChange(), u.order, u))
867
868    # if this fails then our graph had a cycle.
869    assert len(L) == len(self.transfers)
870
871    self.transfers = L
872    for i, xf in enumerate(L):
873      xf.order = i
874
875  def RemoveBackwardEdges(self):
876    print("Removing backward edges...")
877    in_order = 0
878    out_of_order = 0
879    lost_source = 0
880
881    for xf in self.transfers:
882      lost = 0
883      size = xf.src_ranges.size()
884      for u in xf.goes_before:
885        # xf should go before u
886        if xf.order < u.order:
887          # it does, hurray!
888          in_order += 1
889        else:
890          # it doesn't, boo.  trim the blocks that u writes from xf's
891          # source, so that xf can go after u.
892          out_of_order += 1
893          assert xf.src_ranges.overlaps(u.tgt_ranges)
894          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
895          xf.intact = False
896
897      if xf.style == "diff" and not xf.src_ranges:
898        # nothing left to diff from; treat as new data
899        xf.style = "new"
900
901      lost = size - xf.src_ranges.size()
902      lost_source += lost
903
904    print(("  %d/%d dependencies (%.2f%%) were violated; "
905           "%d source blocks removed.") %
906          (out_of_order, in_order + out_of_order,
907           (out_of_order * 100.0 / (in_order + out_of_order))
908           if (in_order + out_of_order) else 0.0,
909           lost_source))
910
911  def ReverseBackwardEdges(self):
912    print("Reversing backward edges...")
913    in_order = 0
914    out_of_order = 0
915    stashes = 0
916    stash_size = 0
917
918    for xf in self.transfers:
919      for u in xf.goes_before.copy():
920        # xf should go before u
921        if xf.order < u.order:
922          # it does, hurray!
923          in_order += 1
924        else:
925          # it doesn't, boo.  modify u to stash the blocks that it
926          # writes that xf wants to read, and then require u to go
927          # before xf.
928          out_of_order += 1
929
930          overlap = xf.src_ranges.intersect(u.tgt_ranges)
931          assert overlap
932
933          u.stash_before.append((stashes, overlap))
934          xf.use_stash.append((stashes, overlap))
935          stashes += 1
936          stash_size += overlap.size()
937
938          # reverse the edge direction; now xf must go after u
939          del xf.goes_before[u]
940          del u.goes_after[xf]
941          xf.goes_after[u] = None    # value doesn't matter
942          u.goes_before[xf] = None
943
944    print(("  %d/%d dependencies (%.2f%%) were violated; "
945           "%d source blocks stashed.") %
946          (out_of_order, in_order + out_of_order,
947           (out_of_order * 100.0 / (in_order + out_of_order))
948           if (in_order + out_of_order) else 0.0,
949           stash_size))
950
951  def FindVertexSequence(self):
952    print("Finding vertex sequence...")
953
954    # This is based on "A Fast & Effective Heuristic for the Feedback
955    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
956    # it as starting with the digraph G and moving all the vertices to
957    # be on a horizontal line in some order, trying to minimize the
958    # number of edges that end up pointing to the left.  Left-pointing
959    # edges will get removed to turn the digraph into a DAG.  In this
960    # case each edge has a weight which is the number of source blocks
961    # we'll lose if that edge is removed; we try to minimize the total
962    # weight rather than just the number of edges.
963
964    # Make a copy of the edge set; this copy will get destroyed by the
965    # algorithm.
966    for xf in self.transfers:
967      xf.incoming = xf.goes_after.copy()
968      xf.outgoing = xf.goes_before.copy()
969      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
970
971    # We use an OrderedDict instead of just a set so that the output
972    # is repeatable; otherwise it would depend on the hash values of
973    # the transfer objects.
974    G = OrderedDict()
975    for xf in self.transfers:
976      G[xf] = None
977    s1 = deque()  # the left side of the sequence, built from left to right
978    s2 = deque()  # the right side of the sequence, built from right to left
979
980    heap = []
981    for xf in self.transfers:
982      xf.heap_item = HeapItem(xf)
983      heap.append(xf.heap_item)
984    heapq.heapify(heap)
985
986    sinks = set(u for u in G if not u.outgoing)
987    sources = set(u for u in G if not u.incoming)
988
989    def adjust_score(iu, delta):
990      iu.score += delta
991      iu.heap_item.clear()
992      iu.heap_item = HeapItem(iu)
993      heapq.heappush(heap, iu.heap_item)
994
995    while G:
996      # Put all sinks at the end of the sequence.
997      while sinks:
998        new_sinks = set()
999        for u in sinks:
1000          if u not in G: continue
1001          s2.appendleft(u)
1002          del G[u]
1003          for iu in u.incoming:
1004            adjust_score(iu, -iu.outgoing.pop(u))
1005            if not iu.outgoing: new_sinks.add(iu)
1006        sinks = new_sinks
1007
1008      # Put all the sources at the beginning of the sequence.
1009      while sources:
1010        new_sources = set()
1011        for u in sources:
1012          if u not in G: continue
1013          s1.append(u)
1014          del G[u]
1015          for iu in u.outgoing:
1016            adjust_score(iu, +iu.incoming.pop(u))
1017            if not iu.incoming: new_sources.add(iu)
1018        sources = new_sources
1019
1020      if not G: break
1021
1022      # Find the "best" vertex to put next.  "Best" is the one that
1023      # maximizes the net difference in source blocks saved we get by
1024      # pretending it's a source rather than a sink.
1025
1026      while True:
1027        u = heapq.heappop(heap)
1028        if u and u.item in G:
1029          u = u.item
1030          break
1031
1032      s1.append(u)
1033      del G[u]
1034      for iu in u.outgoing:
1035        adjust_score(iu, +iu.incoming.pop(u))
1036        if not iu.incoming: sources.add(iu)
1037
1038      for iu in u.incoming:
1039        adjust_score(iu, -iu.outgoing.pop(u))
1040        if not iu.outgoing: sinks.add(iu)
1041
1042    # Now record the sequence in the 'order' field of each transfer,
1043    # and by rearranging self.transfers to be in the chosen sequence.
1044
1045    new_transfers = []
1046    for x in itertools.chain(s1, s2):
1047      x.order = len(new_transfers)
1048      new_transfers.append(x)
1049      del x.incoming
1050      del x.outgoing
1051
1052    self.transfers = new_transfers
1053
1054  def GenerateDigraph(self):
1055    print("Generating digraph...")
1056
1057    # Each item of source_ranges will be:
1058    #   - None, if that block is not used as a source,
1059    #   - a transfer, if one transfer uses it as a source, or
1060    #   - a set of transfers.
1061    source_ranges = []
1062    for b in self.transfers:
1063      for s, e in b.src_ranges:
1064        if e > len(source_ranges):
1065          source_ranges.extend([None] * (e-len(source_ranges)))
1066        for i in range(s, e):
1067          if source_ranges[i] is None:
1068            source_ranges[i] = b
1069          else:
1070            if not isinstance(source_ranges[i], set):
1071              source_ranges[i] = set([source_ranges[i]])
1072            source_ranges[i].add(b)
1073
1074    for a in self.transfers:
1075      intersections = set()
1076      for s, e in a.tgt_ranges:
1077        for i in range(s, e):
1078          if i >= len(source_ranges): break
1079          b = source_ranges[i]
1080          if b is not None:
1081            if isinstance(b, set):
1082              intersections.update(b)
1083            else:
1084              intersections.add(b)
1085
1086      for b in intersections:
1087        if a is b: continue
1088
1089        # If the blocks written by A are read by B, then B needs to go before A.
1090        i = a.tgt_ranges.intersect(b.src_ranges)
1091        if i:
1092          if b.src_name == "__ZERO":
1093            # the cost of removing source blocks for the __ZERO domain
1094            # is (nearly) zero.
1095            size = 0
1096          else:
1097            size = i.size()
1098          b.goes_before[a] = size
1099          a.goes_after[b] = size
1100
1101  def FindTransfers(self):
1102    """Parse the file_map to generate all the transfers."""
1103
1104    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1105                    split=False):
1106      """Wrapper function for adding a Transfer().
1107
1108      For BBOTA v3, we need to stash source blocks for resumable feature.
1109      However, with the growth of file size and the shrink of the cache
1110      partition source blocks are too large to be stashed. If a file occupies
1111      too many blocks (greater than MAX_BLOCKS_PER_DIFF_TRANSFER), we split it
1112      into smaller pieces by getting multiple Transfer()s.
1113
1114      The downside is that after splitting, we may increase the package size
1115      since the split pieces don't align well. According to our experiments,
1116      1/8 of the cache size as the per-piece limit appears to be optimal.
1117      Compared to the fixed 1024-block limit, it reduces the overall package
1118      size by 30% volantis, and 20% for angler and bullhead."""
1119
1120      # We care about diff transfers only.
1121      if style != "diff" or not split:
1122        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1123        return
1124
1125      pieces = 0
1126      cache_size = common.OPTIONS.cache_size
1127      split_threshold = 0.125
1128      max_blocks_per_transfer = int(cache_size * split_threshold /
1129                                    self.tgt.blocksize)
1130
1131      # Change nothing for small files.
1132      if (tgt_ranges.size() <= max_blocks_per_transfer and
1133          src_ranges.size() <= max_blocks_per_transfer):
1134        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1135        return
1136
1137      while (tgt_ranges.size() > max_blocks_per_transfer and
1138             src_ranges.size() > max_blocks_per_transfer):
1139        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1140        src_split_name = "%s-%d" % (src_name, pieces)
1141        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1142        src_first = src_ranges.first(max_blocks_per_transfer)
1143
1144        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
1145                 by_id)
1146
1147        tgt_ranges = tgt_ranges.subtract(tgt_first)
1148        src_ranges = src_ranges.subtract(src_first)
1149        pieces += 1
1150
1151      # Handle remaining blocks.
1152      if tgt_ranges.size() or src_ranges.size():
1153        # Must be both non-empty.
1154        assert tgt_ranges.size() and src_ranges.size()
1155        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1156        src_split_name = "%s-%d" % (src_name, pieces)
1157        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
1158                 by_id)
1159
1160    empty = RangeSet()
1161    for tgt_fn, tgt_ranges in self.tgt.file_map.items():
1162      if tgt_fn == "__ZERO":
1163        # the special "__ZERO" domain is all the blocks not contained
1164        # in any file and that are filled with zeros.  We have a
1165        # special transfer style for zero blocks.
1166        src_ranges = self.src.file_map.get("__ZERO", empty)
1167        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1168                    "zero", self.transfers)
1169        continue
1170
1171      elif tgt_fn == "__COPY":
1172        # "__COPY" domain includes all the blocks not contained in any
1173        # file and that need to be copied unconditionally to the target.
1174        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1175        continue
1176
1177      elif tgt_fn in self.src.file_map:
1178        # Look for an exact pathname match in the source.
1179        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1180                    "diff", self.transfers, self.version >= 3)
1181        continue
1182
1183      b = os.path.basename(tgt_fn)
1184      if b in self.src_basenames:
1185        # Look for an exact basename match in the source.
1186        src_fn = self.src_basenames[b]
1187        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1188                    "diff", self.transfers, self.version >= 3)
1189        continue
1190
1191      b = re.sub("[0-9]+", "#", b)
1192      if b in self.src_numpatterns:
1193        # Look for a 'number pattern' match (a basename match after
1194        # all runs of digits are replaced by "#").  (This is useful
1195        # for .so files that contain version numbers in the filename
1196        # that get bumped.)
1197        src_fn = self.src_numpatterns[b]
1198        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1199                    "diff", self.transfers, self.version >= 3)
1200        continue
1201
1202      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1203
1204  def AbbreviateSourceNames(self):
1205    for k in self.src.file_map.keys():
1206      b = os.path.basename(k)
1207      self.src_basenames[b] = k
1208      b = re.sub("[0-9]+", "#", b)
1209      self.src_numpatterns[b] = k
1210
1211  @staticmethod
1212  def AssertPartition(total, seq):
1213    """Assert that all the RangeSets in 'seq' form a partition of the
1214    'total' RangeSet (ie, they are nonintersecting and their union
1215    equals 'total')."""
1216
1217    so_far = RangeSet()
1218    for i in seq:
1219      assert not so_far.overlaps(i)
1220      so_far = so_far.union(i)
1221    assert so_far == total
1222