1# Lint as: python2, python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# =============================================================================
16"""A module to support operations on ipynb files"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import copy
24import json
25import re
26import shutil
27import tempfile
28
29import six
30
31CodeLine = collections.namedtuple("CodeLine", ["cell_number", "code"])
32
33def is_python(cell):
34  """Checks if the cell consists of Python code."""
35  return (cell["cell_type"] == "code"  # code cells only
36          and cell["source"]  # non-empty cells
37          and not six.ensure_str(cell["source"][0]).startswith("%%")
38         )  # multiline eg: %%bash
39
40
41def process_file(in_filename, out_filename, upgrader):
42  """The function where we inject the support for ipynb upgrade."""
43  print("Extracting code lines from original notebook")
44  raw_code, notebook = _get_code(in_filename)
45  raw_lines = [cl.code for cl in raw_code]
46
47  # The function follows the original flow from `upgrader.process_fil`
48  with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
49
50    processed_file, new_file_content, log, process_errors = (
51        upgrader.update_string_pasta("\n".join(raw_lines), in_filename))
52
53    if temp_file and processed_file:
54      new_notebook = _update_notebook(
55          notebook, raw_code,
56          six.ensure_str(new_file_content).split("\n"))
57      json.dump(new_notebook, temp_file)
58    else:
59      raise SyntaxError(
60          "Was not able to process the file: \n%s\n" % "".join(log))
61
62    files_processed = processed_file
63    report_text = upgrader._format_log(log, in_filename, out_filename)
64    errors = process_errors
65
66  shutil.move(temp_file.name, out_filename)
67
68  return files_processed, report_text, errors
69
70
71def skip_magic(code_line, magic_list):
72  """Checks if the cell has magic, that is not Python-based.
73
74  Args:
75      code_line: A line of Python code
76      magic_list: A list of jupyter "magic" exceptions
77
78  Returns:
79    If the line jupyter "magic" line, not Python line
80
81   >>> skip_magic('!ls -laF', ['%', '!', '?'])
82  True
83  """
84
85  for magic in magic_list:
86    if six.ensure_str(code_line).startswith(magic):
87      return True
88
89  return False
90
91
92def check_line_split(code_line):
93  r"""Checks if a line was split with `\`.
94
95  Args:
96      code_line: A line of Python code
97
98  Returns:
99    If the line was split with `\`
100
101  >>> skip_magic("!gcloud ml-engine models create ${MODEL} \\\n")
102  True
103  """
104
105  return re.search(r"\\\s*\n$", code_line)
106
107
108def _get_code(input_file):
109  """Loads the ipynb file and returns a list of CodeLines."""
110
111  raw_code = []
112
113  with open(input_file) as in_file:
114    notebook = json.load(in_file)
115
116  cell_index = 0
117  for cell in notebook["cells"]:
118    if is_python(cell):
119      cell_lines = cell["source"]
120
121      is_line_split = False
122      for line_idx, code_line in enumerate(cell_lines):
123
124        # Sometimes, jupyter has more than python code
125        # Idea is to comment these lines, for upgrade time
126        if skip_magic(code_line, ["%", "!", "?"]) or is_line_split:
127          # Found a special character, need to "encode"
128          code_line = "###!!!" + six.ensure_str(code_line)
129
130          # if this cell ends with `\` -> skip the next line
131          is_line_split = check_line_split(code_line)
132
133        if is_line_split:
134          is_line_split = check_line_split(code_line)
135
136        # Sometimes, people leave \n at the end of cell
137        # in order to migrate only related things, and make the diff
138        # the smallest -> here is another hack
139        if (line_idx == len(cell_lines) -
140            1) and six.ensure_str(code_line).endswith("\n"):
141          code_line = six.ensure_str(code_line).replace("\n", "###===")
142
143        # sometimes a line would start with `\n` and content after
144        # that's the hack for this
145        raw_code.append(
146            CodeLine(cell_index,
147                     six.ensure_str(code_line.rstrip()).replace("\n",
148                                                                "###===")))
149
150      cell_index += 1
151
152  return raw_code, notebook
153
154
155def _update_notebook(original_notebook, original_raw_lines, updated_code_lines):
156  """Updates notebook, once migration is done."""
157
158  new_notebook = copy.deepcopy(original_notebook)
159
160  # validate that the number of lines is the same
161  assert len(original_raw_lines) == len(updated_code_lines), \
162    ("The lengths of input and converted files are not the same: "
163     "{} vs {}".format(len(original_raw_lines), len(updated_code_lines)))
164
165  code_cell_idx = 0
166  for cell in new_notebook["cells"]:
167    if not is_python(cell):
168      continue
169
170    applicable_lines = [
171        idx for idx, code_line in enumerate(original_raw_lines)
172        if code_line.cell_number == code_cell_idx
173    ]
174
175    new_code = [updated_code_lines[idx] for idx in applicable_lines]
176
177    cell["source"] = "\n".join(new_code).replace("###!!!", "").replace(
178        "###===", "\n")
179    code_cell_idx += 1
180
181  return new_notebook
182