1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17import array
18import copy
19import functools
20import heapq
21import itertools
22import multiprocessing
23import os
24import os.path
25import re
26import subprocess
27import sys
28import threading
29from collections import deque, OrderedDict
30from hashlib import sha1
31
32import common
33from rangelib import RangeSet
34
35
36__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
37
38
39def compute_patch(srcfile, tgtfile, imgdiff=False):
40  patchfile = common.MakeTempFile(prefix='patch-')
41
42  cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
43  cmd.extend([srcfile, tgtfile, patchfile])
44
45  # Don't dump the bsdiff/imgdiff commands, which are not useful for the case
46  # here, since they contain temp filenames only.
47  p = common.Run(cmd, verbose=False, stdout=subprocess.PIPE,
48                 stderr=subprocess.STDOUT)
49  output, _ = p.communicate()
50
51  if p.returncode != 0:
52    raise ValueError(output)
53
54  with open(patchfile, 'rb') as f:
55    return f.read()
56
57
58class Image(object):
59  def RangeSha1(self, ranges):
60    raise NotImplementedError
61
62  def ReadRangeSet(self, ranges):
63    raise NotImplementedError
64
65  def TotalSha1(self, include_clobbered_blocks=False):
66    raise NotImplementedError
67
68  def WriteRangeDataToFd(self, ranges, fd):
69    raise NotImplementedError
70
71
72class EmptyImage(Image):
73  """A zero-length image."""
74
75  def __init__(self):
76    self.blocksize = 4096
77    self.care_map = RangeSet()
78    self.clobbered_blocks = RangeSet()
79    self.extended = RangeSet()
80    self.total_blocks = 0
81    self.file_map = {}
82
83  def RangeSha1(self, ranges):
84    return sha1().hexdigest()
85
86  def ReadRangeSet(self, ranges):
87    return ()
88
89  def TotalSha1(self, include_clobbered_blocks=False):
90    # EmptyImage always carries empty clobbered_blocks, so
91    # include_clobbered_blocks can be ignored.
92    assert self.clobbered_blocks.size() == 0
93    return sha1().hexdigest()
94
95  def WriteRangeDataToFd(self, ranges, fd):
96    raise ValueError("Can't write data from EmptyImage to file")
97
98
99class DataImage(Image):
100  """An image wrapped around a single string of data."""
101
102  def __init__(self, data, trim=False, pad=False):
103    self.data = data
104    self.blocksize = 4096
105
106    assert not (trim and pad)
107
108    partial = len(self.data) % self.blocksize
109    padded = False
110    if partial > 0:
111      if trim:
112        self.data = self.data[:-partial]
113      elif pad:
114        self.data += '\0' * (self.blocksize - partial)
115        padded = True
116      else:
117        raise ValueError(("data for DataImage must be multiple of %d bytes "
118                          "unless trim or pad is specified") %
119                         (self.blocksize,))
120
121    assert len(self.data) % self.blocksize == 0
122
123    self.total_blocks = len(self.data) / self.blocksize
124    self.care_map = RangeSet(data=(0, self.total_blocks))
125    # When the last block is padded, we always write the whole block even for
126    # incremental OTAs. Because otherwise the last block may get skipped if
127    # unchanged for an incremental, but would fail the post-install
128    # verification if it has non-zero contents in the padding bytes.
129    # Bug: 23828506
130    if padded:
131      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
132    else:
133      clobbered_blocks = []
134    self.clobbered_blocks = clobbered_blocks
135    self.extended = RangeSet()
136
137    zero_blocks = []
138    nonzero_blocks = []
139    reference = '\0' * self.blocksize
140
141    for i in range(self.total_blocks-1 if padded else self.total_blocks):
142      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
143      if d == reference:
144        zero_blocks.append(i)
145        zero_blocks.append(i+1)
146      else:
147        nonzero_blocks.append(i)
148        nonzero_blocks.append(i+1)
149
150    assert zero_blocks or nonzero_blocks or clobbered_blocks
151
152    self.file_map = dict()
153    if zero_blocks:
154      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
155    if nonzero_blocks:
156      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
157    if clobbered_blocks:
158      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
159
160  def _GetRangeData(self, ranges):
161    for s, e in ranges:
162      yield self.data[s*self.blocksize:e*self.blocksize]
163
164  def RangeSha1(self, ranges):
165    h = sha1()
166    for data in self._GetRangeData(ranges):
167      h.update(data)
168    return h.hexdigest()
169
170  def ReadRangeSet(self, ranges):
171    return [self._GetRangeData(ranges)]
172
173  def TotalSha1(self, include_clobbered_blocks=False):
174    if not include_clobbered_blocks:
175      return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks))
176    else:
177      return sha1(self.data).hexdigest()
178
179  def WriteRangeDataToFd(self, ranges, fd):
180    for data in self._GetRangeData(ranges):
181      fd.write(data)
182
183
184class Transfer(object):
185  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
186               src_sha1, style, by_id):
187    self.tgt_name = tgt_name
188    self.src_name = src_name
189    self.tgt_ranges = tgt_ranges
190    self.src_ranges = src_ranges
191    self.tgt_sha1 = tgt_sha1
192    self.src_sha1 = src_sha1
193    self.style = style
194
195    # We use OrderedDict rather than dict so that the output is repeatable;
196    # otherwise it would depend on the hash values of the Transfer objects.
197    self.goes_before = OrderedDict()
198    self.goes_after = OrderedDict()
199
200    self.stash_before = []
201    self.use_stash = []
202
203    self.id = len(by_id)
204    by_id.append(self)
205
206    self._patch = None
207
208  @property
209  def patch(self):
210    return self._patch
211
212  @patch.setter
213  def patch(self, patch):
214    if patch:
215      assert self.style == "diff"
216    self._patch = patch
217
218  def NetStashChange(self):
219    return (sum(sr.size() for (_, sr) in self.stash_before) -
220            sum(sr.size() for (_, sr) in self.use_stash))
221
222  def ConvertToNew(self):
223    assert self.style != "new"
224    self.use_stash = []
225    self.style = "new"
226    self.src_ranges = RangeSet()
227    self.patch = None
228
229  def __str__(self):
230    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
231            " to " + str(self.tgt_ranges) + ">")
232
233
234@functools.total_ordering
235class HeapItem(object):
236  def __init__(self, item):
237    self.item = item
238    # Negate the score since python's heap is a min-heap and we want the
239    # maximum score.
240    self.score = -item.score
241
242  def clear(self):
243    self.item = None
244
245  def __bool__(self):
246    return self.item is not None
247
248  # Python 2 uses __nonzero__, while Python 3 uses __bool__.
249  __nonzero__ = __bool__
250
251  # The rest operations are generated by functools.total_ordering decorator.
252  def __eq__(self, other):
253    return self.score == other.score
254
255  def __le__(self, other):
256    return self.score <= other.score
257
258
259class ImgdiffStats(object):
260  """A class that collects imgdiff stats.
261
262  It keeps track of the files that will be applied imgdiff while generating
263  BlockImageDiff. It also logs the ones that cannot use imgdiff, with specific
264  reasons. The stats is only meaningful when imgdiff not being disabled by the
265  caller of BlockImageDiff. In addition, only files with supported types
266  (BlockImageDiff.FileTypeSupportedByImgdiff()) are allowed to be logged.
267  """
268
269  USED_IMGDIFF = "APK files diff'd with imgdiff"
270  USED_IMGDIFF_LARGE_APK = "Large APK files split and diff'd with imgdiff"
271
272  # Reasons for not applying imgdiff on APKs.
273  SKIPPED_TRIMMED = "Not used imgdiff due to trimmed RangeSet"
274  SKIPPED_NONMONOTONIC = "Not used imgdiff due to having non-monotonic ranges"
275  SKIPPED_SHARED_BLOCKS = "Not used imgdiff due to using shared blocks"
276  SKIPPED_INCOMPLETE = "Not used imgdiff due to incomplete RangeSet"
277
278  # The list of valid reasons, which will also be the dumped order in a report.
279  REASONS = (
280      USED_IMGDIFF,
281      USED_IMGDIFF_LARGE_APK,
282      SKIPPED_TRIMMED,
283      SKIPPED_NONMONOTONIC,
284      SKIPPED_SHARED_BLOCKS,
285      SKIPPED_INCOMPLETE,
286  )
287
288  def  __init__(self):
289    self.stats = {}
290
291  def Log(self, filename, reason):
292    """Logs why imgdiff can or cannot be applied to the given filename.
293
294    Args:
295      filename: The filename string.
296      reason: One of the reason constants listed in REASONS.
297
298    Raises:
299      AssertionError: On unsupported filetypes or invalid reason.
300    """
301    assert BlockImageDiff.FileTypeSupportedByImgdiff(filename)
302    assert reason in self.REASONS
303
304    if reason not in self.stats:
305      self.stats[reason] = set()
306    self.stats[reason].add(filename)
307
308  def Report(self):
309    """Prints a report of the collected imgdiff stats."""
310
311    def print_header(header, separator):
312      print(header)
313      print(separator * len(header) + '\n')
314
315    print_header('  Imgdiff Stats Report  ', '=')
316    for key in self.REASONS:
317      if key not in self.stats:
318        continue
319      values = self.stats[key]
320      section_header = ' {} (count: {}) '.format(key, len(values))
321      print_header(section_header, '-')
322      print(''.join(['  {}\n'.format(name) for name in values]))
323
324
325# BlockImageDiff works on two image objects.  An image object is
326# anything that provides the following attributes:
327#
328#    blocksize: the size in bytes of a block, currently must be 4096.
329#
330#    total_blocks: the total size of the partition/image, in blocks.
331#
332#    care_map: a RangeSet containing which blocks (in the range [0,
333#      total_blocks) we actually care about; i.e. which blocks contain
334#      data.
335#
336#    file_map: a dict that partitions the blocks contained in care_map
337#      into smaller domains that are useful for doing diffs on.
338#      (Typically a domain is a file, and the key in file_map is the
339#      pathname.)
340#
341#    clobbered_blocks: a RangeSet containing which blocks contain data
342#      but may be altered by the FS. They need to be excluded when
343#      verifying the partition integrity.
344#
345#    ReadRangeSet(): a function that takes a RangeSet and returns the
346#      data contained in the image blocks of that RangeSet.  The data
347#      is returned as a list or tuple of strings; concatenating the
348#      elements together should produce the requested data.
349#      Implementations are free to break up the data into list/tuple
350#      elements in any way that is convenient.
351#
352#    RangeSha1(): a function that returns (as a hex string) the SHA-1
353#      hash of all the data in the specified range.
354#
355#    TotalSha1(): a function that returns (as a hex string) the SHA-1
356#      hash of all the data in the image (ie, all the blocks in the
357#      care_map minus clobbered_blocks, or including the clobbered
358#      blocks if include_clobbered_blocks is True).
359#
360# When creating a BlockImageDiff, the src image may be None, in which
361# case the list of transfers produced will never read from the
362# original image.
363
364class BlockImageDiff(object):
365  def __init__(self, tgt, src=None, threads=None, version=4,
366               disable_imgdiff=False):
367    if threads is None:
368      threads = multiprocessing.cpu_count() // 2
369      if threads == 0:
370        threads = 1
371    self.threads = threads
372    self.version = version
373    self.transfers = []
374    self.src_basenames = {}
375    self.src_numpatterns = {}
376    self._max_stashed_size = 0
377    self.touched_src_ranges = RangeSet()
378    self.touched_src_sha1 = None
379    self.disable_imgdiff = disable_imgdiff
380    self.imgdiff_stats = ImgdiffStats() if not disable_imgdiff else None
381
382    assert version in (3, 4)
383
384    self.tgt = tgt
385    if src is None:
386      src = EmptyImage()
387    self.src = src
388
389    # The updater code that installs the patch always uses 4k blocks.
390    assert tgt.blocksize == 4096
391    assert src.blocksize == 4096
392
393    # The range sets in each filemap should comprise a partition of
394    # the care map.
395    self.AssertPartition(src.care_map, src.file_map.values())
396    self.AssertPartition(tgt.care_map, tgt.file_map.values())
397
398  @property
399  def max_stashed_size(self):
400    return self._max_stashed_size
401
402  @staticmethod
403  def FileTypeSupportedByImgdiff(filename):
404    """Returns whether the file type is supported by imgdiff."""
405    return filename.lower().endswith(('.apk', '.jar', '.zip'))
406
407  def CanUseImgdiff(self, name, tgt_ranges, src_ranges, large_apk=False):
408    """Checks whether we can apply imgdiff for the given RangeSets.
409
410    For files in ZIP format (e.g., APKs, JARs, etc.) we would like to use
411    'imgdiff -z' if possible. Because it usually produces significantly smaller
412    patches than bsdiff.
413
414    This is permissible if all of the following conditions hold.
415      - The imgdiff hasn't been disabled by the caller (e.g. squashfs);
416      - The file type is supported by imgdiff;
417      - The source and target blocks are monotonic (i.e. the data is stored with
418        blocks in increasing order);
419      - Both files don't contain shared blocks;
420      - Both files have complete lists of blocks;
421      - We haven't removed any blocks from the source set.
422
423    If all these conditions are satisfied, concatenating all the blocks in the
424    RangeSet in order will produce a valid ZIP file (plus possibly extra zeros
425    in the last block). imgdiff is fine with extra zeros at the end of the file.
426
427    Args:
428      name: The filename to be diff'd.
429      tgt_ranges: The target RangeSet.
430      src_ranges: The source RangeSet.
431      large_apk: Whether this is to split a large APK.
432
433    Returns:
434      A boolean result.
435    """
436    if self.disable_imgdiff or not self.FileTypeSupportedByImgdiff(name):
437      return False
438
439    if not tgt_ranges.monotonic or not src_ranges.monotonic:
440      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_NONMONOTONIC)
441      return False
442
443    if (tgt_ranges.extra.get('uses_shared_blocks') or
444        src_ranges.extra.get('uses_shared_blocks')):
445      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_SHARED_BLOCKS)
446      return False
447
448    if tgt_ranges.extra.get('incomplete') or src_ranges.extra.get('incomplete'):
449      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_INCOMPLETE)
450      return False
451
452    if tgt_ranges.extra.get('trimmed') or src_ranges.extra.get('trimmed'):
453      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_TRIMMED)
454      return False
455
456    reason = (ImgdiffStats.USED_IMGDIFF_LARGE_APK if large_apk
457              else ImgdiffStats.USED_IMGDIFF)
458    self.imgdiff_stats.Log(name, reason)
459    return True
460
461  def Compute(self, prefix):
462    # When looking for a source file to use as the diff input for a
463    # target file, we try:
464    #   1) an exact path match if available, otherwise
465    #   2) a exact basename match if available, otherwise
466    #   3) a basename match after all runs of digits are replaced by
467    #      "#" if available, otherwise
468    #   4) we have no source for this target.
469    self.AbbreviateSourceNames()
470    self.FindTransfers()
471
472    # Find the ordering dependencies among transfers (this is O(n^2)
473    # in the number of transfers).
474    self.GenerateDigraph()
475    # Find a sequence of transfers that satisfies as many ordering
476    # dependencies as possible (heuristically).
477    self.FindVertexSequence()
478    # Fix up the ordering dependencies that the sequence didn't
479    # satisfy.
480    self.ReverseBackwardEdges()
481    self.ImproveVertexSequence()
482
483    # Ensure the runtime stash size is under the limit.
484    if common.OPTIONS.cache_size is not None:
485      self.ReviseStashSize()
486
487    # Double-check our work.
488    self.AssertSequenceGood()
489    self.AssertSha1Good()
490
491    self.ComputePatches(prefix)
492    self.WriteTransfers(prefix)
493
494    # Report the imgdiff stats.
495    if common.OPTIONS.verbose and not self.disable_imgdiff:
496      self.imgdiff_stats.Report()
497
498  def WriteTransfers(self, prefix):
499    def WriteSplitTransfers(out, style, target_blocks):
500      """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
501
502      This prevents the target size of one command from being too large; and
503      might help to avoid fsync errors on some devices."""
504
505      assert style == "new" or style == "zero"
506      blocks_limit = 1024
507      total = 0
508      while target_blocks:
509        blocks_to_write = target_blocks.first(blocks_limit)
510        out.append("%s %s\n" % (style, blocks_to_write.to_string_raw()))
511        total += blocks_to_write.size()
512        target_blocks = target_blocks.subtract(blocks_to_write)
513      return total
514
515    out = []
516    total = 0
517
518    # In BBOTA v3+, it uses the hash of the stashed blocks as the stash slot
519    # id. 'stashes' records the map from 'hash' to the ref count. The stash
520    # will be freed only if the count decrements to zero.
521    stashes = {}
522    stashed_blocks = 0
523    max_stashed_blocks = 0
524
525    for xf in self.transfers:
526
527      for _, sr in xf.stash_before:
528        sh = self.src.RangeSha1(sr)
529        if sh in stashes:
530          stashes[sh] += 1
531        else:
532          stashes[sh] = 1
533          stashed_blocks += sr.size()
534          self.touched_src_ranges = self.touched_src_ranges.union(sr)
535          out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
536
537      if stashed_blocks > max_stashed_blocks:
538        max_stashed_blocks = stashed_blocks
539
540      free_string = []
541      free_size = 0
542
543      #   <# blocks> <src ranges>
544      #     OR
545      #   <# blocks> <src ranges> <src locs> <stash refs...>
546      #     OR
547      #   <# blocks> - <stash refs...>
548
549      size = xf.src_ranges.size()
550      src_str_buffer = [str(size)]
551
552      unstashed_src_ranges = xf.src_ranges
553      mapped_stashes = []
554      for _, sr in xf.use_stash:
555        unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
556        sh = self.src.RangeSha1(sr)
557        sr = xf.src_ranges.map_within(sr)
558        mapped_stashes.append(sr)
559        assert sh in stashes
560        src_str_buffer.append("%s:%s" % (sh, sr.to_string_raw()))
561        stashes[sh] -= 1
562        if stashes[sh] == 0:
563          free_string.append("free %s\n" % (sh,))
564          free_size += sr.size()
565          stashes.pop(sh)
566
567      if unstashed_src_ranges:
568        src_str_buffer.insert(1, unstashed_src_ranges.to_string_raw())
569        if xf.use_stash:
570          mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
571          src_str_buffer.insert(2, mapped_unstashed.to_string_raw())
572          mapped_stashes.append(mapped_unstashed)
573          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
574      else:
575        src_str_buffer.insert(1, "-")
576        self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
577
578      src_str = " ".join(src_str_buffer)
579
580      # version 3+:
581      #   zero <rangeset>
582      #   new <rangeset>
583      #   erase <rangeset>
584      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
585      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
586      #   move hash <tgt rangeset> <src_str>
587
588      tgt_size = xf.tgt_ranges.size()
589
590      if xf.style == "new":
591        assert xf.tgt_ranges
592        assert tgt_size == WriteSplitTransfers(out, xf.style, xf.tgt_ranges)
593        total += tgt_size
594      elif xf.style == "move":
595        assert xf.tgt_ranges
596        assert xf.src_ranges.size() == tgt_size
597        if xf.src_ranges != xf.tgt_ranges:
598          # take into account automatic stashing of overlapping blocks
599          if xf.src_ranges.overlaps(xf.tgt_ranges):
600            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
601            if temp_stash_usage > max_stashed_blocks:
602              max_stashed_blocks = temp_stash_usage
603
604          self.touched_src_ranges = self.touched_src_ranges.union(
605              xf.src_ranges)
606
607          out.append("%s %s %s %s\n" % (
608              xf.style,
609              xf.tgt_sha1,
610              xf.tgt_ranges.to_string_raw(), src_str))
611          total += tgt_size
612      elif xf.style in ("bsdiff", "imgdiff"):
613        assert xf.tgt_ranges
614        assert xf.src_ranges
615        # take into account automatic stashing of overlapping blocks
616        if xf.src_ranges.overlaps(xf.tgt_ranges):
617          temp_stash_usage = stashed_blocks + xf.src_ranges.size()
618          if temp_stash_usage > max_stashed_blocks:
619            max_stashed_blocks = temp_stash_usage
620
621        self.touched_src_ranges = self.touched_src_ranges.union(xf.src_ranges)
622
623        out.append("%s %d %d %s %s %s %s\n" % (
624            xf.style,
625            xf.patch_start, xf.patch_len,
626            xf.src_sha1,
627            xf.tgt_sha1,
628            xf.tgt_ranges.to_string_raw(), src_str))
629        total += tgt_size
630      elif xf.style == "zero":
631        assert xf.tgt_ranges
632        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
633        assert WriteSplitTransfers(out, xf.style, to_zero) == to_zero.size()
634        total += to_zero.size()
635      else:
636        raise ValueError("unknown transfer style '%s'\n" % xf.style)
637
638      if free_string:
639        out.append("".join(free_string))
640        stashed_blocks -= free_size
641
642      if common.OPTIONS.cache_size is not None:
643        # Sanity check: abort if we're going to need more stash space than
644        # the allowed size (cache_size * threshold). There are two purposes
645        # of having a threshold here. a) Part of the cache may have been
646        # occupied by some recovery logs. b) It will buy us some time to deal
647        # with the oversize issue.
648        cache_size = common.OPTIONS.cache_size
649        stash_threshold = common.OPTIONS.stash_threshold
650        max_allowed = cache_size * stash_threshold
651        assert max_stashed_blocks * self.tgt.blocksize <= max_allowed, \
652               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
653                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
654                   self.tgt.blocksize, max_allowed, cache_size,
655                   stash_threshold)
656
657    self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)
658
659    # Zero out extended blocks as a workaround for bug 20881595.
660    if self.tgt.extended:
661      assert (WriteSplitTransfers(out, "zero", self.tgt.extended) ==
662              self.tgt.extended.size())
663      total += self.tgt.extended.size()
664
665    # We erase all the blocks on the partition that a) don't contain useful
666    # data in the new image; b) will not be touched by dm-verity. Out of those
667    # blocks, we erase the ones that won't be used in this update at the
668    # beginning of an update. The rest would be erased at the end. This is to
669    # work around the eMMC issue observed on some devices, which may otherwise
670    # get starving for clean blocks and thus fail the update. (b/28347095)
671    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
672    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
673    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
674
675    erase_first = new_dontcare.subtract(self.touched_src_ranges)
676    if erase_first:
677      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
678
679    erase_last = new_dontcare.subtract(erase_first)
680    if erase_last:
681      out.append("erase %s\n" % (erase_last.to_string_raw(),))
682
683    out.insert(0, "%d\n" % (self.version,))   # format version number
684    out.insert(1, "%d\n" % (total,))
685    # v3+: the number of stash slots is unused.
686    out.insert(2, "0\n")
687    out.insert(3, str(max_stashed_blocks) + "\n")
688
689    with open(prefix + ".transfer.list", "wb") as f:
690      for i in out:
691        f.write(i)
692
693    self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
694    OPTIONS = common.OPTIONS
695    if OPTIONS.cache_size is not None:
696      max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
697      print("max stashed blocks: %d  (%d bytes), "
698            "limit: %d bytes (%.2f%%)\n" % (
699                max_stashed_blocks, self._max_stashed_size, max_allowed,
700                self._max_stashed_size * 100.0 / max_allowed))
701    else:
702      print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
703          max_stashed_blocks, self._max_stashed_size))
704
705  def ReviseStashSize(self):
706    print("Revising stash size...")
707    stash_map = {}
708
709    # Create the map between a stash and its def/use points. For example, for a
710    # given stash of (raw_id, sr), stash_map[raw_id] = (sr, def_cmd, use_cmd).
711    for xf in self.transfers:
712      # Command xf defines (stores) all the stashes in stash_before.
713      for stash_raw_id, sr in xf.stash_before:
714        stash_map[stash_raw_id] = (sr, xf)
715
716      # Record all the stashes command xf uses.
717      for stash_raw_id, _ in xf.use_stash:
718        stash_map[stash_raw_id] += (xf,)
719
720    # Compute the maximum blocks available for stash based on /cache size and
721    # the threshold.
722    cache_size = common.OPTIONS.cache_size
723    stash_threshold = common.OPTIONS.stash_threshold
724    max_allowed = cache_size * stash_threshold / self.tgt.blocksize
725
726    # See the comments for 'stashes' in WriteTransfers().
727    stashes = {}
728    stashed_blocks = 0
729    new_blocks = 0
730
731    # Now go through all the commands. Compute the required stash size on the
732    # fly. If a command requires excess stash than available, it deletes the
733    # stash by replacing the command that uses the stash with a "new" command
734    # instead.
735    for xf in self.transfers:
736      replaced_cmds = []
737
738      # xf.stash_before generates explicit stash commands.
739      for stash_raw_id, sr in xf.stash_before:
740        # Check the post-command stashed_blocks.
741        stashed_blocks_after = stashed_blocks
742        sh = self.src.RangeSha1(sr)
743        if sh not in stashes:
744          stashed_blocks_after += sr.size()
745
746        if stashed_blocks_after > max_allowed:
747          # We cannot stash this one for a later command. Find out the command
748          # that will use this stash and replace the command with "new".
749          use_cmd = stash_map[stash_raw_id][2]
750          replaced_cmds.append(use_cmd)
751          print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
752        else:
753          # Update the stashes map.
754          if sh in stashes:
755            stashes[sh] += 1
756          else:
757            stashes[sh] = 1
758          stashed_blocks = stashed_blocks_after
759
760      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
761      # ComputePatches(), they both have the style of "diff".
762      if xf.style == "diff":
763        assert xf.tgt_ranges and xf.src_ranges
764        if xf.src_ranges.overlaps(xf.tgt_ranges):
765          if stashed_blocks + xf.src_ranges.size() > max_allowed:
766            replaced_cmds.append(xf)
767            print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
768
769      # Replace the commands in replaced_cmds with "new"s.
770      for cmd in replaced_cmds:
771        # It no longer uses any commands in "use_stash". Remove the def points
772        # for all those stashes.
773        for stash_raw_id, sr in cmd.use_stash:
774          def_cmd = stash_map[stash_raw_id][1]
775          assert (stash_raw_id, sr) in def_cmd.stash_before
776          def_cmd.stash_before.remove((stash_raw_id, sr))
777
778        # Add up blocks that violates space limit and print total number to
779        # screen later.
780        new_blocks += cmd.tgt_ranges.size()
781        cmd.ConvertToNew()
782
783      # xf.use_stash may generate free commands.
784      for _, sr in xf.use_stash:
785        sh = self.src.RangeSha1(sr)
786        assert sh in stashes
787        stashes[sh] -= 1
788        if stashes[sh] == 0:
789          stashed_blocks -= sr.size()
790          stashes.pop(sh)
791
792    num_of_bytes = new_blocks * self.tgt.blocksize
793    print("  Total %d blocks (%d bytes) are packed as new blocks due to "
794          "insufficient cache size." % (new_blocks, num_of_bytes))
795    return new_blocks
796
797  def ComputePatches(self, prefix):
798    print("Reticulating splines...")
799    diff_queue = []
800    patch_num = 0
801    with open(prefix + ".new.dat", "wb") as new_f:
802      for index, xf in enumerate(self.transfers):
803        if xf.style == "zero":
804          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
805          print("%10d %10d (%6.2f%%) %7s %s %s" % (
806              tgt_size, tgt_size, 100.0, xf.style, xf.tgt_name,
807              str(xf.tgt_ranges)))
808
809        elif xf.style == "new":
810          self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
811          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
812          print("%10d %10d (%6.2f%%) %7s %s %s" % (
813              tgt_size, tgt_size, 100.0, xf.style,
814              xf.tgt_name, str(xf.tgt_ranges)))
815
816        elif xf.style == "diff":
817          # We can't compare src and tgt directly because they may have
818          # the same content but be broken up into blocks differently, eg:
819          #
820          #    ["he", "llo"]  vs  ["h", "ello"]
821          #
822          # We want those to compare equal, ideally without having to
823          # actually concatenate the strings (these may be tens of
824          # megabytes).
825          if xf.src_sha1 == xf.tgt_sha1:
826            # These are identical; we don't need to generate a patch,
827            # just issue copy commands on the device.
828            xf.style = "move"
829            xf.patch = None
830            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
831            if xf.src_ranges != xf.tgt_ranges:
832              print("%10d %10d (%6.2f%%) %7s %s %s (from %s)" % (
833                  tgt_size, tgt_size, 100.0, xf.style,
834                  xf.tgt_name if xf.tgt_name == xf.src_name else (
835                      xf.tgt_name + " (from " + xf.src_name + ")"),
836                  str(xf.tgt_ranges), str(xf.src_ranges)))
837          else:
838            if xf.patch:
839              # We have already generated the patch with imgdiff. Check if the
840              # transfer is intact.
841              assert not self.disable_imgdiff
842              imgdiff = True
843              if (xf.src_ranges.extra.get('trimmed') or
844                  xf.tgt_ranges.extra.get('trimmed')):
845                imgdiff = False
846                xf.patch = None
847            else:
848              imgdiff = self.CanUseImgdiff(
849                  xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
850            xf.style = "imgdiff" if imgdiff else "bsdiff"
851            diff_queue.append((index, imgdiff, patch_num))
852            patch_num += 1
853
854        else:
855          assert False, "unknown style " + xf.style
856
857    if diff_queue:
858      if self.threads > 1:
859        print("Computing patches (using %d threads)..." % (self.threads,))
860      else:
861        print("Computing patches...")
862
863      diff_total = len(diff_queue)
864      patches = [None] * diff_total
865      error_messages = []
866
867      # Using multiprocessing doesn't give additional benefits, due to the
868      # pattern of the code. The diffing work is done by subprocess.call, which
869      # already runs in a separate process (not affected much by the GIL -
870      # Global Interpreter Lock). Using multiprocess also requires either a)
871      # writing the diff input files in the main process before forking, or b)
872      # reopening the image file (SparseImage) in the worker processes. Doing
873      # neither of them further improves the performance.
874      lock = threading.Lock()
875      def diff_worker():
876        while True:
877          with lock:
878            if not diff_queue:
879              return
880            xf_index, imgdiff, patch_index = diff_queue.pop()
881            xf = self.transfers[xf_index]
882
883            if sys.stdout.isatty():
884              diff_left = len(diff_queue)
885              progress = (diff_total - diff_left) * 100 / diff_total
886              # '\033[K' is to clear to EOL.
887              print(' [%3d%%] %s\033[K' % (progress, xf.tgt_name), end='\r')
888              sys.stdout.flush()
889
890          patch = xf.patch
891          if not patch:
892            src_ranges = xf.src_ranges
893            tgt_ranges = xf.tgt_ranges
894
895            src_file = common.MakeTempFile(prefix="src-")
896            with open(src_file, "wb") as fd:
897              self.src.WriteRangeDataToFd(src_ranges, fd)
898
899            tgt_file = common.MakeTempFile(prefix="tgt-")
900            with open(tgt_file, "wb") as fd:
901              self.tgt.WriteRangeDataToFd(tgt_ranges, fd)
902
903            message = []
904            try:
905              patch = compute_patch(src_file, tgt_file, imgdiff)
906            except ValueError as e:
907              message.append(
908                  "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
909                      "imgdiff" if imgdiff else "bsdiff",
910                      xf.tgt_name if xf.tgt_name == xf.src_name else
911                      xf.tgt_name + " (from " + xf.src_name + ")",
912                      xf.tgt_ranges, xf.src_ranges, e.message))
913            if message:
914              with lock:
915                error_messages.extend(message)
916
917          with lock:
918            patches[patch_index] = (xf_index, patch)
919
920      threads = [threading.Thread(target=diff_worker)
921                 for _ in range(self.threads)]
922      for th in threads:
923        th.start()
924      while threads:
925        threads.pop().join()
926
927      if sys.stdout.isatty():
928        print('\n')
929
930      if error_messages:
931        print('ERROR:')
932        print('\n'.join(error_messages))
933        print('\n\n\n')
934        sys.exit(1)
935    else:
936      patches = []
937
938    offset = 0
939    with open(prefix + ".patch.dat", "wb") as patch_fd:
940      for index, patch in patches:
941        xf = self.transfers[index]
942        xf.patch_len = len(patch)
943        xf.patch_start = offset
944        offset += xf.patch_len
945        patch_fd.write(patch)
946
947        if common.OPTIONS.verbose:
948          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
949          print("%10d %10d (%6.2f%%) %7s %s %s %s" % (
950              xf.patch_len, tgt_size, xf.patch_len * 100.0 / tgt_size,
951              xf.style,
952              xf.tgt_name if xf.tgt_name == xf.src_name else (
953                  xf.tgt_name + " (from " + xf.src_name + ")"),
954              xf.tgt_ranges, xf.src_ranges))
955
956  def AssertSha1Good(self):
957    """Check the SHA-1 of the src & tgt blocks in the transfer list.
958
959    Double check the SHA-1 value to avoid the issue in b/71908713, where
960    SparseImage.RangeSha1() messed up with the hash calculation in multi-thread
961    environment. That specific problem has been fixed by protecting the
962    underlying generator function 'SparseImage._GetRangeData()' with lock.
963    """
964    for xf in self.transfers:
965      tgt_sha1 = self.tgt.RangeSha1(xf.tgt_ranges)
966      assert xf.tgt_sha1 == tgt_sha1
967      if xf.style == "diff":
968        src_sha1 = self.src.RangeSha1(xf.src_ranges)
969        assert xf.src_sha1 == src_sha1
970
971  def AssertSequenceGood(self):
972    # Simulate the sequences of transfers we will output, and check that:
973    # - we never read a block after writing it, and
974    # - we write every block we care about exactly once.
975
976    # Start with no blocks having been touched yet.
977    touched = array.array("B", "\0" * self.tgt.total_blocks)
978
979    # Imagine processing the transfers in order.
980    for xf in self.transfers:
981      # Check that the input blocks for this transfer haven't yet been touched.
982
983      x = xf.src_ranges
984      for _, sr in xf.use_stash:
985        x = x.subtract(sr)
986
987      for s, e in x:
988        # Source image could be larger. Don't check the blocks that are in the
989        # source image only. Since they are not in 'touched', and won't ever
990        # be touched.
991        for i in range(s, min(e, self.tgt.total_blocks)):
992          assert touched[i] == 0
993
994      # Check that the output blocks for this transfer haven't yet
995      # been touched, and touch all the blocks written by this
996      # transfer.
997      for s, e in xf.tgt_ranges:
998        for i in range(s, e):
999          assert touched[i] == 0
1000          touched[i] = 1
1001
1002    # Check that we've written every target block.
1003    for s, e in self.tgt.care_map:
1004      for i in range(s, e):
1005        assert touched[i] == 1
1006
1007  def ImproveVertexSequence(self):
1008    print("Improving vertex order...")
1009
1010    # At this point our digraph is acyclic; we reversed any edges that
1011    # were backwards in the heuristically-generated sequence.  The
1012    # previously-generated order is still acceptable, but we hope to
1013    # find a better order that needs less memory for stashed data.
1014    # Now we do a topological sort to generate a new vertex order,
1015    # using a greedy algorithm to choose which vertex goes next
1016    # whenever we have a choice.
1017
1018    # Make a copy of the edge set; this copy will get destroyed by the
1019    # algorithm.
1020    for xf in self.transfers:
1021      xf.incoming = xf.goes_after.copy()
1022      xf.outgoing = xf.goes_before.copy()
1023
1024    L = []   # the new vertex order
1025
1026    # S is the set of sources in the remaining graph; we always choose
1027    # the one that leaves the least amount of stashed data after it's
1028    # executed.
1029    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
1030         if not u.incoming]
1031    heapq.heapify(S)
1032
1033    while S:
1034      _, _, xf = heapq.heappop(S)
1035      L.append(xf)
1036      for u in xf.outgoing:
1037        del u.incoming[xf]
1038        if not u.incoming:
1039          heapq.heappush(S, (u.NetStashChange(), u.order, u))
1040
1041    # if this fails then our graph had a cycle.
1042    assert len(L) == len(self.transfers)
1043
1044    self.transfers = L
1045    for i, xf in enumerate(L):
1046      xf.order = i
1047
1048  def RemoveBackwardEdges(self):
1049    print("Removing backward edges...")
1050    in_order = 0
1051    out_of_order = 0
1052    lost_source = 0
1053
1054    for xf in self.transfers:
1055      lost = 0
1056      size = xf.src_ranges.size()
1057      for u in xf.goes_before:
1058        # xf should go before u
1059        if xf.order < u.order:
1060          # it does, hurray!
1061          in_order += 1
1062        else:
1063          # it doesn't, boo.  trim the blocks that u writes from xf's
1064          # source, so that xf can go after u.
1065          out_of_order += 1
1066          assert xf.src_ranges.overlaps(u.tgt_ranges)
1067          xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
1068          xf.src_ranges.extra['trimmed'] = True
1069
1070      if xf.style == "diff" and not xf.src_ranges:
1071        # nothing left to diff from; treat as new data
1072        xf.style = "new"
1073
1074      lost = size - xf.src_ranges.size()
1075      lost_source += lost
1076
1077    print(("  %d/%d dependencies (%.2f%%) were violated; "
1078           "%d source blocks removed.") %
1079          (out_of_order, in_order + out_of_order,
1080           (out_of_order * 100.0 / (in_order + out_of_order))
1081           if (in_order + out_of_order) else 0.0,
1082           lost_source))
1083
1084  def ReverseBackwardEdges(self):
1085    """Reverse unsatisfying edges and compute pairs of stashed blocks.
1086
1087    For each transfer, make sure it properly stashes the blocks it touches and
1088    will be used by later transfers. It uses pairs of (stash_raw_id, range) to
1089    record the blocks to be stashed. 'stash_raw_id' is an id that uniquely
1090    identifies each pair. Note that for the same range (e.g. RangeSet("1-5")),
1091    it is possible to have multiple pairs with different 'stash_raw_id's. Each
1092    'stash_raw_id' will be consumed by one transfer. In BBOTA v3+, identical
1093    blocks will be written to the same stash slot in WriteTransfers().
1094    """
1095
1096    print("Reversing backward edges...")
1097    in_order = 0
1098    out_of_order = 0
1099    stash_raw_id = 0
1100    stash_size = 0
1101
1102    for xf in self.transfers:
1103      for u in xf.goes_before.copy():
1104        # xf should go before u
1105        if xf.order < u.order:
1106          # it does, hurray!
1107          in_order += 1
1108        else:
1109          # it doesn't, boo.  modify u to stash the blocks that it
1110          # writes that xf wants to read, and then require u to go
1111          # before xf.
1112          out_of_order += 1
1113
1114          overlap = xf.src_ranges.intersect(u.tgt_ranges)
1115          assert overlap
1116
1117          u.stash_before.append((stash_raw_id, overlap))
1118          xf.use_stash.append((stash_raw_id, overlap))
1119          stash_raw_id += 1
1120          stash_size += overlap.size()
1121
1122          # reverse the edge direction; now xf must go after u
1123          del xf.goes_before[u]
1124          del u.goes_after[xf]
1125          xf.goes_after[u] = None    # value doesn't matter
1126          u.goes_before[xf] = None
1127
1128    print(("  %d/%d dependencies (%.2f%%) were violated; "
1129           "%d source blocks stashed.") %
1130          (out_of_order, in_order + out_of_order,
1131           (out_of_order * 100.0 / (in_order + out_of_order))
1132           if (in_order + out_of_order) else 0.0,
1133           stash_size))
1134
1135  def FindVertexSequence(self):
1136    print("Finding vertex sequence...")
1137
1138    # This is based on "A Fast & Effective Heuristic for the Feedback
1139    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
1140    # it as starting with the digraph G and moving all the vertices to
1141    # be on a horizontal line in some order, trying to minimize the
1142    # number of edges that end up pointing to the left.  Left-pointing
1143    # edges will get removed to turn the digraph into a DAG.  In this
1144    # case each edge has a weight which is the number of source blocks
1145    # we'll lose if that edge is removed; we try to minimize the total
1146    # weight rather than just the number of edges.
1147
1148    # Make a copy of the edge set; this copy will get destroyed by the
1149    # algorithm.
1150    for xf in self.transfers:
1151      xf.incoming = xf.goes_after.copy()
1152      xf.outgoing = xf.goes_before.copy()
1153      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
1154
1155    # We use an OrderedDict instead of just a set so that the output
1156    # is repeatable; otherwise it would depend on the hash values of
1157    # the transfer objects.
1158    G = OrderedDict()
1159    for xf in self.transfers:
1160      G[xf] = None
1161    s1 = deque()  # the left side of the sequence, built from left to right
1162    s2 = deque()  # the right side of the sequence, built from right to left
1163
1164    heap = []
1165    for xf in self.transfers:
1166      xf.heap_item = HeapItem(xf)
1167      heap.append(xf.heap_item)
1168    heapq.heapify(heap)
1169
1170    # Use OrderedDict() instead of set() to preserve the insertion order. Need
1171    # to use 'sinks[key] = None' to add key into the set. sinks will look like
1172    # { key1: None, key2: None, ... }.
1173    sinks = OrderedDict.fromkeys(u for u in G if not u.outgoing)
1174    sources = OrderedDict.fromkeys(u for u in G if not u.incoming)
1175
1176    def adjust_score(iu, delta):
1177      iu.score += delta
1178      iu.heap_item.clear()
1179      iu.heap_item = HeapItem(iu)
1180      heapq.heappush(heap, iu.heap_item)
1181
1182    while G:
1183      # Put all sinks at the end of the sequence.
1184      while sinks:
1185        new_sinks = OrderedDict()
1186        for u in sinks:
1187          if u not in G:
1188            continue
1189          s2.appendleft(u)
1190          del G[u]
1191          for iu in u.incoming:
1192            adjust_score(iu, -iu.outgoing.pop(u))
1193            if not iu.outgoing:
1194              new_sinks[iu] = None
1195        sinks = new_sinks
1196
1197      # Put all the sources at the beginning of the sequence.
1198      while sources:
1199        new_sources = OrderedDict()
1200        for u in sources:
1201          if u not in G:
1202            continue
1203          s1.append(u)
1204          del G[u]
1205          for iu in u.outgoing:
1206            adjust_score(iu, +iu.incoming.pop(u))
1207            if not iu.incoming:
1208              new_sources[iu] = None
1209        sources = new_sources
1210
1211      if not G:
1212        break
1213
1214      # Find the "best" vertex to put next.  "Best" is the one that
1215      # maximizes the net difference in source blocks saved we get by
1216      # pretending it's a source rather than a sink.
1217
1218      while True:
1219        u = heapq.heappop(heap)
1220        if u and u.item in G:
1221          u = u.item
1222          break
1223
1224      s1.append(u)
1225      del G[u]
1226      for iu in u.outgoing:
1227        adjust_score(iu, +iu.incoming.pop(u))
1228        if not iu.incoming:
1229          sources[iu] = None
1230
1231      for iu in u.incoming:
1232        adjust_score(iu, -iu.outgoing.pop(u))
1233        if not iu.outgoing:
1234          sinks[iu] = None
1235
1236    # Now record the sequence in the 'order' field of each transfer,
1237    # and by rearranging self.transfers to be in the chosen sequence.
1238
1239    new_transfers = []
1240    for x in itertools.chain(s1, s2):
1241      x.order = len(new_transfers)
1242      new_transfers.append(x)
1243      del x.incoming
1244      del x.outgoing
1245
1246    self.transfers = new_transfers
1247
1248  def GenerateDigraph(self):
1249    print("Generating digraph...")
1250
1251    # Each item of source_ranges will be:
1252    #   - None, if that block is not used as a source,
1253    #   - an ordered set of transfers.
1254    source_ranges = []
1255    for b in self.transfers:
1256      for s, e in b.src_ranges:
1257        if e > len(source_ranges):
1258          source_ranges.extend([None] * (e-len(source_ranges)))
1259        for i in range(s, e):
1260          if source_ranges[i] is None:
1261            source_ranges[i] = OrderedDict.fromkeys([b])
1262          else:
1263            source_ranges[i][b] = None
1264
1265    for a in self.transfers:
1266      intersections = OrderedDict()
1267      for s, e in a.tgt_ranges:
1268        for i in range(s, e):
1269          if i >= len(source_ranges):
1270            break
1271          # Add all the Transfers in source_ranges[i] to the (ordered) set.
1272          if source_ranges[i] is not None:
1273            for j in source_ranges[i]:
1274              intersections[j] = None
1275
1276      for b in intersections:
1277        if a is b:
1278          continue
1279
1280        # If the blocks written by A are read by B, then B needs to go before A.
1281        i = a.tgt_ranges.intersect(b.src_ranges)
1282        if i:
1283          if b.src_name == "__ZERO":
1284            # the cost of removing source blocks for the __ZERO domain
1285            # is (nearly) zero.
1286            size = 0
1287          else:
1288            size = i.size()
1289          b.goes_before[a] = size
1290          a.goes_after[b] = size
1291
1292  def FindTransfers(self):
1293    """Parse the file_map to generate all the transfers."""
1294
1295    def AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1296                                             src_ranges, style, by_id):
1297      """Add one or multiple Transfer()s by splitting large files.
1298
1299      For BBOTA v3, we need to stash source blocks for resumable feature.
1300      However, with the growth of file size and the shrink of the cache
1301      partition source blocks are too large to be stashed. If a file occupies
1302      too many blocks, we split it into smaller pieces by getting multiple
1303      Transfer()s.
1304
1305      The downside is that after splitting, we may increase the package size
1306      since the split pieces don't align well. According to our experiments,
1307      1/8 of the cache size as the per-piece limit appears to be optimal.
1308      Compared to the fixed 1024-block limit, it reduces the overall package
1309      size by 30% for volantis, and 20% for angler and bullhead."""
1310
1311      pieces = 0
1312      while (tgt_ranges.size() > max_blocks_per_transfer and
1313             src_ranges.size() > max_blocks_per_transfer):
1314        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1315        src_split_name = "%s-%d" % (src_name, pieces)
1316        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1317        src_first = src_ranges.first(max_blocks_per_transfer)
1318
1319        Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
1320                 self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
1321                 style, by_id)
1322
1323        tgt_ranges = tgt_ranges.subtract(tgt_first)
1324        src_ranges = src_ranges.subtract(src_first)
1325        pieces += 1
1326
1327      # Handle remaining blocks.
1328      if tgt_ranges.size() or src_ranges.size():
1329        # Must be both non-empty.
1330        assert tgt_ranges.size() and src_ranges.size()
1331        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1332        src_split_name = "%s-%d" % (src_name, pieces)
1333        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
1334                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1335                 style, by_id)
1336
1337    def AddSplitTransfers(tgt_name, src_name, tgt_ranges, src_ranges, style,
1338                          by_id):
1339      """Find all the zip files and split the others with a fixed chunk size.
1340
1341      This function will construct a list of zip archives, which will later be
1342      split by imgdiff to reduce the final patch size. For the other files,
1343      we will plainly split them based on a fixed chunk size with the potential
1344      patch size penalty.
1345      """
1346
1347      assert style == "diff"
1348
1349      # Change nothing for small files.
1350      if (tgt_ranges.size() <= max_blocks_per_transfer and
1351          src_ranges.size() <= max_blocks_per_transfer):
1352        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1353                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1354                 style, by_id)
1355        return
1356
1357      # Split large APKs with imgdiff, if possible. We're intentionally checking
1358      # file types one more time (CanUseImgdiff() checks that as well), before
1359      # calling the costly RangeSha1()s.
1360      if (self.FileTypeSupportedByImgdiff(tgt_name) and
1361          self.tgt.RangeSha1(tgt_ranges) != self.src.RangeSha1(src_ranges)):
1362        if self.CanUseImgdiff(tgt_name, tgt_ranges, src_ranges, True):
1363          large_apks.append((tgt_name, src_name, tgt_ranges, src_ranges))
1364          return
1365
1366      AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1367                                           src_ranges, style, by_id)
1368
1369    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1370                    split=False):
1371      """Wrapper function for adding a Transfer()."""
1372
1373      # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
1374      # otherwise add the Transfer() as is.
1375      if style != "diff" or not split:
1376        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1377                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1378                 style, by_id)
1379        return
1380
1381      # Handle .odex files specially to analyze the block-wise difference. If
1382      # most of the blocks are identical with only few changes (e.g. header),
1383      # we will patch the changed blocks only. This avoids stashing unchanged
1384      # blocks while patching. We limit the analysis to files without size
1385      # changes only. This is to avoid sacrificing the OTA generation cost too
1386      # much.
1387      if (tgt_name.split(".")[-1].lower() == 'odex' and
1388          tgt_ranges.size() == src_ranges.size()):
1389
1390        # 0.5 threshold can be further tuned. The tradeoff is: if only very
1391        # few blocks remain identical, we lose the opportunity to use imgdiff
1392        # that may have better compression ratio than bsdiff.
1393        crop_threshold = 0.5
1394
1395        tgt_skipped = RangeSet()
1396        src_skipped = RangeSet()
1397        tgt_size = tgt_ranges.size()
1398        tgt_changed = 0
1399        for src_block, tgt_block in zip(src_ranges.next_item(),
1400                                        tgt_ranges.next_item()):
1401          src_rs = RangeSet(str(src_block))
1402          tgt_rs = RangeSet(str(tgt_block))
1403          if self.src.ReadRangeSet(src_rs) == self.tgt.ReadRangeSet(tgt_rs):
1404            tgt_skipped = tgt_skipped.union(tgt_rs)
1405            src_skipped = src_skipped.union(src_rs)
1406          else:
1407            tgt_changed += tgt_rs.size()
1408
1409          # Terminate early if no clear sign of benefits.
1410          if tgt_changed > tgt_size * crop_threshold:
1411            break
1412
1413        if tgt_changed < tgt_size * crop_threshold:
1414          assert tgt_changed + tgt_skipped.size() == tgt_size
1415          print('%10d %10d (%6.2f%%) %s' % (
1416              tgt_skipped.size(), tgt_size,
1417              tgt_skipped.size() * 100.0 / tgt_size, tgt_name))
1418          AddSplitTransfers(
1419              "%s-skipped" % (tgt_name,),
1420              "%s-skipped" % (src_name,),
1421              tgt_skipped, src_skipped, style, by_id)
1422
1423          # Intentionally change the file extension to avoid being imgdiff'd as
1424          # the files are no longer in their original format.
1425          tgt_name = "%s-cropped" % (tgt_name,)
1426          src_name = "%s-cropped" % (src_name,)
1427          tgt_ranges = tgt_ranges.subtract(tgt_skipped)
1428          src_ranges = src_ranges.subtract(src_skipped)
1429
1430          # Possibly having no changed blocks.
1431          if not tgt_ranges:
1432            return
1433
1434      # Add the transfer(s).
1435      AddSplitTransfers(
1436          tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1437
1438    def ParseAndValidateSplitInfo(patch_size, tgt_ranges, src_ranges,
1439                                  split_info):
1440      """Parse the split_info and return a list of info tuples.
1441
1442      Args:
1443        patch_size: total size of the patch file.
1444        tgt_ranges: Ranges of the target file within the original image.
1445        src_ranges: Ranges of the source file within the original image.
1446        split_info format:
1447          imgdiff version#
1448          count of pieces
1449          <patch_size_1> <tgt_size_1> <src_ranges_1>
1450          ...
1451          <patch_size_n> <tgt_size_n> <src_ranges_n>
1452
1453      Returns:
1454        [patch_start, patch_len, split_tgt_ranges, split_src_ranges]
1455      """
1456
1457      version = int(split_info[0])
1458      assert version == 2
1459      count = int(split_info[1])
1460      assert len(split_info) - 2 == count
1461
1462      split_info_list = []
1463      patch_start = 0
1464      tgt_remain = copy.deepcopy(tgt_ranges)
1465      # each line has the format <patch_size>, <tgt_size>, <src_ranges>
1466      for line in split_info[2:]:
1467        info = line.split()
1468        assert len(info) == 3
1469        patch_length = int(info[0])
1470
1471        split_tgt_size = int(info[1])
1472        assert split_tgt_size % 4096 == 0
1473        assert split_tgt_size / 4096 <= tgt_remain.size()
1474        split_tgt_ranges = tgt_remain.first(split_tgt_size / 4096)
1475        tgt_remain = tgt_remain.subtract(split_tgt_ranges)
1476
1477        # Find the split_src_ranges within the image file from its relative
1478        # position in file.
1479        split_src_indices = RangeSet.parse_raw(info[2])
1480        split_src_ranges = RangeSet()
1481        for r in split_src_indices:
1482          curr_range = src_ranges.first(r[1]).subtract(src_ranges.first(r[0]))
1483          assert not split_src_ranges.overlaps(curr_range)
1484          split_src_ranges = split_src_ranges.union(curr_range)
1485
1486        split_info_list.append((patch_start, patch_length,
1487                                split_tgt_ranges, split_src_ranges))
1488        patch_start += patch_length
1489
1490      # Check that the sizes of all the split pieces add up to the final file
1491      # size for patch and target.
1492      assert tgt_remain.size() == 0
1493      assert patch_start == patch_size
1494      return split_info_list
1495
1496    def SplitLargeApks():
1497      """Split the large apks files.
1498
1499      Example: Chrome.apk will be split into
1500        src-0: Chrome.apk-0, tgt-0: Chrome.apk-0
1501        src-1: Chrome.apk-1, tgt-1: Chrome.apk-1
1502        ...
1503
1504      After the split, the target pieces are continuous and block aligned; and
1505      the source pieces are mutually exclusive. During the split, we also
1506      generate and save the image patch between src-X & tgt-X. This patch will
1507      be valid because the block ranges of src-X & tgt-X will always stay the
1508      same afterwards; but there's a chance we don't use the patch if we
1509      convert the "diff" command into "new" or "move" later.
1510      """
1511
1512      while True:
1513        with transfer_lock:
1514          if not large_apks:
1515            return
1516          tgt_name, src_name, tgt_ranges, src_ranges = large_apks.pop(0)
1517
1518        src_file = common.MakeTempFile(prefix="src-")
1519        tgt_file = common.MakeTempFile(prefix="tgt-")
1520        with open(src_file, "wb") as src_fd:
1521          self.src.WriteRangeDataToFd(src_ranges, src_fd)
1522        with open(tgt_file, "wb") as tgt_fd:
1523          self.tgt.WriteRangeDataToFd(tgt_ranges, tgt_fd)
1524
1525        patch_file = common.MakeTempFile(prefix="patch-")
1526        patch_info_file = common.MakeTempFile(prefix="split_info-")
1527        cmd = ["imgdiff", "-z",
1528               "--block-limit={}".format(max_blocks_per_transfer),
1529               "--split-info=" + patch_info_file,
1530               src_file, tgt_file, patch_file]
1531        p = common.Run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
1532        imgdiff_output, _ = p.communicate()
1533        assert p.returncode == 0, \
1534            "Failed to create imgdiff patch between {} and {}:\n{}".format(
1535                src_name, tgt_name, imgdiff_output)
1536
1537        with open(patch_info_file) as patch_info:
1538          lines = patch_info.readlines()
1539
1540        patch_size_total = os.path.getsize(patch_file)
1541        split_info_list = ParseAndValidateSplitInfo(patch_size_total,
1542                                                    tgt_ranges, src_ranges,
1543                                                    lines)
1544        for index, (patch_start, patch_length, split_tgt_ranges,
1545                    split_src_ranges) in enumerate(split_info_list):
1546          with open(patch_file) as f:
1547            f.seek(patch_start)
1548            patch_content = f.read(patch_length)
1549
1550          split_src_name = "{}-{}".format(src_name, index)
1551          split_tgt_name = "{}-{}".format(tgt_name, index)
1552          split_large_apks.append((split_tgt_name,
1553                                   split_src_name,
1554                                   split_tgt_ranges,
1555                                   split_src_ranges,
1556                                   patch_content))
1557
1558    print("Finding transfers...")
1559
1560    large_apks = []
1561    split_large_apks = []
1562    cache_size = common.OPTIONS.cache_size
1563    split_threshold = 0.125
1564    max_blocks_per_transfer = int(cache_size * split_threshold /
1565                                  self.tgt.blocksize)
1566    empty = RangeSet()
1567    for tgt_fn, tgt_ranges in sorted(self.tgt.file_map.items()):
1568      if tgt_fn == "__ZERO":
1569        # the special "__ZERO" domain is all the blocks not contained
1570        # in any file and that are filled with zeros.  We have a
1571        # special transfer style for zero blocks.
1572        src_ranges = self.src.file_map.get("__ZERO", empty)
1573        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1574                    "zero", self.transfers)
1575        continue
1576
1577      elif tgt_fn == "__COPY":
1578        # "__COPY" domain includes all the blocks not contained in any
1579        # file and that need to be copied unconditionally to the target.
1580        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1581        continue
1582
1583      elif tgt_fn in self.src.file_map:
1584        # Look for an exact pathname match in the source.
1585        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1586                    "diff", self.transfers, True)
1587        continue
1588
1589      b = os.path.basename(tgt_fn)
1590      if b in self.src_basenames:
1591        # Look for an exact basename match in the source.
1592        src_fn = self.src_basenames[b]
1593        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1594                    "diff", self.transfers, True)
1595        continue
1596
1597      b = re.sub("[0-9]+", "#", b)
1598      if b in self.src_numpatterns:
1599        # Look for a 'number pattern' match (a basename match after
1600        # all runs of digits are replaced by "#").  (This is useful
1601        # for .so files that contain version numbers in the filename
1602        # that get bumped.)
1603        src_fn = self.src_numpatterns[b]
1604        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1605                    "diff", self.transfers, True)
1606        continue
1607
1608      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1609
1610    transfer_lock = threading.Lock()
1611    threads = [threading.Thread(target=SplitLargeApks)
1612               for _ in range(self.threads)]
1613    for th in threads:
1614      th.start()
1615    while threads:
1616      threads.pop().join()
1617
1618    # Sort the split transfers for large apks to generate a determinate package.
1619    split_large_apks.sort()
1620    for (tgt_name, src_name, tgt_ranges, src_ranges,
1621         patch) in split_large_apks:
1622      transfer_split = Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1623                                self.tgt.RangeSha1(tgt_ranges),
1624                                self.src.RangeSha1(src_ranges),
1625                                "diff", self.transfers)
1626      transfer_split.patch = patch
1627
1628  def AbbreviateSourceNames(self):
1629    for k in self.src.file_map.keys():
1630      b = os.path.basename(k)
1631      self.src_basenames[b] = k
1632      b = re.sub("[0-9]+", "#", b)
1633      self.src_numpatterns[b] = k
1634
1635  @staticmethod
1636  def AssertPartition(total, seq):
1637    """Assert that all the RangeSets in 'seq' form a partition of the
1638    'total' RangeSet (ie, they are nonintersecting and their union
1639    equals 'total')."""
1640
1641    so_far = RangeSet()
1642    for i in seq:
1643      assert not so_far.overlaps(i)
1644      so_far = so_far.union(i)
1645    assert so_far == total
1646