1# Copyright (C) 2014 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from __future__ import print_function 16 17from collections import deque, OrderedDict 18from hashlib import sha1 19import heapq 20import itertools 21import multiprocessing 22import os 23import pprint 24import re 25import subprocess 26import sys 27import threading 28import tempfile 29 30from rangelib import * 31 32__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] 33 34def compute_patch(src, tgt, imgdiff=False): 35 srcfd, srcfile = tempfile.mkstemp(prefix="src-") 36 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-") 37 patchfd, patchfile = tempfile.mkstemp(prefix="patch-") 38 os.close(patchfd) 39 40 try: 41 with os.fdopen(srcfd, "wb") as f_src: 42 for p in src: 43 f_src.write(p) 44 45 with os.fdopen(tgtfd, "wb") as f_tgt: 46 for p in tgt: 47 f_tgt.write(p) 48 try: 49 os.unlink(patchfile) 50 except OSError: 51 pass 52 if imgdiff: 53 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile], 54 stdout=open("/dev/null", "a"), 55 stderr=subprocess.STDOUT) 56 else: 57 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile]) 58 59 if p: 60 raise ValueError("diff failed: " + str(p)) 61 62 with open(patchfile, "rb") as f: 63 return f.read() 64 finally: 65 try: 66 os.unlink(srcfile) 67 os.unlink(tgtfile) 68 os.unlink(patchfile) 69 except OSError: 70 pass 71 72class EmptyImage(object): 73 """A zero-length image.""" 74 blocksize = 4096 75 care_map = RangeSet() 76 total_blocks = 0 77 file_map = {} 78 def ReadRangeSet(self, ranges): 79 return () 80 def TotalSha1(self): 81 return sha1().hexdigest() 82 83 84class DataImage(object): 85 """An image wrapped around a single string of data.""" 86 87 def __init__(self, data, trim=False, pad=False): 88 self.data = data 89 self.blocksize = 4096 90 91 assert not (trim and pad) 92 93 partial = len(self.data) % self.blocksize 94 if partial > 0: 95 if trim: 96 self.data = self.data[:-partial] 97 elif pad: 98 self.data += '\0' * (self.blocksize - partial) 99 else: 100 raise ValueError(("data for DataImage must be multiple of %d bytes " 101 "unless trim or pad is specified") % 102 (self.blocksize,)) 103 104 assert len(self.data) % self.blocksize == 0 105 106 self.total_blocks = len(self.data) / self.blocksize 107 self.care_map = RangeSet(data=(0, self.total_blocks)) 108 109 zero_blocks = [] 110 nonzero_blocks = [] 111 reference = '\0' * self.blocksize 112 113 for i in range(self.total_blocks): 114 d = self.data[i*self.blocksize : (i+1)*self.blocksize] 115 if d == reference: 116 zero_blocks.append(i) 117 zero_blocks.append(i+1) 118 else: 119 nonzero_blocks.append(i) 120 nonzero_blocks.append(i+1) 121 122 self.file_map = {"__ZERO": RangeSet(zero_blocks), 123 "__NONZERO": RangeSet(nonzero_blocks)} 124 125 def ReadRangeSet(self, ranges): 126 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges] 127 128 def TotalSha1(self): 129 if not hasattr(self, "sha1"): 130 self.sha1 = sha1(self.data).hexdigest() 131 return self.sha1 132 133 134class Transfer(object): 135 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id): 136 self.tgt_name = tgt_name 137 self.src_name = src_name 138 self.tgt_ranges = tgt_ranges 139 self.src_ranges = src_ranges 140 self.style = style 141 self.intact = (getattr(tgt_ranges, "monotonic", False) and 142 getattr(src_ranges, "monotonic", False)) 143 self.goes_before = {} 144 self.goes_after = {} 145 146 self.stash_before = [] 147 self.use_stash = [] 148 149 self.id = len(by_id) 150 by_id.append(self) 151 152 def NetStashChange(self): 153 return (sum(sr.size() for (_, sr) in self.stash_before) - 154 sum(sr.size() for (_, sr) in self.use_stash)) 155 156 def __str__(self): 157 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + 158 " to " + str(self.tgt_ranges) + ">") 159 160 161# BlockImageDiff works on two image objects. An image object is 162# anything that provides the following attributes: 163# 164# blocksize: the size in bytes of a block, currently must be 4096. 165# 166# total_blocks: the total size of the partition/image, in blocks. 167# 168# care_map: a RangeSet containing which blocks (in the range [0, 169# total_blocks) we actually care about; i.e. which blocks contain 170# data. 171# 172# file_map: a dict that partitions the blocks contained in care_map 173# into smaller domains that are useful for doing diffs on. 174# (Typically a domain is a file, and the key in file_map is the 175# pathname.) 176# 177# ReadRangeSet(): a function that takes a RangeSet and returns the 178# data contained in the image blocks of that RangeSet. The data 179# is returned as a list or tuple of strings; concatenating the 180# elements together should produce the requested data. 181# Implementations are free to break up the data into list/tuple 182# elements in any way that is convenient. 183# 184# TotalSha1(): a function that returns (as a hex string) the SHA-1 185# hash of all the data in the image (ie, all the blocks in the 186# care_map) 187# 188# When creating a BlockImageDiff, the src image may be None, in which 189# case the list of transfers produced will never read from the 190# original image. 191 192class BlockImageDiff(object): 193 def __init__(self, tgt, src=None, threads=None, version=2): 194 if threads is None: 195 threads = multiprocessing.cpu_count() // 2 196 if threads == 0: threads = 1 197 self.threads = threads 198 self.version = version 199 200 assert version in (1, 2) 201 202 self.tgt = tgt 203 if src is None: 204 src = EmptyImage() 205 self.src = src 206 207 # The updater code that installs the patch always uses 4k blocks. 208 assert tgt.blocksize == 4096 209 assert src.blocksize == 4096 210 211 # The range sets in each filemap should comprise a partition of 212 # the care map. 213 self.AssertPartition(src.care_map, src.file_map.values()) 214 self.AssertPartition(tgt.care_map, tgt.file_map.values()) 215 216 def Compute(self, prefix): 217 # When looking for a source file to use as the diff input for a 218 # target file, we try: 219 # 1) an exact path match if available, otherwise 220 # 2) a exact basename match if available, otherwise 221 # 3) a basename match after all runs of digits are replaced by 222 # "#" if available, otherwise 223 # 4) we have no source for this target. 224 self.AbbreviateSourceNames() 225 self.FindTransfers() 226 227 # Find the ordering dependencies among transfers (this is O(n^2) 228 # in the number of transfers). 229 self.GenerateDigraph() 230 # Find a sequence of transfers that satisfies as many ordering 231 # dependencies as possible (heuristically). 232 self.FindVertexSequence() 233 # Fix up the ordering dependencies that the sequence didn't 234 # satisfy. 235 if self.version == 1: 236 self.RemoveBackwardEdges() 237 else: 238 self.ReverseBackwardEdges() 239 self.ImproveVertexSequence() 240 241 # Double-check our work. 242 self.AssertSequenceGood() 243 244 self.ComputePatches(prefix) 245 self.WriteTransfers(prefix) 246 247 def WriteTransfers(self, prefix): 248 out = [] 249 250 total = 0 251 performs_read = False 252 253 stashes = {} 254 stashed_blocks = 0 255 max_stashed_blocks = 0 256 257 free_stash_ids = [] 258 next_stash_id = 0 259 260 for xf in self.transfers: 261 262 if self.version < 2: 263 assert not xf.stash_before 264 assert not xf.use_stash 265 266 for s, sr in xf.stash_before: 267 assert s not in stashes 268 if free_stash_ids: 269 sid = heapq.heappop(free_stash_ids) 270 else: 271 sid = next_stash_id 272 next_stash_id += 1 273 stashes[s] = sid 274 stashed_blocks += sr.size() 275 out.append("stash %d %s\n" % (sid, sr.to_string_raw())) 276 277 if stashed_blocks > max_stashed_blocks: 278 max_stashed_blocks = stashed_blocks 279 280 if self.version == 1: 281 src_string = xf.src_ranges.to_string_raw() 282 elif self.version == 2: 283 284 # <# blocks> <src ranges> 285 # OR 286 # <# blocks> <src ranges> <src locs> <stash refs...> 287 # OR 288 # <# blocks> - <stash refs...> 289 290 size = xf.src_ranges.size() 291 src_string = [str(size)] 292 293 unstashed_src_ranges = xf.src_ranges 294 mapped_stashes = [] 295 for s, sr in xf.use_stash: 296 sid = stashes.pop(s) 297 stashed_blocks -= sr.size() 298 unstashed_src_ranges = unstashed_src_ranges.subtract(sr) 299 sr = xf.src_ranges.map_within(sr) 300 mapped_stashes.append(sr) 301 src_string.append("%d:%s" % (sid, sr.to_string_raw())) 302 heapq.heappush(free_stash_ids, sid) 303 304 if unstashed_src_ranges: 305 src_string.insert(1, unstashed_src_ranges.to_string_raw()) 306 if xf.use_stash: 307 mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges) 308 src_string.insert(2, mapped_unstashed.to_string_raw()) 309 mapped_stashes.append(mapped_unstashed) 310 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 311 else: 312 src_string.insert(1, "-") 313 self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes) 314 315 src_string = " ".join(src_string) 316 317 # both versions: 318 # zero <rangeset> 319 # new <rangeset> 320 # erase <rangeset> 321 # 322 # version 1: 323 # bsdiff patchstart patchlen <src rangeset> <tgt rangeset> 324 # imgdiff patchstart patchlen <src rangeset> <tgt rangeset> 325 # move <src rangeset> <tgt rangeset> 326 # 327 # version 2: 328 # bsdiff patchstart patchlen <tgt rangeset> <src_string> 329 # imgdiff patchstart patchlen <tgt rangeset> <src_string> 330 # move <tgt rangeset> <src_string> 331 332 tgt_size = xf.tgt_ranges.size() 333 334 if xf.style == "new": 335 assert xf.tgt_ranges 336 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw())) 337 total += tgt_size 338 elif xf.style == "move": 339 performs_read = True 340 assert xf.tgt_ranges 341 assert xf.src_ranges.size() == tgt_size 342 if xf.src_ranges != xf.tgt_ranges: 343 if self.version == 1: 344 out.append("%s %s %s\n" % ( 345 xf.style, 346 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 347 elif self.version == 2: 348 out.append("%s %s %s\n" % ( 349 xf.style, 350 xf.tgt_ranges.to_string_raw(), src_string)) 351 total += tgt_size 352 elif xf.style in ("bsdiff", "imgdiff"): 353 performs_read = True 354 assert xf.tgt_ranges 355 assert xf.src_ranges 356 if self.version == 1: 357 out.append("%s %d %d %s %s\n" % ( 358 xf.style, xf.patch_start, xf.patch_len, 359 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw())) 360 elif self.version == 2: 361 out.append("%s %d %d %s %s\n" % ( 362 xf.style, xf.patch_start, xf.patch_len, 363 xf.tgt_ranges.to_string_raw(), src_string)) 364 total += tgt_size 365 elif xf.style == "zero": 366 assert xf.tgt_ranges 367 to_zero = xf.tgt_ranges.subtract(xf.src_ranges) 368 if to_zero: 369 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw())) 370 total += to_zero.size() 371 else: 372 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,) 373 374 375 # sanity check: abort if we're going to need more than 512 MB if 376 # stash space 377 assert max_stashed_blocks * self.tgt.blocksize < (512 << 20) 378 379 all_tgt = RangeSet(data=(0, self.tgt.total_blocks)) 380 if performs_read: 381 # if some of the original data is used, then at the end we'll 382 # erase all the blocks on the partition that don't contain data 383 # in the new image. 384 new_dontcare = all_tgt.subtract(self.tgt.care_map) 385 if new_dontcare: 386 out.append("erase %s\n" % (new_dontcare.to_string_raw(),)) 387 else: 388 # if nothing is read (ie, this is a full OTA), then we can start 389 # by erasing the entire partition. 390 out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),)) 391 392 out.insert(0, "%d\n" % (self.version,)) # format version number 393 out.insert(1, str(total) + "\n") 394 if self.version >= 2: 395 # version 2 only: after the total block count, we give the number 396 # of stash slots needed, and the maximum size needed (in blocks) 397 out.insert(2, str(next_stash_id) + "\n") 398 out.insert(3, str(max_stashed_blocks) + "\n") 399 400 with open(prefix + ".transfer.list", "wb") as f: 401 for i in out: 402 f.write(i) 403 404 if self.version >= 2: 405 print("max stashed blocks: %d (%d bytes)\n" % ( 406 max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize)) 407 408 def ComputePatches(self, prefix): 409 print("Reticulating splines...") 410 diff_q = [] 411 patch_num = 0 412 with open(prefix + ".new.dat", "wb") as new_f: 413 for xf in self.transfers: 414 if xf.style == "zero": 415 pass 416 elif xf.style == "new": 417 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges): 418 new_f.write(piece) 419 elif xf.style == "diff": 420 src = self.src.ReadRangeSet(xf.src_ranges) 421 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges) 422 423 # We can't compare src and tgt directly because they may have 424 # the same content but be broken up into blocks differently, eg: 425 # 426 # ["he", "llo"] vs ["h", "ello"] 427 # 428 # We want those to compare equal, ideally without having to 429 # actually concatenate the strings (these may be tens of 430 # megabytes). 431 432 src_sha1 = sha1() 433 for p in src: 434 src_sha1.update(p) 435 tgt_sha1 = sha1() 436 tgt_size = 0 437 for p in tgt: 438 tgt_sha1.update(p) 439 tgt_size += len(p) 440 441 if src_sha1.digest() == tgt_sha1.digest(): 442 # These are identical; we don't need to generate a patch, 443 # just issue copy commands on the device. 444 xf.style = "move" 445 else: 446 # For files in zip format (eg, APKs, JARs, etc.) we would 447 # like to use imgdiff -z if possible (because it usually 448 # produces significantly smaller patches than bsdiff). 449 # This is permissible if: 450 # 451 # - the source and target files are monotonic (ie, the 452 # data is stored with blocks in increasing order), and 453 # - we haven't removed any blocks from the source set. 454 # 455 # If these conditions are satisfied then appending all the 456 # blocks in the set together in order will produce a valid 457 # zip file (plus possibly extra zeros in the last block), 458 # which is what imgdiff needs to operate. (imgdiff is 459 # fine with extra zeros at the end of the file.) 460 imgdiff = (xf.intact and 461 xf.tgt_name.split(".")[-1].lower() 462 in ("apk", "jar", "zip")) 463 xf.style = "imgdiff" if imgdiff else "bsdiff" 464 diff_q.append((tgt_size, src, tgt, xf, patch_num)) 465 patch_num += 1 466 467 else: 468 assert False, "unknown style " + xf.style 469 470 if diff_q: 471 if self.threads > 1: 472 print("Computing patches (using %d threads)..." % (self.threads,)) 473 else: 474 print("Computing patches...") 475 diff_q.sort() 476 477 patches = [None] * patch_num 478 479 lock = threading.Lock() 480 def diff_worker(): 481 while True: 482 with lock: 483 if not diff_q: return 484 tgt_size, src, tgt, xf, patchnum = diff_q.pop() 485 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff")) 486 size = len(patch) 487 with lock: 488 patches[patchnum] = (patch, xf) 489 print("%10d %10d (%6.2f%%) %7s %s" % ( 490 size, tgt_size, size * 100.0 / tgt_size, xf.style, 491 xf.tgt_name if xf.tgt_name == xf.src_name else ( 492 xf.tgt_name + " (from " + xf.src_name + ")"))) 493 494 threads = [threading.Thread(target=diff_worker) 495 for i in range(self.threads)] 496 for th in threads: 497 th.start() 498 while threads: 499 threads.pop().join() 500 else: 501 patches = [] 502 503 p = 0 504 with open(prefix + ".patch.dat", "wb") as patch_f: 505 for patch, xf in patches: 506 xf.patch_start = p 507 xf.patch_len = len(patch) 508 patch_f.write(patch) 509 p += len(patch) 510 511 def AssertSequenceGood(self): 512 # Simulate the sequences of transfers we will output, and check that: 513 # - we never read a block after writing it, and 514 # - we write every block we care about exactly once. 515 516 # Start with no blocks having been touched yet. 517 touched = RangeSet() 518 519 # Imagine processing the transfers in order. 520 for xf in self.transfers: 521 # Check that the input blocks for this transfer haven't yet been touched. 522 523 x = xf.src_ranges 524 if self.version >= 2: 525 for _, sr in xf.use_stash: 526 x = x.subtract(sr) 527 528 assert not touched.overlaps(x) 529 # Check that the output blocks for this transfer haven't yet been touched. 530 assert not touched.overlaps(xf.tgt_ranges) 531 # Touch all the blocks written by this transfer. 532 touched = touched.union(xf.tgt_ranges) 533 534 # Check that we've written every target block. 535 assert touched == self.tgt.care_map 536 537 def ImproveVertexSequence(self): 538 print("Improving vertex order...") 539 540 # At this point our digraph is acyclic; we reversed any edges that 541 # were backwards in the heuristically-generated sequence. The 542 # previously-generated order is still acceptable, but we hope to 543 # find a better order that needs less memory for stashed data. 544 # Now we do a topological sort to generate a new vertex order, 545 # using a greedy algorithm to choose which vertex goes next 546 # whenever we have a choice. 547 548 # Make a copy of the edge set; this copy will get destroyed by the 549 # algorithm. 550 for xf in self.transfers: 551 xf.incoming = xf.goes_after.copy() 552 xf.outgoing = xf.goes_before.copy() 553 554 L = [] # the new vertex order 555 556 # S is the set of sources in the remaining graph; we always choose 557 # the one that leaves the least amount of stashed data after it's 558 # executed. 559 S = [(u.NetStashChange(), u.order, u) for u in self.transfers 560 if not u.incoming] 561 heapq.heapify(S) 562 563 while S: 564 _, _, xf = heapq.heappop(S) 565 L.append(xf) 566 for u in xf.outgoing: 567 del u.incoming[xf] 568 if not u.incoming: 569 heapq.heappush(S, (u.NetStashChange(), u.order, u)) 570 571 # if this fails then our graph had a cycle. 572 assert len(L) == len(self.transfers) 573 574 self.transfers = L 575 for i, xf in enumerate(L): 576 xf.order = i 577 578 def RemoveBackwardEdges(self): 579 print("Removing backward edges...") 580 in_order = 0 581 out_of_order = 0 582 lost_source = 0 583 584 for xf in self.transfers: 585 lost = 0 586 size = xf.src_ranges.size() 587 for u in xf.goes_before: 588 # xf should go before u 589 if xf.order < u.order: 590 # it does, hurray! 591 in_order += 1 592 else: 593 # it doesn't, boo. trim the blocks that u writes from xf's 594 # source, so that xf can go after u. 595 out_of_order += 1 596 assert xf.src_ranges.overlaps(u.tgt_ranges) 597 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges) 598 xf.intact = False 599 600 if xf.style == "diff" and not xf.src_ranges: 601 # nothing left to diff from; treat as new data 602 xf.style = "new" 603 604 lost = size - xf.src_ranges.size() 605 lost_source += lost 606 607 print((" %d/%d dependencies (%.2f%%) were violated; " 608 "%d source blocks removed.") % 609 (out_of_order, in_order + out_of_order, 610 (out_of_order * 100.0 / (in_order + out_of_order)) 611 if (in_order + out_of_order) else 0.0, 612 lost_source)) 613 614 def ReverseBackwardEdges(self): 615 print("Reversing backward edges...") 616 in_order = 0 617 out_of_order = 0 618 stashes = 0 619 stash_size = 0 620 621 for xf in self.transfers: 622 lost = 0 623 size = xf.src_ranges.size() 624 for u in xf.goes_before.copy(): 625 # xf should go before u 626 if xf.order < u.order: 627 # it does, hurray! 628 in_order += 1 629 else: 630 # it doesn't, boo. modify u to stash the blocks that it 631 # writes that xf wants to read, and then require u to go 632 # before xf. 633 out_of_order += 1 634 635 overlap = xf.src_ranges.intersect(u.tgt_ranges) 636 assert overlap 637 638 u.stash_before.append((stashes, overlap)) 639 xf.use_stash.append((stashes, overlap)) 640 stashes += 1 641 stash_size += overlap.size() 642 643 # reverse the edge direction; now xf must go after u 644 del xf.goes_before[u] 645 del u.goes_after[xf] 646 xf.goes_after[u] = None # value doesn't matter 647 u.goes_before[xf] = None 648 649 print((" %d/%d dependencies (%.2f%%) were violated; " 650 "%d source blocks stashed.") % 651 (out_of_order, in_order + out_of_order, 652 (out_of_order * 100.0 / (in_order + out_of_order)) 653 if (in_order + out_of_order) else 0.0, 654 stash_size)) 655 656 def FindVertexSequence(self): 657 print("Finding vertex sequence...") 658 659 # This is based on "A Fast & Effective Heuristic for the Feedback 660 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of 661 # it as starting with the digraph G and moving all the vertices to 662 # be on a horizontal line in some order, trying to minimize the 663 # number of edges that end up pointing to the left. Left-pointing 664 # edges will get removed to turn the digraph into a DAG. In this 665 # case each edge has a weight which is the number of source blocks 666 # we'll lose if that edge is removed; we try to minimize the total 667 # weight rather than just the number of edges. 668 669 # Make a copy of the edge set; this copy will get destroyed by the 670 # algorithm. 671 for xf in self.transfers: 672 xf.incoming = xf.goes_after.copy() 673 xf.outgoing = xf.goes_before.copy() 674 675 # We use an OrderedDict instead of just a set so that the output 676 # is repeatable; otherwise it would depend on the hash values of 677 # the transfer objects. 678 G = OrderedDict() 679 for xf in self.transfers: 680 G[xf] = None 681 s1 = deque() # the left side of the sequence, built from left to right 682 s2 = deque() # the right side of the sequence, built from right to left 683 684 while G: 685 686 # Put all sinks at the end of the sequence. 687 while True: 688 sinks = [u for u in G if not u.outgoing] 689 if not sinks: break 690 for u in sinks: 691 s2.appendleft(u) 692 del G[u] 693 for iu in u.incoming: 694 del iu.outgoing[u] 695 696 # Put all the sources at the beginning of the sequence. 697 while True: 698 sources = [u for u in G if not u.incoming] 699 if not sources: break 700 for u in sources: 701 s1.append(u) 702 del G[u] 703 for iu in u.outgoing: 704 del iu.incoming[u] 705 706 if not G: break 707 708 # Find the "best" vertex to put next. "Best" is the one that 709 # maximizes the net difference in source blocks saved we get by 710 # pretending it's a source rather than a sink. 711 712 max_d = None 713 best_u = None 714 for u in G: 715 d = sum(u.outgoing.values()) - sum(u.incoming.values()) 716 if best_u is None or d > max_d: 717 max_d = d 718 best_u = u 719 720 u = best_u 721 s1.append(u) 722 del G[u] 723 for iu in u.outgoing: 724 del iu.incoming[u] 725 for iu in u.incoming: 726 del iu.outgoing[u] 727 728 # Now record the sequence in the 'order' field of each transfer, 729 # and by rearranging self.transfers to be in the chosen sequence. 730 731 new_transfers = [] 732 for x in itertools.chain(s1, s2): 733 x.order = len(new_transfers) 734 new_transfers.append(x) 735 del x.incoming 736 del x.outgoing 737 738 self.transfers = new_transfers 739 740 def GenerateDigraph(self): 741 print("Generating digraph...") 742 for a in self.transfers: 743 for b in self.transfers: 744 if a is b: continue 745 746 # If the blocks written by A are read by B, then B needs to go before A. 747 i = a.tgt_ranges.intersect(b.src_ranges) 748 if i: 749 if b.src_name == "__ZERO": 750 # the cost of removing source blocks for the __ZERO domain 751 # is (nearly) zero. 752 size = 0 753 else: 754 size = i.size() 755 b.goes_before[a] = size 756 a.goes_after[b] = size 757 758 def FindTransfers(self): 759 self.transfers = [] 760 empty = RangeSet() 761 for tgt_fn, tgt_ranges in self.tgt.file_map.items(): 762 if tgt_fn == "__ZERO": 763 # the special "__ZERO" domain is all the blocks not contained 764 # in any file and that are filled with zeros. We have a 765 # special transfer style for zero blocks. 766 src_ranges = self.src.file_map.get("__ZERO", empty) 767 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges, 768 "zero", self.transfers) 769 continue 770 771 elif tgt_fn in self.src.file_map: 772 # Look for an exact pathname match in the source. 773 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn], 774 "diff", self.transfers) 775 continue 776 777 b = os.path.basename(tgt_fn) 778 if b in self.src_basenames: 779 # Look for an exact basename match in the source. 780 src_fn = self.src_basenames[b] 781 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 782 "diff", self.transfers) 783 continue 784 785 b = re.sub("[0-9]+", "#", b) 786 if b in self.src_numpatterns: 787 # Look for a 'number pattern' match (a basename match after 788 # all runs of digits are replaced by "#"). (This is useful 789 # for .so files that contain version numbers in the filename 790 # that get bumped.) 791 src_fn = self.src_numpatterns[b] 792 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn], 793 "diff", self.transfers) 794 continue 795 796 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers) 797 798 def AbbreviateSourceNames(self): 799 self.src_basenames = {} 800 self.src_numpatterns = {} 801 802 for k in self.src.file_map.keys(): 803 b = os.path.basename(k) 804 self.src_basenames[b] = k 805 b = re.sub("[0-9]+", "#", b) 806 self.src_numpatterns[b] = k 807 808 @staticmethod 809 def AssertPartition(total, seq): 810 """Assert that all the RangeSets in 'seq' form a partition of the 811 'total' RangeSet (ie, they are nonintersecting and their union 812 equals 'total').""" 813 so_far = RangeSet() 814 for i in seq: 815 assert not so_far.overlaps(i) 816 so_far = so_far.union(i) 817 assert so_far == total 818