1#
2# Copyright (C) 2013 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17"""Applying a Chrome OS update payload.
18
19This module is used internally by the main Payload class for applying an update
20payload. The interface for invoking the applier is as follows:
21
22  applier = PayloadApplier(payload)
23  applier.Run(...)
24
25"""
26
27from __future__ import absolute_import
28from __future__ import print_function
29
30import array
31import bz2
32import hashlib
33# Not everywhere we can have the lzma library so we ignore it if we didn't have
34# it because it is not going to be used. For example, 'cros flash' uses
35# devserver code which eventually loads this file, but the lzma library is not
36# included in the client test devices, and it is not necessary to do so. But
37# lzma is not used in 'cros flash' so it should be fine. Python 3.x include
38# lzma, but for backward compatibility with Python 2.7, backports-lzma is
39# needed.
40try:
41  import lzma
42except ImportError:
43  try:
44    from backports import lzma
45  except ImportError:
46    pass
47import os
48import subprocess
49import sys
50import tempfile
51
52from update_payload import common
53from update_payload.error import PayloadError
54
55#
56# Helper functions.
57#
58def _VerifySha256(file_obj, expected_hash, name, length=-1):
59  """Verifies the SHA256 hash of a file.
60
61  Args:
62    file_obj: file object to read
63    expected_hash: the hash digest we expect to be getting
64    name: name string of this hash, for error reporting
65    length: precise length of data to verify (optional)
66
67  Raises:
68    PayloadError if computed hash doesn't match expected one, or if fails to
69    read the specified length of data.
70  """
71  hasher = hashlib.sha256()
72  block_length = 1024 * 1024
73  max_length = length if length >= 0 else sys.maxsize
74
75  while max_length > 0:
76    read_length = min(max_length, block_length)
77    data = file_obj.read(read_length)
78    if not data:
79      break
80    max_length -= len(data)
81    hasher.update(data)
82
83  if length >= 0 and max_length > 0:
84    raise PayloadError(
85        'insufficient data (%d instead of %d) when verifying %s' %
86        (length - max_length, length, name))
87
88  actual_hash = hasher.digest()
89  if actual_hash != expected_hash:
90    raise PayloadError('%s hash (%s) not as expected (%s)' %
91                       (name, common.FormatSha256(actual_hash),
92                        common.FormatSha256(expected_hash)))
93
94
95def _ReadExtents(file_obj, extents, block_size, max_length=-1):
96  """Reads data from file as defined by extent sequence.
97
98  This tries to be efficient by not copying data as it is read in chunks.
99
100  Args:
101    file_obj: file object
102    extents: sequence of block extents (offset and length)
103    block_size: size of each block
104    max_length: maximum length to read (optional)
105
106  Returns:
107    A character array containing the concatenated read data.
108  """
109  data = array.array('B')
110  if max_length < 0:
111    max_length = sys.maxsize
112  for ex in extents:
113    if max_length == 0:
114      break
115    read_length = min(max_length, ex.num_blocks * block_size)
116
117    file_obj.seek(ex.start_block * block_size)
118    data.fromfile(file_obj, read_length)
119
120    max_length -= read_length
121
122  return data
123
124
125def _WriteExtents(file_obj, data, extents, block_size, base_name):
126  """Writes data to file as defined by extent sequence.
127
128  This tries to be efficient by not copy data as it is written in chunks.
129
130  Args:
131    file_obj: file object
132    data: data to write
133    extents: sequence of block extents (offset and length)
134    block_size: size of each block
135    base_name: name string of extent sequence for error reporting
136
137  Raises:
138    PayloadError when things don't add up.
139  """
140  data_offset = 0
141  data_length = len(data)
142  for ex, ex_name in common.ExtentIter(extents, base_name):
143    if not data_length:
144      raise PayloadError('%s: more write extents than data' % ex_name)
145    write_length = min(data_length, ex.num_blocks * block_size)
146    file_obj.seek(ex.start_block * block_size)
147    file_obj.write(data[data_offset:(data_offset + write_length)])
148
149    data_offset += write_length
150    data_length -= write_length
151
152  if data_length:
153    raise PayloadError('%s: more data than write extents' % base_name)
154
155
156def _ExtentsToBspatchArg(extents, block_size, base_name, data_length=-1):
157  """Translates an extent sequence into a bspatch-compatible string argument.
158
159  Args:
160    extents: sequence of block extents (offset and length)
161    block_size: size of each block
162    base_name: name string of extent sequence for error reporting
163    data_length: the actual total length of the data in bytes (optional)
164
165  Returns:
166    A tuple consisting of (i) a string of the form
167    "off_1:len_1,...,off_n:len_n", (ii) an offset where zero padding is needed
168    for filling the last extent, (iii) the length of the padding (zero means no
169    padding is needed and the extents cover the full length of data).
170
171  Raises:
172    PayloadError if data_length is too short or too long.
173  """
174  arg = ''
175  pad_off = pad_len = 0
176  if data_length < 0:
177    data_length = sys.maxsize
178  for ex, ex_name in common.ExtentIter(extents, base_name):
179    if not data_length:
180      raise PayloadError('%s: more extents than total data length' % ex_name)
181
182    start_byte = ex.start_block * block_size
183    num_bytes = ex.num_blocks * block_size
184    if data_length < num_bytes:
185      # We're only padding a real extent.
186      pad_off = start_byte + data_length
187      pad_len = num_bytes - data_length
188      num_bytes = data_length
189
190    arg += '%s%d:%d' % (arg and ',', start_byte, num_bytes)
191    data_length -= num_bytes
192
193  if data_length:
194    raise PayloadError('%s: extents not covering full data length' % base_name)
195
196  return arg, pad_off, pad_len
197
198
199#
200# Payload application.
201#
202class PayloadApplier(object):
203  """Applying an update payload.
204
205  This is a short-lived object whose purpose is to isolate the logic used for
206  applying an update payload.
207  """
208
209  def __init__(self, payload, bsdiff_in_place=True, bspatch_path=None,
210               puffpatch_path=None, truncate_to_expected_size=True):
211    """Initialize the applier.
212
213    Args:
214      payload: the payload object to check
215      bsdiff_in_place: whether to perform BSDIFF operation in-place (optional)
216      bspatch_path: path to the bspatch binary (optional)
217      puffpatch_path: path to the puffpatch binary (optional)
218      truncate_to_expected_size: whether to truncate the resulting partitions
219                                 to their expected sizes, as specified in the
220                                 payload (optional)
221    """
222    assert payload.is_init, 'uninitialized update payload'
223    self.payload = payload
224    self.block_size = payload.manifest.block_size
225    self.minor_version = payload.manifest.minor_version
226    self.bsdiff_in_place = bsdiff_in_place
227    self.bspatch_path = bspatch_path or 'bspatch'
228    self.puffpatch_path = puffpatch_path or 'puffin'
229    self.truncate_to_expected_size = truncate_to_expected_size
230
231  def _ApplyReplaceOperation(self, op, op_name, out_data, part_file, part_size):
232    """Applies a REPLACE{,_BZ,_XZ} operation.
233
234    Args:
235      op: the operation object
236      op_name: name string for error reporting
237      out_data: the data to be written
238      part_file: the partition file object
239      part_size: the size of the partition
240
241    Raises:
242      PayloadError if something goes wrong.
243    """
244    block_size = self.block_size
245    data_length = len(out_data)
246
247    # Decompress data if needed.
248    if op.type == common.OpType.REPLACE_BZ:
249      out_data = bz2.decompress(out_data)
250      data_length = len(out_data)
251    elif op.type == common.OpType.REPLACE_XZ:
252      # pylint: disable=no-member
253      out_data = lzma.decompress(out_data)
254      data_length = len(out_data)
255
256    # Write data to blocks specified in dst extents.
257    data_start = 0
258    for ex, ex_name in common.ExtentIter(op.dst_extents,
259                                         '%s.dst_extents' % op_name):
260      start_block = ex.start_block
261      num_blocks = ex.num_blocks
262      count = num_blocks * block_size
263
264      data_end = data_start + count
265
266      # Make sure we're not running past partition boundary.
267      if (start_block + num_blocks) * block_size > part_size:
268        raise PayloadError(
269            '%s: extent (%s) exceeds partition size (%d)' %
270            (ex_name, common.FormatExtent(ex, block_size),
271             part_size))
272
273      # Make sure that we have enough data to write.
274      if data_end >= data_length + block_size:
275        raise PayloadError(
276            '%s: more dst blocks than data (even with padding)')
277
278      # Pad with zeros if necessary.
279      if data_end > data_length:
280        padding = data_end - data_length
281        out_data += b'\0' * padding
282
283      self.payload.payload_file.seek(start_block * block_size)
284      part_file.seek(start_block * block_size)
285      part_file.write(out_data[data_start:data_end])
286
287      data_start += count
288
289    # Make sure we wrote all data.
290    if data_start < data_length:
291      raise PayloadError('%s: wrote fewer bytes (%d) than expected (%d)' %
292                         (op_name, data_start, data_length))
293
294  def _ApplyZeroOperation(self, op, op_name, part_file):
295    """Applies a ZERO operation.
296
297    Args:
298      op: the operation object
299      op_name: name string for error reporting
300      part_file: the partition file object
301
302    Raises:
303      PayloadError if something goes wrong.
304    """
305    block_size = self.block_size
306    base_name = '%s.dst_extents' % op_name
307
308    # Iterate over the extents and write zero.
309    # pylint: disable=unused-variable
310    for ex, ex_name in common.ExtentIter(op.dst_extents, base_name):
311      part_file.seek(ex.start_block * block_size)
312      part_file.write(b'\0' * (ex.num_blocks * block_size))
313
314  def _ApplySourceCopyOperation(self, op, op_name, old_part_file,
315                                new_part_file):
316    """Applies a SOURCE_COPY operation.
317
318    Args:
319      op: the operation object
320      op_name: name string for error reporting
321      old_part_file: the old partition file object
322      new_part_file: the new partition file object
323
324    Raises:
325      PayloadError if something goes wrong.
326    """
327    if not old_part_file:
328      raise PayloadError(
329          '%s: no source partition file provided for operation type (%d)' %
330          (op_name, op.type))
331
332    block_size = self.block_size
333
334    # Gather input raw data from src extents.
335    in_data = _ReadExtents(old_part_file, op.src_extents, block_size)
336
337    # Dump extracted data to dst extents.
338    _WriteExtents(new_part_file, in_data, op.dst_extents, block_size,
339                  '%s.dst_extents' % op_name)
340
341  def _BytesInExtents(self, extents, base_name):
342    """Counts the length of extents in bytes.
343
344    Args:
345      extents: The list of Extents.
346      base_name: For error reporting.
347
348    Returns:
349      The number of bytes in extents.
350    """
351
352    length = 0
353    # pylint: disable=unused-variable
354    for ex, ex_name in common.ExtentIter(extents, base_name):
355      length += ex.num_blocks * self.block_size
356    return length
357
358  def _ApplyDiffOperation(self, op, op_name, patch_data, old_part_file,
359                          new_part_file):
360    """Applies a SOURCE_BSDIFF, BROTLI_BSDIFF or PUFFDIFF operation.
361
362    Args:
363      op: the operation object
364      op_name: name string for error reporting
365      patch_data: the binary patch content
366      old_part_file: the source partition file object
367      new_part_file: the target partition file object
368
369    Raises:
370      PayloadError if something goes wrong.
371    """
372    if not old_part_file:
373      raise PayloadError(
374          '%s: no source partition file provided for operation type (%d)' %
375          (op_name, op.type))
376
377    block_size = self.block_size
378
379    # Dump patch data to file.
380    with tempfile.NamedTemporaryFile(delete=False) as patch_file:
381      patch_file_name = patch_file.name
382      patch_file.write(patch_data)
383
384    if (hasattr(new_part_file, 'fileno') and
385        ((not old_part_file) or hasattr(old_part_file, 'fileno'))):
386      # Construct input and output extents argument for bspatch.
387
388      in_extents_arg, _, _ = _ExtentsToBspatchArg(
389          op.src_extents, block_size, '%s.src_extents' % op_name,
390          data_length=op.src_length if op.src_length else
391          self._BytesInExtents(op.src_extents, "%s.src_extents"))
392      out_extents_arg, pad_off, pad_len = _ExtentsToBspatchArg(
393          op.dst_extents, block_size, '%s.dst_extents' % op_name,
394          data_length=op.dst_length if op.dst_length else
395          self._BytesInExtents(op.dst_extents, "%s.dst_extents"))
396
397      new_file_name = '/dev/fd/%d' % new_part_file.fileno()
398      # Diff from source partition.
399      old_file_name = '/dev/fd/%d' % old_part_file.fileno()
400
401      # In python3, file descriptors(fd) are not passed to child processes by
402      # default. To pass the fds to the child processes, we need to set the flag
403      # 'inheritable' in the fds and make the subprocess calls with the argument
404      # close_fds set to False.
405      if sys.version_info.major >= 3:
406        os.set_inheritable(new_part_file.fileno(), True)
407        os.set_inheritable(old_part_file.fileno(), True)
408
409      if op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.BROTLI_BSDIFF):
410        # Invoke bspatch on partition file with extents args.
411        bspatch_cmd = [self.bspatch_path, old_file_name, new_file_name,
412                       patch_file_name, in_extents_arg, out_extents_arg]
413        subprocess.check_call(bspatch_cmd, close_fds=False)
414      elif op.type == common.OpType.PUFFDIFF:
415        # Invoke puffpatch on partition file with extents args.
416        puffpatch_cmd = [self.puffpatch_path,
417                         "--operation=puffpatch",
418                         "--src_file=%s" % old_file_name,
419                         "--dst_file=%s" % new_file_name,
420                         "--patch_file=%s" % patch_file_name,
421                         "--src_extents=%s" % in_extents_arg,
422                         "--dst_extents=%s" % out_extents_arg]
423        subprocess.check_call(puffpatch_cmd, close_fds=False)
424      else:
425        raise PayloadError("Unknown operation %s" % op.type)
426
427      # Pad with zeros past the total output length.
428      if pad_len:
429        new_part_file.seek(pad_off)
430        new_part_file.write(b'\0' * pad_len)
431    else:
432      # Gather input raw data and write to a temp file.
433      input_part_file = old_part_file if old_part_file else new_part_file
434      in_data = _ReadExtents(input_part_file, op.src_extents, block_size,
435                             max_length=op.src_length if op.src_length else
436                             self._BytesInExtents(op.src_extents,
437                                                  "%s.src_extents"))
438      with tempfile.NamedTemporaryFile(delete=False) as in_file:
439        in_file_name = in_file.name
440        in_file.write(in_data)
441
442      # Allocate temporary output file.
443      with tempfile.NamedTemporaryFile(delete=False) as out_file:
444        out_file_name = out_file.name
445
446      if op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.BROTLI_BSDIFF):
447        # Invoke bspatch.
448        bspatch_cmd = [self.bspatch_path, in_file_name, out_file_name,
449                       patch_file_name]
450        subprocess.check_call(bspatch_cmd)
451      elif op.type == common.OpType.PUFFDIFF:
452        # Invoke puffpatch.
453        puffpatch_cmd = [self.puffpatch_path,
454                         "--operation=puffpatch",
455                         "--src_file=%s" % in_file_name,
456                         "--dst_file=%s" % out_file_name,
457                         "--patch_file=%s" % patch_file_name]
458        subprocess.check_call(puffpatch_cmd)
459      else:
460        raise PayloadError("Unknown operation %s" % op.type)
461
462      # Read output.
463      with open(out_file_name, 'rb') as out_file:
464        out_data = out_file.read()
465        if len(out_data) != op.dst_length:
466          raise PayloadError(
467              '%s: actual patched data length (%d) not as expected (%d)' %
468              (op_name, len(out_data), op.dst_length))
469
470      # Write output back to partition, with padding.
471      unaligned_out_len = len(out_data) % block_size
472      if unaligned_out_len:
473        out_data += b'\0' * (block_size - unaligned_out_len)
474      _WriteExtents(new_part_file, out_data, op.dst_extents, block_size,
475                    '%s.dst_extents' % op_name)
476
477      # Delete input/output files.
478      os.remove(in_file_name)
479      os.remove(out_file_name)
480
481    # Delete patch file.
482    os.remove(patch_file_name)
483
484  def _ApplyOperations(self, operations, base_name, old_part_file,
485                       new_part_file, part_size):
486    """Applies a sequence of update operations to a partition.
487
488    Args:
489      operations: the sequence of operations
490      base_name: the name of the operation sequence
491      old_part_file: the old partition file object, open for reading/writing
492      new_part_file: the new partition file object, open for reading/writing
493      part_size: the partition size
494
495    Raises:
496      PayloadError if anything goes wrong while processing the payload.
497    """
498    for op, op_name in common.OperationIter(operations, base_name):
499      # Read data blob.
500      data = self.payload.ReadDataBlob(op.data_offset, op.data_length)
501
502      if op.type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ,
503                     common.OpType.REPLACE_XZ):
504        self._ApplyReplaceOperation(op, op_name, data, new_part_file, part_size)
505      elif op.type == common.OpType.ZERO:
506        self._ApplyZeroOperation(op, op_name, new_part_file)
507      elif op.type == common.OpType.SOURCE_COPY:
508        self._ApplySourceCopyOperation(op, op_name, old_part_file,
509                                       new_part_file)
510      elif op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.PUFFDIFF,
511                       common.OpType.BROTLI_BSDIFF):
512        self._ApplyDiffOperation(op, op_name, data, old_part_file,
513                                 new_part_file)
514      else:
515        raise PayloadError('%s: unknown operation type (%d)' %
516                           (op_name, op.type))
517
518  def _ApplyToPartition(self, operations, part_name, base_name,
519                        new_part_file_name, new_part_info,
520                        old_part_file_name=None, old_part_info=None):
521    """Applies an update to a partition.
522
523    Args:
524      operations: the sequence of update operations to apply
525      part_name: the name of the partition, for error reporting
526      base_name: the name of the operation sequence
527      new_part_file_name: file name to write partition data to
528      new_part_info: size and expected hash of dest partition
529      old_part_file_name: file name of source partition (optional)
530      old_part_info: size and expected hash of source partition (optional)
531
532    Raises:
533      PayloadError if anything goes wrong with the update.
534    """
535    # Do we have a source partition?
536    if old_part_file_name:
537      # Verify the source partition.
538      with open(old_part_file_name, 'rb') as old_part_file:
539        _VerifySha256(old_part_file, old_part_info.hash,
540                      'old ' + part_name, length=old_part_info.size)
541      new_part_file_mode = 'r+b'
542      open(new_part_file_name, 'w').close()
543
544    else:
545      # We need to create/truncate the dst partition file.
546      new_part_file_mode = 'w+b'
547
548    # Apply operations.
549    with open(new_part_file_name, new_part_file_mode) as new_part_file:
550      old_part_file = (open(old_part_file_name, 'r+b')
551                       if old_part_file_name else None)
552      try:
553        self._ApplyOperations(operations, base_name, old_part_file,
554                              new_part_file, new_part_info.size)
555      finally:
556        if old_part_file:
557          old_part_file.close()
558
559      # Truncate the result, if so instructed.
560      if self.truncate_to_expected_size:
561        new_part_file.seek(0, 2)
562        if new_part_file.tell() > new_part_info.size:
563          new_part_file.seek(new_part_info.size)
564          new_part_file.truncate()
565
566    # Verify the resulting partition.
567    with open(new_part_file_name, 'rb') as new_part_file:
568      _VerifySha256(new_part_file, new_part_info.hash,
569                    'new ' + part_name, length=new_part_info.size)
570
571  def Run(self, new_parts, old_parts=None):
572    """Applier entry point, invoking all update operations.
573
574    Args:
575      new_parts: map of partition name to dest partition file
576      old_parts: map of partition name to source partition file (optional)
577
578    Raises:
579      PayloadError if payload application failed.
580    """
581    if old_parts is None:
582      old_parts = {}
583
584    self.payload.ResetFile()
585
586    new_part_info = {}
587    old_part_info = {}
588    install_operations = []
589
590    manifest = self.payload.manifest
591    for part in manifest.partitions:
592      name = part.partition_name
593      new_part_info[name] = part.new_partition_info
594      old_part_info[name] = part.old_partition_info
595      install_operations.append((name, part.operations))
596
597    part_names = set(new_part_info.keys())  # Equivalently, old_part_info.keys()
598
599    # Make sure the arguments are sane and match the payload.
600    new_part_names = set(new_parts.keys())
601    if new_part_names != part_names:
602      raise PayloadError('missing dst partition(s) %s' %
603                         ', '.join(part_names - new_part_names))
604
605    old_part_names = set(old_parts.keys())
606    if part_names - old_part_names:
607      if self.payload.IsDelta():
608        raise PayloadError('trying to apply a delta update without src '
609                           'partition(s) %s' %
610                           ', '.join(part_names - old_part_names))
611    elif old_part_names == part_names:
612      if self.payload.IsFull():
613        raise PayloadError('trying to apply a full update onto src partitions')
614    else:
615      raise PayloadError('not all src partitions provided')
616
617    for name, operations in install_operations:
618      # Apply update to partition.
619      self._ApplyToPartition(
620          operations, name, '%s_install_operations' % name, new_parts[name],
621          new_part_info[name], old_parts.get(name, None), old_part_info[name])
622