1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import re 24import subprocess 25import threading 26import tempfile 27 28from rangelib import RangeSet 29 30 31__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 32 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72 73class Image(object): 74 def ReadRangeSet(self, ranges): 75 raise NotImplementedError 76 77 def TotalSha1(self, include_clobbered_blocks=False): 78 raise NotImplementedError 79 80 81class EmptyImage(Image): 82 """A zero-length image.""" 83 blocksize = 4096 84 care_map = RangeSet() 85 clobbered_blocks = RangeSet() 86 extended = RangeSet() 87 total_blocks = 0 88 file_map = {} 89 def ReadRangeSet(self, ranges): 90 return () 91 def TotalSha1(self, include_clobbered_blocks=False): 92 # EmptyImage always carries empty clobbered_blocks, so 93 # include_clobbered_blocks can be ignored. 94 assert self.clobbered_blocks.size() == 0 95 return sha1().hexdigest() 96 97 98class DataImage(Image): 99 """An image wrapped around a single string of data.""" 100 101 def __init__(self, data, trim=False, pad=False): 102 self.data = data 103 self.blocksize = 4096 104 105 assert not (trim and pad) 106 107 partial = len(self.data) % self.blocksize 108 if partial > 0: 109 if trim: 110 self.data = self.data[:-partial] 111 elif pad: 112 self.data += '\0' * (self.blocksize - partial) 113 else: 114 raise ValueError(("data for DataImage must be multiple of %d bytes " 115 "unless trim or pad is specified") % 116 (self.blocksize,)) 117 118 assert len(self.data) % self.blocksize == 0 119 120 self.total_blocks = len(self.data) / self.blocksize 121 self.care_map = RangeSet(data=(0, self.total_blocks)) 122 self.clobbered_blocks = RangeSet() 123 self.extended = RangeSet() 124 125 zero_blocks = [] 126 nonzero_blocks = [] 127 reference = '\0' * self.blocksize 128 129 for i in range(self.total_blocks): 130 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 131 if d == reference: 132 zero_blocks.append(i) 133 zero_blocks.append(i+1) 134 else: 135 nonzero_blocks.append(i) 136 nonzero_blocks.append(i+1) 137 138 self.file_map = {"__ZERO": RangeSet(zero_blocks), 139 "__NONZERO": RangeSet(nonzero_blocks)} 140 141 def ReadRangeSet(self, ranges): 142 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 143 144 def TotalSha1(self, include_clobbered_blocks=False): 145 # DataImage always carries empty clobbered_blocks, so 146 # include_clobbered_blocks can be ignored. 147 assert self.clobbered_blocks.size() == 0 148 return sha1(self.data).hexdigest() 149 150 151class Transfer(object): 152 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 153 self.tgt_name = tgt_name 154 self.src_name = src_name 155 self.tgt_ranges = tgt_ranges 156 self.src_ranges = src_ranges 157 self.style = style 158 self.intact = (getattr(tgt_ranges, "monotonic", False) and 159 getattr(src_ranges, "monotonic", False)) 160 161 # We use OrderedDict rather than dict so that the output is repeatable; 162 # otherwise it would depend on the hash values of the Transfer objects. 163 self.goes_before = OrderedDict() 164 self.goes_after = OrderedDict() 165 166 self.stash_before = [] 167 self.use_stash = [] 168 169 self.id = len(by_id) 170 by_id.append(self) 171 172 def NetStashChange(self): 173 return (sum(sr.size() for (_, sr) in self.stash_before) - 174 sum(sr.size() for (_, sr) in self.use_stash)) 175 176 def __str__(self): 177 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 178 " to " + str(self.tgt_ranges) + ">") 179 180 181# BlockImageDiff works on two image objects. An image object is 182# anything that provides the following attributes: 183# 184# blocksize: the size in bytes of a block, currently must be 4096. 185# 186# total_blocks: the total size of the partition/image, in blocks. 187# 188# care_map: a RangeSet containing which blocks (in the range [0, 189# total_blocks) we actually care about; i.e. which blocks contain 190# data. 191# 192# file_map: a dict that partitions the blocks contained in care_map 193# into smaller domains that are useful for doing diffs on. 194# (Typically a domain is a file, and the key in file_map is the 195# pathname.) 196# 197# clobbered_blocks: a RangeSet containing which blocks contain data 198# but may be altered by the FS. They need to be excluded when 199# verifying the partition integrity. 200# 201# ReadRangeSet(): a function that takes a RangeSet and returns the 202# data contained in the image blocks of that RangeSet. The data 203# is returned as a list or tuple of strings; concatenating the 204# elements together should produce the requested data. 205# Implementations are free to break up the data into list/tuple 206# elements in any way that is convenient. 207# 208# TotalSha1(): a function that returns (as a hex string) the SHA-1 209# hash of all the data in the image (ie, all the blocks in the 210# care_map minus clobbered_blocks, or including the clobbered 211# blocks if include_clobbered_blocks is True). 212# 213# When creating a BlockImageDiff, the src image may be None, in which 214# case the list of transfers produced will never read from the 215# original image. 216 217class BlockImageDiff(object): 218 def __init__(self, tgt, src=None, threads=None, version=3): 219 if threads is None: 220 threads = multiprocessing.cpu_count() // 2 221 if threads == 0: 222 threads = 1 223 self.threads = threads 224 self.version = version 225 self.transfers = [] 226 self.src_basenames = {} 227 self.src_numpatterns = {} 228 229 assert version in (1, 2, 3) 230 231 self.tgt = tgt 232 if src is None: 233 src = EmptyImage() 234 self.src = src 235 236 # The updater code that installs the patch always uses 4k blocks. 237 assert tgt.blocksize == 4096 238 assert src.blocksize == 4096 239 240 # The range sets in each filemap should comprise a partition of 241 # the care map. 242 self.AssertPartition(src.care_map, src.file_map.values()) 243 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 244 245 def Compute(self, prefix): 246 # When looking for a source file to use as the diff input for a 247 # target file, we try: 248 # 1) an exact path match if available, otherwise 249 # 2) a exact basename match if available, otherwise 250 # 3) a basename match after all runs of digits are replaced by 251 # "#" if available, otherwise 252 # 4) we have no source for this target. 253 self.AbbreviateSourceNames() 254 self.FindTransfers() 255 256 # Find the ordering dependencies among transfers (this is O(n^2) 257 # in the number of transfers). 258 self.GenerateDigraph() 259 # Find a sequence of transfers that satisfies as many ordering 260 # dependencies as possible (heuristically). 261 self.FindVertexSequence() 262 # Fix up the ordering dependencies that the sequence didn't 263 # satisfy. 264 if self.version == 1: 265 self.RemoveBackwardEdges() 266 else: 267 self.ReverseBackwardEdges() 268 self.ImproveVertexSequence() 269 270 # Double-check our work. 271 self.AssertSequenceGood() 272 273 self.ComputePatches(prefix) 274 self.WriteTransfers(prefix) 275 276 def HashBlocks(self, source, ranges): # pylint: disable=no-self-use 277 data = source.ReadRangeSet(ranges) 278 ctx = sha1() 279 280 for p in data: 281 ctx.update(p) 282 283 return ctx.hexdigest() 284 285 def WriteTransfers(self, prefix): 286 out = [] 287 288 total = 0 289 performs_read = False 290 291 stashes = {} 292 stashed_blocks = 0 293 max_stashed_blocks = 0 294 295 free_stash_ids = [] 296 next_stash_id = 0 297 298 for xf in self.transfers: 299 300 if self.version < 2: 301 assert not xf.stash_before 302 assert not xf.use_stash 303 304 for s, sr in xf.stash_before: 305 assert s not in stashes 306 if free_stash_ids: 307 sid = heapq.heappop(free_stash_ids) 308 else: 309 sid = next_stash_id 310 next_stash_id += 1 311 stashes[s] = sid 312 stashed_blocks += sr.size() 313 if self.version == 2: 314 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 315 else: 316 sh = self.HashBlocks(self.src, sr) 317 if sh in stashes: 318 stashes[sh] += 1 319 else: 320 stashes[sh] = 1 321 out.append("stash %s %s\n" % (sh, sr.to_string_raw())) 322 323 if stashed_blocks > max_stashed_blocks: 324 max_stashed_blocks = stashed_blocks 325 326 free_string = [] 327 328 if self.version == 1: 329 src_str = xf.src_ranges.to_string_raw() 330 elif self.version >= 2: 331 332 # <# blocks> <src ranges> 333 # OR 334 # <# blocks> <src ranges> <src locs> <stash refs...> 335 # OR 336 # <# blocks> - <stash refs...> 337 338 size = xf.src_ranges.size() 339 src_str = [str(size)] 340 341 unstashed_src_ranges = xf.src_ranges 342 mapped_stashes = [] 343 for s, sr in xf.use_stash: 344 sid = stashes.pop(s) 345 stashed_blocks -= sr.size() 346 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 347 sh = self.HashBlocks(self.src, sr) 348 sr = xf.src_ranges.map_within(sr) 349 mapped_stashes.append(sr) 350 if self.version == 2: 351 src_str.append("%d:%s" % (sid, sr.to_string_raw())) 352 else: 353 assert sh in stashes 354 src_str.append("%s:%s" % (sh, sr.to_string_raw())) 355 stashes[sh] -= 1 356 if stashes[sh] == 0: 357 free_string.append("free %s\n" % (sh)) 358 stashes.pop(sh) 359 heapq.heappush(free_stash_ids, sid) 360 361 if unstashed_src_ranges: 362 src_str.insert(1, unstashed_src_ranges.to_string_raw()) 363 if xf.use_stash: 364 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 365 src_str.insert(2, mapped_unstashed.to_string_raw()) 366 mapped_stashes.append(mapped_unstashed) 367 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 368 else: 369 src_str.insert(1, "-") 370 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 371 372 src_str = " ".join(src_str) 373 374 # all versions: 375 # zero <rangeset> 376 # new <rangeset> 377 # erase <rangeset> 378 # 379 # version 1: 380 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 381 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 382 # move <src rangeset> <tgt rangeset> 383 # 384 # version 2: 385 # bsdiff patchstart patchlen <tgt rangeset> <src_str> 386 # imgdiff patchstart patchlen <tgt rangeset> <src_str> 387 # move <tgt rangeset> <src_str> 388 # 389 # version 3: 390 # bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 391 # imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str> 392 # move hash <tgt rangeset> <src_str> 393 394 tgt_size = xf.tgt_ranges.size() 395 396 if xf.style == "new": 397 assert xf.tgt_ranges 398 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 399 total += tgt_size 400 elif xf.style == "move": 401 performs_read = True 402 assert xf.tgt_ranges 403 assert xf.src_ranges.size() == tgt_size 404 if xf.src_ranges != xf.tgt_ranges: 405 if self.version == 1: 406 out.append("%s %s %s\n" % ( 407 xf.style, 408 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 409 elif self.version == 2: 410 out.append("%s %s %s\n" % ( 411 xf.style, 412 xf.tgt_ranges.to_string_raw(), src_str)) 413 elif self.version >= 3: 414 # take into account automatic stashing of overlapping blocks 415 if xf.src_ranges.overlaps(xf.tgt_ranges): 416 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 417 if temp_stash_usage > max_stashed_blocks: 418 max_stashed_blocks = temp_stash_usage 419 420 out.append("%s %s %s %s\n" % ( 421 xf.style, 422 self.HashBlocks(self.tgt, xf.tgt_ranges), 423 xf.tgt_ranges.to_string_raw(), src_str)) 424 total += tgt_size 425 elif xf.style in ("bsdiff", "imgdiff"): 426 performs_read = True 427 assert xf.tgt_ranges 428 assert xf.src_ranges 429 if self.version == 1: 430 out.append("%s %d %d %s %s\n" % ( 431 xf.style, xf.patch_start, xf.patch_len, 432 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 433 elif self.version == 2: 434 out.append("%s %d %d %s %s\n" % ( 435 xf.style, xf.patch_start, xf.patch_len, 436 xf.tgt_ranges.to_string_raw(), src_str)) 437 elif self.version >= 3: 438 # take into account automatic stashing of overlapping blocks 439 if xf.src_ranges.overlaps(xf.tgt_ranges): 440 temp_stash_usage = stashed_blocks + xf.src_ranges.size() 441 if temp_stash_usage > max_stashed_blocks: 442 max_stashed_blocks = temp_stash_usage 443 444 out.append("%s %d %d %s %s %s %s\n" % ( 445 xf.style, 446 xf.patch_start, xf.patch_len, 447 self.HashBlocks(self.src, xf.src_ranges), 448 self.HashBlocks(self.tgt, xf.tgt_ranges), 449 xf.tgt_ranges.to_string_raw(), src_str)) 450 total += tgt_size 451 elif xf.style == "zero": 452 assert xf.tgt_ranges 453 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 454 if to_zero: 455 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 456 total += to_zero.size() 457 else: 458 raise ValueError("unknown transfer style '%s'\n" % xf.style) 459 460 if free_string: 461 out.append("".join(free_string)) 462 463 # sanity check: abort if we're going to need more than 512 MB if 464 # stash space 465 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 466 467 # Zero out extended blocks as a workaround for bug 20881595. 468 if self.tgt.extended: 469 out.append("zero %s\n" % (self.tgt.extended.to_string_raw(),)) 470 471 # We erase all the blocks on the partition that a) don't contain useful 472 # data in the new image and b) will not be touched by dm-verity. 473 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 474 all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended) 475 new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map) 476 if new_dontcare: 477 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 478 479 out.insert(0, "%d\n" % (self.version,)) # format version number 480 out.insert(1, str(total) + "\n") 481 if self.version >= 2: 482 # version 2 only: after the total block count, we give the number 483 # of stash slots needed, and the maximum size needed (in blocks) 484 out.insert(2, str(next_stash_id) + "\n") 485 out.insert(3, str(max_stashed_blocks) + "\n") 486 487 with open(prefix + ".transfer.list", "wb") as f: 488 for i in out: 489 f.write(i) 490 491 if self.version >= 2: 492 print("max stashed blocks: %d (%d bytes)\n" % ( 493 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 494 495 def ComputePatches(self, prefix): 496 print("Reticulating splines...") 497 diff_q = [] 498 patch_num = 0 499 with open(prefix + ".new.dat", "wb") as new_f: 500 for xf in self.transfers: 501 if xf.style == "zero": 502 pass 503 elif xf.style == "new": 504 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 505 new_f.write(piece) 506 elif xf.style == "diff": 507 src = self.src.ReadRangeSet(xf.src_ranges) 508 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 509 510 # We can't compare src and tgt directly because they may have 511 # the same content but be broken up into blocks differently, eg: 512 # 513 # ["he", "llo"] vs ["h", "ello"] 514 # 515 # We want those to compare equal, ideally without having to 516 # actually concatenate the strings (these may be tens of 517 # megabytes). 518 519 src_sha1 = sha1() 520 for p in src: 521 src_sha1.update(p) 522 tgt_sha1 = sha1() 523 tgt_size = 0 524 for p in tgt: 525 tgt_sha1.update(p) 526 tgt_size += len(p) 527 528 if src_sha1.digest() == tgt_sha1.digest(): 529 # These are identical; we don't need to generate a patch, 530 # just issue copy commands on the device. 531 xf.style = "move" 532 else: 533 # For files in zip format (eg, APKs, JARs, etc.) we would 534 # like to use imgdiff -z if possible (because it usually 535 # produces significantly smaller patches than bsdiff). 536 # This is permissible if: 537 # 538 # - the source and target files are monotonic (ie, the 539 # data is stored with blocks in increasing order), and 540 # - we haven't removed any blocks from the source set. 541 # 542 # If these conditions are satisfied then appending all the 543 # blocks in the set together in order will produce a valid 544 # zip file (plus possibly extra zeros in the last block), 545 # which is what imgdiff needs to operate. (imgdiff is 546 # fine with extra zeros at the end of the file.) 547 imgdiff = (xf.intact and 548 xf.tgt_name.split(".")[-1].lower() 549 in ("apk", "jar", "zip")) 550 xf.style = "imgdiff" if imgdiff else "bsdiff" 551 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 552 patch_num += 1 553 554 else: 555 assert False, "unknown style " + xf.style 556 557 if diff_q: 558 if self.threads > 1: 559 print("Computing patches (using %d threads)..." % (self.threads,)) 560 else: 561 print("Computing patches...") 562 diff_q.sort() 563 564 patches = [None] * patch_num 565 566 # TODO: Rewrite with multiprocessing.ThreadPool? 567 lock = threading.Lock() 568 def diff_worker(): 569 while True: 570 with lock: 571 if not diff_q: 572 return 573 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 574 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 575 size = len(patch) 576 with lock: 577 patches[patchnum] = (patch, xf) 578 print("%10d %10d (%6.2f%%) %7s %s" % ( 579 size, tgt_size, size * 100.0 / tgt_size, xf.style, 580 xf.tgt_name if xf.tgt_name == xf.src_name else ( 581 xf.tgt_name + " (from " + xf.src_name + ")"))) 582 583 threads = [threading.Thread(target=diff_worker) 584 for _ in range(self.threads)] 585 for th in threads: 586 th.start() 587 while threads: 588 threads.pop().join() 589 else: 590 patches = [] 591 592 p = 0 593 with open(prefix + ".patch.dat", "wb") as patch_f: 594 for patch, xf in patches: 595 xf.patch_start = p 596 xf.patch_len = len(patch) 597 patch_f.write(patch) 598 p += len(patch) 599 600 def AssertSequenceGood(self): 601 # Simulate the sequences of transfers we will output, and check that: 602 # - we never read a block after writing it, and 603 # - we write every block we care about exactly once. 604 605 # Start with no blocks having been touched yet. 606 touched = RangeSet() 607 608 # Imagine processing the transfers in order. 609 for xf in self.transfers: 610 # Check that the input blocks for this transfer haven't yet been touched. 611 612 x = xf.src_ranges 613 if self.version >= 2: 614 for _, sr in xf.use_stash: 615 x = x.subtract(sr) 616 617 assert not touched.overlaps(x) 618 # Check that the output blocks for this transfer haven't yet been touched. 619 assert not touched.overlaps(xf.tgt_ranges) 620 # Touch all the blocks written by this transfer. 621 touched = touched.union(xf.tgt_ranges) 622 623 # Check that we've written every target block. 624 assert touched == self.tgt.care_map 625 626 def ImproveVertexSequence(self): 627 print("Improving vertex order...") 628 629 # At this point our digraph is acyclic; we reversed any edges that 630 # were backwards in the heuristically-generated sequence. The 631 # previously-generated order is still acceptable, but we hope to 632 # find a better order that needs less memory for stashed data. 633 # Now we do a topological sort to generate a new vertex order, 634 # using a greedy algorithm to choose which vertex goes next 635 # whenever we have a choice. 636 637 # Make a copy of the edge set; this copy will get destroyed by the 638 # algorithm. 639 for xf in self.transfers: 640 xf.incoming = xf.goes_after.copy() 641 xf.outgoing = xf.goes_before.copy() 642 643 L = [] # the new vertex order 644 645 # S is the set of sources in the remaining graph; we always choose 646 # the one that leaves the least amount of stashed data after it's 647 # executed. 648 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 649 if not u.incoming] 650 heapq.heapify(S) 651 652 while S: 653 _, _, xf = heapq.heappop(S) 654 L.append(xf) 655 for u in xf.outgoing: 656 del u.incoming[xf] 657 if not u.incoming: 658 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 659 660 # if this fails then our graph had a cycle. 661 assert len(L) == len(self.transfers) 662 663 self.transfers = L 664 for i, xf in enumerate(L): 665 xf.order = i 666 667 def RemoveBackwardEdges(self): 668 print("Removing backward edges...") 669 in_order = 0 670 out_of_order = 0 671 lost_source = 0 672 673 for xf in self.transfers: 674 lost = 0 675 size = xf.src_ranges.size() 676 for u in xf.goes_before: 677 # xf should go before u 678 if xf.order < u.order: 679 # it does, hurray! 680 in_order += 1 681 else: 682 # it doesn't, boo. trim the blocks that u writes from xf's 683 # source, so that xf can go after u. 684 out_of_order += 1 685 assert xf.src_ranges.overlaps(u.tgt_ranges) 686 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 687 xf.intact = False 688 689 if xf.style == "diff" and not xf.src_ranges: 690 # nothing left to diff from; treat as new data 691 xf.style = "new" 692 693 lost = size - xf.src_ranges.size() 694 lost_source += lost 695 696 print((" %d/%d dependencies (%.2f%%) were violated; " 697 "%d source blocks removed.") % 698 (out_of_order, in_order + out_of_order, 699 (out_of_order * 100.0 / (in_order + out_of_order)) 700 if (in_order + out_of_order) else 0.0, 701 lost_source)) 702 703 def ReverseBackwardEdges(self): 704 print("Reversing backward edges...") 705 in_order = 0 706 out_of_order = 0 707 stashes = 0 708 stash_size = 0 709 710 for xf in self.transfers: 711 for u in xf.goes_before.copy(): 712 # xf should go before u 713 if xf.order < u.order: 714 # it does, hurray! 715 in_order += 1 716 else: 717 # it doesn't, boo. modify u to stash the blocks that it 718 # writes that xf wants to read, and then require u to go 719 # before xf. 720 out_of_order += 1 721 722 overlap = xf.src_ranges.intersect(u.tgt_ranges) 723 assert overlap 724 725 u.stash_before.append((stashes, overlap)) 726 xf.use_stash.append((stashes, overlap)) 727 stashes += 1 728 stash_size += overlap.size() 729 730 # reverse the edge direction; now xf must go after u 731 del xf.goes_before[u] 732 del u.goes_after[xf] 733 xf.goes_after[u] = None # value doesn't matter 734 u.goes_before[xf] = None 735 736 print((" %d/%d dependencies (%.2f%%) were violated; " 737 "%d source blocks stashed.") % 738 (out_of_order, in_order + out_of_order, 739 (out_of_order * 100.0 / (in_order + out_of_order)) 740 if (in_order + out_of_order) else 0.0, 741 stash_size)) 742 743 def FindVertexSequence(self): 744 print("Finding vertex sequence...") 745 746 # This is based on "A Fast & Effective Heuristic for the Feedback 747 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 748 # it as starting with the digraph G and moving all the vertices to 749 # be on a horizontal line in some order, trying to minimize the 750 # number of edges that end up pointing to the left. Left-pointing 751 # edges will get removed to turn the digraph into a DAG. In this 752 # case each edge has a weight which is the number of source blocks 753 # we'll lose if that edge is removed; we try to minimize the total 754 # weight rather than just the number of edges. 755 756 # Make a copy of the edge set; this copy will get destroyed by the 757 # algorithm. 758 for xf in self.transfers: 759 xf.incoming = xf.goes_after.copy() 760 xf.outgoing = xf.goes_before.copy() 761 762 # We use an OrderedDict instead of just a set so that the output 763 # is repeatable; otherwise it would depend on the hash values of 764 # the transfer objects. 765 G = OrderedDict() 766 for xf in self.transfers: 767 G[xf] = None 768 s1 = deque() # the left side of the sequence, built from left to right 769 s2 = deque() # the right side of the sequence, built from right to left 770 771 while G: 772 773 # Put all sinks at the end of the sequence. 774 while True: 775 sinks = [u for u in G if not u.outgoing] 776 if not sinks: 777 break 778 for u in sinks: 779 s2.appendleft(u) 780 del G[u] 781 for iu in u.incoming: 782 del iu.outgoing[u] 783 784 # Put all the sources at the beginning of the sequence. 785 while True: 786 sources = [u for u in G if not u.incoming] 787 if not sources: 788 break 789 for u in sources: 790 s1.append(u) 791 del G[u] 792 for iu in u.outgoing: 793 del iu.incoming[u] 794 795 if not G: 796 break 797 798 # Find the "best" vertex to put next. "Best" is the one that 799 # maximizes the net difference in source blocks saved we get by 800 # pretending it's a source rather than a sink. 801 802 max_d = None 803 best_u = None 804 for u in G: 805 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 806 if best_u is None or d > max_d: 807 max_d = d 808 best_u = u 809 810 u = best_u 811 s1.append(u) 812 del G[u] 813 for iu in u.outgoing: 814 del iu.incoming[u] 815 for iu in u.incoming: 816 del iu.outgoing[u] 817 818 # Now record the sequence in the 'order' field of each transfer, 819 # and by rearranging self.transfers to be in the chosen sequence. 820 821 new_transfers = [] 822 for x in itertools.chain(s1, s2): 823 x.order = len(new_transfers) 824 new_transfers.append(x) 825 del x.incoming 826 del x.outgoing 827 828 self.transfers = new_transfers 829 830 def GenerateDigraph(self): 831 print("Generating digraph...") 832 for a in self.transfers: 833 for b in self.transfers: 834 if a is b: 835 continue 836 837 # If the blocks written by A are read by B, then B needs to go before A. 838 i = a.tgt_ranges.intersect(b.src_ranges) 839 if i: 840 if b.src_name == "__ZERO": 841 # the cost of removing source blocks for the __ZERO domain 842 # is (nearly) zero. 843 size = 0 844 else: 845 size = i.size() 846 b.goes_before[a] = size 847 a.goes_after[b] = size 848 849 def FindTransfers(self): 850 empty = RangeSet() 851 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 852 if tgt_fn == "__ZERO": 853 # the special "__ZERO" domain is all the blocks not contained 854 # in any file and that are filled with zeros. We have a 855 # special transfer style for zero blocks. 856 src_ranges = self.src.file_map.get("__ZERO", empty) 857 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 858 "zero", self.transfers) 859 continue 860 861 elif tgt_fn == "__COPY": 862 # "__COPY" domain includes all the blocks not contained in any 863 # file and that need to be copied unconditionally to the target. 864 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 865 continue 866 867 elif tgt_fn in self.src.file_map: 868 # Look for an exact pathname match in the source. 869 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 870 "diff", self.transfers) 871 continue 872 873 b = os.path.basename(tgt_fn) 874 if b in self.src_basenames: 875 # Look for an exact basename match in the source. 876 src_fn = self.src_basenames[b] 877 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 878 "diff", self.transfers) 879 continue 880 881 b = re.sub("[0-9]+", "#", b) 882 if b in self.src_numpatterns: 883 # Look for a 'number pattern' match (a basename match after 884 # all runs of digits are replaced by "#"). (This is useful 885 # for .so files that contain version numbers in the filename 886 # that get bumped.) 887 src_fn = self.src_numpatterns[b] 888 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 889 "diff", self.transfers) 890 continue 891 892 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 893 894 def AbbreviateSourceNames(self): 895 for k in self.src.file_map.keys(): 896 b = os.path.basename(k) 897 self.src_basenames[b] = k 898 b = re.sub("[0-9]+", "#", b) 899 self.src_numpatterns[b] = k 900 901 @staticmethod 902 def AssertPartition(total, seq): 903 """Assert that all the RangeSets in 'seq' form a partition of the 904 'total' RangeSet (ie, they are nonintersecting and their union 905 equals 'total').""" 906 so_far = RangeSet() 907 for i in seq: 908 assert not so_far.overlaps(i) 909 so_far = so_far.union(i) 910 assert so_far == total 911