1#!/usr/bin/python2
2#
3# Copyright (C) 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Send an A/B update to an Android device over adb."""
19
20import argparse
21import BaseHTTPServer
22import hashlib
23import logging
24import os
25import socket
26import subprocess
27import sys
28import threading
29import xml.etree.ElementTree
30import zipfile
31
32import update_payload.payload
33
34
35# The path used to store the OTA package when applying the package from a file.
36OTA_PACKAGE_PATH = '/data/ota_package'
37
38# The path to the payload public key on the device.
39PAYLOAD_KEY_PATH = '/etc/update_engine/update-payload-key.pub.pem'
40
41# The port on the device that update_engine should connect to.
42DEVICE_PORT = 1234
43
44def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None):
45  """Copy from a file object to another.
46
47  This function is similar to shutil.copyfileobj except that it allows to copy
48  less than the full source file.
49
50  Args:
51    fsrc: source file object where to read from.
52    fdst: destination file object where to write to.
53    buffer_size: size of the copy buffer in memory.
54    copy_length: maximum number of bytes to copy, or None to copy everything.
55
56  Returns:
57    the number of bytes copied.
58  """
59  copied = 0
60  while True:
61    chunk_size = buffer_size
62    if copy_length is not None:
63      chunk_size = min(chunk_size, copy_length - copied)
64      if not chunk_size:
65        break
66    buf = fsrc.read(chunk_size)
67    if not buf:
68      break
69    fdst.write(buf)
70    copied += len(buf)
71  return copied
72
73
74class AndroidOTAPackage(object):
75  """Android update payload using the .zip format.
76
77  Android OTA packages traditionally used a .zip file to store the payload. When
78  applying A/B updates over the network, a payload binary is stored RAW inside
79  this .zip file which is used by update_engine to apply the payload. To do
80  this, an offset and size inside the .zip file are provided.
81  """
82
83  # Android OTA package file paths.
84  OTA_PAYLOAD_BIN = 'payload.bin'
85  OTA_PAYLOAD_PROPERTIES_TXT = 'payload_properties.txt'
86
87  def __init__(self, otafilename):
88    self.otafilename = otafilename
89
90    otazip = zipfile.ZipFile(otafilename, 'r')
91    payload_info = otazip.getinfo(self.OTA_PAYLOAD_BIN)
92    self.offset = payload_info.header_offset + len(payload_info.FileHeader())
93    self.size = payload_info.file_size
94    self.properties = otazip.read(self.OTA_PAYLOAD_PROPERTIES_TXT)
95
96
97class UpdateHandler(BaseHTTPServer.BaseHTTPRequestHandler):
98  """A HTTPServer that supports single-range requests.
99
100  Attributes:
101    serving_payload: path to the only payload file we are serving.
102    serving_range: the start offset and size tuple of the payload.
103  """
104
105  @staticmethod
106  def _parse_range(range_str, file_size):
107    """Parse an HTTP range string.
108
109    Args:
110      range_str: HTTP Range header in the request, not including "Header:".
111      file_size: total size of the serving file.
112
113    Returns:
114      A tuple (start_range, end_range) with the range of bytes requested.
115    """
116    start_range = 0
117    end_range = file_size
118
119    if range_str:
120      range_str = range_str.split('=', 1)[1]
121      s, e = range_str.split('-', 1)
122      if s:
123        start_range = int(s)
124        if e:
125          end_range = int(e) + 1
126      elif e:
127        if int(e) < file_size:
128          start_range = file_size - int(e)
129    return start_range, end_range
130
131
132  def do_GET(self):  # pylint: disable=invalid-name
133    """Reply with the requested payload file."""
134    if self.path != '/payload':
135      self.send_error(404, 'Unknown request')
136      return
137
138    if not self.serving_payload:
139      self.send_error(500, 'No serving payload set')
140      return
141
142    try:
143      f = open(self.serving_payload, 'rb')
144    except IOError:
145      self.send_error(404, 'File not found')
146      return
147    # Handle the range request.
148    if 'Range' in self.headers:
149      self.send_response(206)
150    else:
151      self.send_response(200)
152
153    serving_start, serving_size = self.serving_range
154    start_range, end_range = self._parse_range(self.headers.get('range'),
155                                               serving_size)
156    logging.info('Serving request for %s from %s [%d, %d) length: %d',
157                 self.path, self.serving_payload, serving_start + start_range,
158                 serving_start + end_range, end_range - start_range)
159
160    self.send_header('Accept-Ranges', 'bytes')
161    self.send_header('Content-Range',
162                     'bytes ' + str(start_range) + '-' + str(end_range - 1) +
163                     '/' + str(end_range - start_range))
164    self.send_header('Content-Length', end_range - start_range)
165
166    stat = os.fstat(f.fileno())
167    self.send_header('Last-Modified', self.date_time_string(stat.st_mtime))
168    self.send_header('Content-type', 'application/octet-stream')
169    self.end_headers()
170
171    f.seek(serving_start + start_range)
172    CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
173
174
175  def do_POST(self):  # pylint: disable=invalid-name
176    """Reply with the omaha response xml."""
177    if self.path != '/update':
178      self.send_error(404, 'Unknown request')
179      return
180
181    if not self.serving_payload:
182      self.send_error(500, 'No serving payload set')
183      return
184
185    try:
186      f = open(self.serving_payload, 'rb')
187    except IOError:
188      self.send_error(404, 'File not found')
189      return
190
191    content_length = int(self.headers.getheader('Content-Length'))
192    request_xml = self.rfile.read(content_length)
193    xml_root = xml.etree.ElementTree.fromstring(request_xml)
194    appid = None
195    for app in xml_root.iter('app'):
196      if 'appid' in app.attrib:
197        appid = app.attrib['appid']
198        break
199    if not appid:
200      self.send_error(400, 'No appid in Omaha request')
201      return
202
203    self.send_response(200)
204    self.send_header("Content-type", "text/xml")
205    self.end_headers()
206
207    serving_start, serving_size = self.serving_range
208    sha256 = hashlib.sha256()
209    f.seek(serving_start)
210    bytes_to_hash = serving_size
211    while bytes_to_hash:
212      buf = f.read(min(bytes_to_hash, 1024 * 1024))
213      if not buf:
214        self.send_error(500, 'Payload too small')
215        return
216      sha256.update(buf)
217      bytes_to_hash -= len(buf)
218
219    payload = update_payload.Payload(f, payload_file_offset=serving_start)
220    payload.Init()
221
222    response_xml = '''
223        <?xml version="1.0" encoding="UTF-8"?>
224        <response protocol="3.0">
225          <app appid="{appid}">
226            <updatecheck status="ok">
227              <urls>
228                <url codebase="http://127.0.0.1:{port}/"/>
229              </urls>
230              <manifest version="0.0.0.1">
231                <actions>
232                  <action event="install" run="payload"/>
233                  <action event="postinstall" MetadataSize="{metadata_size}"/>
234                </actions>
235                <packages>
236                  <package hash_sha256="{payload_hash}" name="payload" size="{payload_size}"/>
237                </packages>
238              </manifest>
239            </updatecheck>
240          </app>
241        </response>
242    '''.format(appid=appid, port=DEVICE_PORT,
243               metadata_size=payload.metadata_size,
244               payload_hash=sha256.hexdigest(),
245               payload_size=serving_size)
246    self.wfile.write(response_xml.strip())
247    return
248
249
250class ServerThread(threading.Thread):
251  """A thread for serving HTTP requests."""
252
253  def __init__(self, ota_filename, serving_range):
254    threading.Thread.__init__(self)
255    # serving_payload and serving_range are class attributes and the
256    # UpdateHandler class is instantiated with every request.
257    UpdateHandler.serving_payload = ota_filename
258    UpdateHandler.serving_range = serving_range
259    self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
260    self.port = self._httpd.server_port
261
262  def run(self):
263    try:
264      self._httpd.serve_forever()
265    except (KeyboardInterrupt, socket.error):
266      pass
267    logging.info('Server Terminated')
268
269  def StopServer(self):
270    self._httpd.socket.close()
271
272
273def StartServer(ota_filename, serving_range):
274  t = ServerThread(ota_filename, serving_range)
275  t.start()
276  return t
277
278
279def AndroidUpdateCommand(ota_filename, payload_url, extra_headers):
280  """Return the command to run to start the update in the Android device."""
281  ota = AndroidOTAPackage(ota_filename)
282  headers = ota.properties
283  headers += 'USER_AGENT=Dalvik (something, something)\n'
284  headers += 'NETWORK_ID=0\n'
285  headers += extra_headers
286
287  return ['update_engine_client', '--update', '--follow',
288          '--payload=%s' % payload_url, '--offset=%d' % ota.offset,
289          '--size=%d' % ota.size, '--headers="%s"' % headers]
290
291
292def OmahaUpdateCommand(omaha_url):
293  """Return the command to run to start the update in a device using Omaha."""
294  return ['update_engine_client', '--update', '--follow',
295          '--omaha_url=%s' % omaha_url]
296
297
298class AdbHost(object):
299  """Represents a device connected via ADB."""
300
301  def __init__(self, device_serial=None):
302    """Construct an instance.
303
304    Args:
305        device_serial: options string serial number of attached device.
306    """
307    self._device_serial = device_serial
308    self._command_prefix = ['adb']
309    if self._device_serial:
310      self._command_prefix += ['-s', self._device_serial]
311
312  def adb(self, command):
313    """Run an ADB command like "adb push".
314
315    Args:
316      command: list of strings containing command and arguments to run
317
318    Returns:
319      the program's return code.
320
321    Raises:
322      subprocess.CalledProcessError on command exit != 0.
323    """
324    command = self._command_prefix + command
325    logging.info('Running: %s', ' '.join(str(x) for x in command))
326    p = subprocess.Popen(command, universal_newlines=True)
327    p.wait()
328    return p.returncode
329
330  def adb_output(self, command):
331    """Run an ADB command like "adb push" and return the output.
332
333    Args:
334      command: list of strings containing command and arguments to run
335
336    Returns:
337      the program's output as a string.
338
339    Raises:
340      subprocess.CalledProcessError on command exit != 0.
341    """
342    command = self._command_prefix + command
343    logging.info('Running: %s', ' '.join(str(x) for x in command))
344    return subprocess.check_output(command, universal_newlines=True)
345
346
347def main():
348  parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
349  parser.add_argument('otafile', metavar='PAYLOAD', type=str,
350                      help='the OTA package file (a .zip file) or raw payload \
351                      if device uses Omaha.')
352  parser.add_argument('--file', action='store_true',
353                      help='Push the file to the device before updating.')
354  parser.add_argument('--no-push', action='store_true',
355                      help='Skip the "push" command when using --file')
356  parser.add_argument('-s', type=str, default='', metavar='DEVICE',
357                      help='The specific device to use.')
358  parser.add_argument('--no-verbose', action='store_true',
359                      help='Less verbose output')
360  parser.add_argument('--public-key', type=str, default='',
361                      help='Override the public key used to verify payload.')
362  parser.add_argument('--extra-headers', type=str, default='',
363                      help='Extra headers to pass to the device.')
364  args = parser.parse_args()
365  logging.basicConfig(
366      level=logging.WARNING if args.no_verbose else logging.INFO)
367
368  dut = AdbHost(args.s)
369
370  server_thread = None
371  # List of commands to execute on exit.
372  finalize_cmds = []
373  # Commands to execute when canceling an update.
374  cancel_cmd = ['shell', 'su', '0', 'update_engine_client', '--cancel']
375  # List of commands to perform the update.
376  cmds = []
377
378  help_cmd = ['shell', 'su', '0', 'update_engine_client', '--help']
379  use_omaha = 'omaha' in dut.adb_output(help_cmd)
380
381  if args.file:
382    # Update via pushing a file to /data.
383    device_ota_file = os.path.join(OTA_PACKAGE_PATH, 'debug.zip')
384    payload_url = 'file://' + device_ota_file
385    if not args.no_push:
386      data_local_tmp_file = '/data/local/tmp/debug.zip'
387      cmds.append(['push', args.otafile, data_local_tmp_file])
388      cmds.append(['shell', 'su', '0', 'mv', data_local_tmp_file,
389                   device_ota_file])
390      cmds.append(['shell', 'su', '0', 'chcon',
391                   'u:object_r:ota_package_file:s0', device_ota_file])
392    cmds.append(['shell', 'su', '0', 'chown', 'system:cache', device_ota_file])
393    cmds.append(['shell', 'su', '0', 'chmod', '0660', device_ota_file])
394  else:
395    # Update via sending the payload over the network with an "adb reverse"
396    # command.
397    payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
398    if use_omaha and zipfile.is_zipfile(args.otafile):
399      ota = AndroidOTAPackage(args.otafile)
400      serving_range = (ota.offset, ota.size)
401    else:
402      serving_range = (0, os.stat(args.otafile).st_size)
403    server_thread = StartServer(args.otafile, serving_range)
404    cmds.append(
405        ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
406    finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])
407
408  if args.public_key:
409    payload_key_dir = os.path.dirname(PAYLOAD_KEY_PATH)
410    cmds.append(
411        ['shell', 'su', '0', 'mount', '-t', 'tmpfs', 'tmpfs', payload_key_dir])
412    # Allow adb push to payload_key_dir
413    cmds.append(['shell', 'su', '0', 'chcon', 'u:object_r:shell_data_file:s0',
414                 payload_key_dir])
415    cmds.append(['push', args.public_key, PAYLOAD_KEY_PATH])
416    # Allow update_engine to read it.
417    cmds.append(['shell', 'su', '0', 'chcon', '-R', 'u:object_r:system_file:s0',
418                 payload_key_dir])
419    finalize_cmds.append(['shell', 'su', '0', 'umount', payload_key_dir])
420
421  try:
422    # The main update command using the configured payload_url.
423    if use_omaha:
424      update_cmd = \
425          OmahaUpdateCommand('http://127.0.0.1:%d/update' % DEVICE_PORT)
426    else:
427      update_cmd = \
428          AndroidUpdateCommand(args.otafile, payload_url, args.extra_headers)
429    cmds.append(['shell', 'su', '0'] + update_cmd)
430
431    for cmd in cmds:
432      dut.adb(cmd)
433  except KeyboardInterrupt:
434    dut.adb(cancel_cmd)
435  finally:
436    if server_thread:
437      server_thread.StopServer()
438    for cmd in finalize_cmds:
439      dut.adb(cmd)
440
441  return 0
442
443if __name__ == '__main__':
444  sys.exit(main())
445