1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Entry points for YAPF.
15
16The main APIs that YAPF exposes to drive the reformatting.
17
18  FormatFile(): reformat a file.
19  FormatCode(): reformat a string of code.
20
21These APIs have some common arguments:
22
23  style_config: (string) Either a style name or a path to a file that contains
24    formatting style settings. If None is specified, use the default style
25    as set in style.DEFAULT_STYLE_FACTORY
26  lines: (list of tuples of integers) A list of tuples of lines, [start, end],
27    that we want to format. The lines are 1-based indexed. It can be used by
28    third-party code (e.g., IDEs) when reformatting a snippet of code rather
29    than a whole file.
30  print_diff: (bool) Instead of returning the reformatted source, return a
31    diff that turns the formatted source into reformatter source.
32  verify: (bool) True if reformatted code should be verified for syntax.
33"""
34
35import difflib
36import re
37import sys
38
39from lib2to3.pgen2 import parse
40
41from yapf.yapflib import blank_line_calculator
42from yapf.yapflib import comment_splicer
43from yapf.yapflib import continuation_splicer
44from yapf.yapflib import file_resources
45from yapf.yapflib import identify_container
46from yapf.yapflib import py3compat
47from yapf.yapflib import pytree_unwrapper
48from yapf.yapflib import pytree_utils
49from yapf.yapflib import reformatter
50from yapf.yapflib import split_penalty
51from yapf.yapflib import style
52from yapf.yapflib import subtype_assigner
53
54
55def FormatFile(filename,
56               style_config=None,
57               lines=None,
58               print_diff=False,
59               verify=False,
60               in_place=False,
61               logger=None):
62  """Format a single Python file and return the formatted code.
63
64  Arguments:
65    filename: (unicode) The file to reformat.
66    in_place: (bool) If True, write the reformatted code back to the file.
67    logger: (io streamer) A stream to output logging.
68    remaining arguments: see comment at the top of this module.
69
70  Returns:
71    Tuple of (reformatted_code, encoding, changed). reformatted_code is None if
72    the file is successfully written to (having used in_place). reformatted_code
73    is a diff if print_diff is True.
74
75  Raises:
76    IOError: raised if there was an error reading the file.
77    ValueError: raised if in_place and print_diff are both specified.
78  """
79  _CheckPythonVersion()
80
81  if in_place and print_diff:
82    raise ValueError('Cannot pass both in_place and print_diff.')
83
84  original_source, newline, encoding = ReadFile(filename, logger)
85  reformatted_source, changed = FormatCode(
86      original_source,
87      style_config=style_config,
88      filename=filename,
89      lines=lines,
90      print_diff=print_diff,
91      verify=verify)
92  if reformatted_source.rstrip('\n'):
93    lines = reformatted_source.rstrip('\n').split('\n')
94    reformatted_source = newline.join(line for line in lines) + newline
95  if in_place:
96    if original_source and original_source != reformatted_source:
97      file_resources.WriteReformattedCode(filename, reformatted_source,
98                                          encoding, in_place)
99    return None, encoding, changed
100
101  return reformatted_source, encoding, changed
102
103
104def FormatCode(unformatted_source,
105               filename='<unknown>',
106               style_config=None,
107               lines=None,
108               print_diff=False,
109               verify=False):
110  """Format a string of Python code.
111
112  This provides an alternative entry point to YAPF.
113
114  Arguments:
115    unformatted_source: (unicode) The code to format.
116    filename: (unicode) The name of the file being reformatted.
117    remaining arguments: see comment at the top of this module.
118
119  Returns:
120    Tuple of (reformatted_source, changed). reformatted_source conforms to the
121    desired formatting style. changed is True if the source changed.
122  """
123  _CheckPythonVersion()
124  style.SetGlobalStyle(style.CreateStyleFromConfig(style_config))
125  if not unformatted_source.endswith('\n'):
126    unformatted_source += '\n'
127
128  try:
129    tree = pytree_utils.ParseCodeToTree(unformatted_source)
130  except parse.ParseError as e:
131    e.msg = filename + ': ' + e.msg
132    raise
133
134  # Run passes on the tree, modifying it in place.
135  comment_splicer.SpliceComments(tree)
136  continuation_splicer.SpliceContinuations(tree)
137  subtype_assigner.AssignSubtypes(tree)
138  identify_container.IdentifyContainers(tree)
139  split_penalty.ComputeSplitPenalties(tree)
140  blank_line_calculator.CalculateBlankLines(tree)
141
142  uwlines = pytree_unwrapper.UnwrapPyTree(tree)
143  for uwl in uwlines:
144    uwl.CalculateFormattingInformation()
145
146  lines = _LineRangesToSet(lines)
147  _MarkLinesToFormat(uwlines, lines)
148  reformatted_source = reformatter.Reformat(
149      _SplitSemicolons(uwlines), verify, lines)
150
151  if unformatted_source == reformatted_source:
152    return '' if print_diff else reformatted_source, False
153
154  code_diff = _GetUnifiedDiff(
155      unformatted_source, reformatted_source, filename=filename)
156
157  if print_diff:
158    return code_diff, code_diff.strip() != ''  # pylint: disable=g-explicit-bool-comparison
159
160  return reformatted_source, True
161
162
163def _CheckPythonVersion():  # pragma: no cover
164  errmsg = 'yapf is only supported for Python 2.7 or 3.4+'
165  if sys.version_info[0] == 2:
166    if sys.version_info[1] < 7:
167      raise RuntimeError(errmsg)
168  elif sys.version_info[0] == 3:
169    if sys.version_info[1] < 4:
170      raise RuntimeError(errmsg)
171
172
173def ReadFile(filename, logger=None):
174  """Read the contents of the file.
175
176  An optional logger can be specified to emit messages to your favorite logging
177  stream. If specified, then no exception is raised. This is external so that it
178  can be used by third-party applications.
179
180  Arguments:
181    filename: (unicode) The name of the file.
182    logger: (function) A function or lambda that takes a string and emits it.
183
184  Returns:
185    The contents of filename.
186
187  Raises:
188    IOError: raised if there was an error reading the file.
189  """
190  try:
191    encoding = file_resources.FileEncoding(filename)
192
193    # Preserves line endings.
194    with py3compat.open_with_encoding(
195        filename, mode='r', encoding=encoding, newline='') as fd:
196      lines = fd.readlines()
197
198    line_ending = file_resources.LineEnding(lines)
199    source = '\n'.join(line.rstrip('\r\n') for line in lines) + '\n'
200    return source, line_ending, encoding
201  except IOError as err:  # pragma: no cover
202    if logger:
203      logger(err)
204    raise
205
206
207def _SplitSemicolons(uwlines):
208  res = []
209  for uwline in uwlines:
210    res.extend(uwline.Split())
211  return res
212
213
214DISABLE_PATTERN = r'^#.*\byapf:\s*disable\b'
215ENABLE_PATTERN = r'^#.*\byapf:\s*enable\b'
216
217
218def _LineRangesToSet(line_ranges):
219  """Return a set of lines in the range."""
220
221  if line_ranges is None:
222    return None
223
224  line_set = set()
225  for low, high in sorted(line_ranges):
226    line_set.update(range(low, high + 1))
227
228  return line_set
229
230
231def _MarkLinesToFormat(uwlines, lines):
232  """Skip sections of code that we shouldn't reformat."""
233  if lines:
234    for uwline in uwlines:
235      uwline.disable = not lines.intersection(
236          range(uwline.lineno, uwline.last.lineno + 1))
237
238  # Now go through the lines and disable any lines explicitly marked as
239  # disabled.
240  index = 0
241  while index < len(uwlines):
242    uwline = uwlines[index]
243    if uwline.is_comment:
244      if _DisableYAPF(uwline.first.value.strip()):
245        index += 1
246        while index < len(uwlines):
247          uwline = uwlines[index]
248          if uwline.is_comment and _EnableYAPF(uwline.first.value.strip()):
249            break
250          uwline.disable = True
251          index += 1
252    elif re.search(DISABLE_PATTERN, uwline.last.value.strip(), re.IGNORECASE):
253      uwline.disable = True
254    index += 1
255
256
257def _DisableYAPF(line):
258  return (re.search(DISABLE_PATTERN,
259                    line.split('\n')[0].strip(), re.IGNORECASE) or
260          re.search(DISABLE_PATTERN,
261                    line.split('\n')[-1].strip(), re.IGNORECASE))
262
263
264def _EnableYAPF(line):
265  return (re.search(ENABLE_PATTERN,
266                    line.split('\n')[0].strip(), re.IGNORECASE) or
267          re.search(ENABLE_PATTERN,
268                    line.split('\n')[-1].strip(), re.IGNORECASE))
269
270
271def _GetUnifiedDiff(before, after, filename='code'):
272  """Get a unified diff of the changes.
273
274  Arguments:
275    before: (unicode) The original source code.
276    after: (unicode) The reformatted source code.
277    filename: (unicode) The code's filename.
278
279  Returns:
280    The unified diff text.
281  """
282  before = before.splitlines()
283  after = after.splitlines()
284  return '\n'.join(
285      difflib.unified_diff(
286          before,
287          after,
288          filename,
289          filename,
290          '(original)',
291          '(reformatted)',
292          lineterm='')) + '\n'
293