1# Lint as: python2, python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================= 16"""A module to support operations on ipynb files""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import copy 24import json 25import re 26import shutil 27import tempfile 28 29import six 30 31CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"]) 32 33def is_python(cell): 34 """Checks if the cell consists of Python code.""" 35 return (cell["cell_type"] == "code" # code cells only 36 and cell["source"] # non-empty cells 37 and not six.ensure_str(cell["source"][0]).startswith("%%") 38 ) # multiline eg: %%bash 39 40 41def process_file(in_filename, out_filename, upgrader): 42 """The function where we inject the support for ipynb upgrade.""" 43 print("Extracting code lines from original notebook") 44 raw_code, notebook = _get_code(in_filename) 45 raw_lines = [cl.code for cl in raw_code] 46 47 # The function follows the original flow from `upgrader.process_fil` 48 with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: 49 50 processed_file, new_file_content, log, process_errors = ( 51 upgrader.update_string_pasta("\n".join(raw_lines), in_filename)) 52 53 if temp_file and processed_file: 54 new_notebook = _update_notebook( 55 notebook, raw_code, 56 six.ensure_str(new_file_content).split("\n")) 57 json.dump(new_notebook, temp_file) 58 else: 59 raise SyntaxError( 60 "Was not able to process the file: \n%s\n" % "".join(log)) 61 62 files_processed = processed_file 63 report_text = upgrader._format_log(log, in_filename, out_filename) 64 errors = process_errors 65 66 shutil.move(temp_file.name, out_filename) 67 68 return files_processed, report_text, errors 69 70 71def skip_magic(code_line, magic_list): 72 """Checks if the cell has magic, that is not Python-based. 73 74 Args: 75 code_line: A line of Python code 76 magic_list: A list of jupyter "magic" exceptions 77 78 Returns: 79 If the line jupyter "magic" line, not Python line 80 81 >>> skip_magic('!ls -laF', ['%', '!', '?']) 82 True 83 """ 84 85 for magic in magic_list: 86 if six.ensure_str(code_line).startswith(magic): 87 return True 88 89 return False 90 91 92def check_line_split(code_line): 93 r"""Checks if a line was split with `\`. 94 95 Args: 96 code_line: A line of Python code 97 98 Returns: 99 If the line was split with `\` 100 101 >>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n") 102 True 103 """ 104 105 return re.search(r"\\\s*\n$", code_line) 106 107 108def _get_code(input_file): 109 """Loads the ipynb file and returns a list of CodeLines.""" 110 111 raw_code = [] 112 113 with open(input_file) as in_file: 114 notebook = json.load(in_file) 115 116 cell_index = 0 117 for cell in notebook["cells"]: 118 if is_python(cell): 119 cell_lines = cell["source"] 120 121 is_line_split = False 122 for line_idx, code_line in enumerate(cell_lines): 123 124 # Sometimes, jupyter has more than python code 125 # Idea is to comment these lines, for upgrade time 126 if skip_magic(code_line, ["%", "!", "?"]) or is_line_split: 127 # Found a special character, need to "encode" 128 code_line = "###!!!" + six.ensure_str(code_line) 129 130 # if this cell ends with `\` -> skip the next line 131 is_line_split = check_line_split(code_line) 132 133 if is_line_split: 134 is_line_split = check_line_split(code_line) 135 136 # Sometimes, people leave \n at the end of cell 137 # in order to migrate only related things, and make the diff 138 # the smallest -> here is another hack 139 if (line_idx == len(cell_lines) - 140 1) and six.ensure_str(code_line).endswith("\n"): 141 code_line = six.ensure_str(code_line).replace("\n", "###===") 142 143 # sometimes a line would start with `\n` and content after 144 # that's the hack for this 145 raw_code.append( 146 CodeLine(cell_index, 147 six.ensure_str(code_line.rstrip()).replace("\n", 148 "###==="))) 149 150 cell_index += 1 151 152 return raw_code, notebook 153 154 155def _update_notebook(original_notebook, original_raw_lines, updated_code_lines): 156 """Updates notebook, once migration is done.""" 157 158 new_notebook = copy.deepcopy(original_notebook) 159 160 # validate that the number of lines is the same 161 assert len(original_raw_lines) == len(updated_code_lines), \ 162 ("The lengths of input and converted files are not the same: " 163 "{} vs {}".format(len(original_raw_lines), len(updated_code_lines))) 164 165 code_cell_idx = 0 166 for cell in new_notebook["cells"]: 167 if not is_python(cell): 168 continue 169 170 applicable_lines = [ 171 idx for idx, code_line in enumerate(original_raw_lines) 172 if code_line.cell_number == code_cell_idx 173 ] 174 175 new_code = [updated_code_lines[idx] for idx in applicable_lines] 176 177 cell["source"] = "\n".join(new_code).replace("###!!!", "").replace( 178 "###===", "\n") 179 code_cell_idx += 1 180 181 return new_notebook 182