1# Lint as: python2, python3
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23
24import six
25
26from tensorflow.tools.compatibility import ast_edits
27
28
29class TFAPIChangeSpec(ast_edits.APIChangeSpec):
30  """List of maps that describe what changed in the API."""
31
32  def __init__(self):
33    # Maps from a function name to a dictionary that describes how to
34    # map from an old argument keyword to the new argument keyword.
35    self.function_keyword_renames = {
36        "tf.batch_matmul": {
37            "adj_x": "adjoint_a",
38            "adj_y": "adjoint_b",
39        },
40        "tf.count_nonzero": {
41            "reduction_indices": "axis"
42        },
43        "tf.reduce_all": {
44            "reduction_indices": "axis"
45        },
46        "tf.reduce_any": {
47            "reduction_indices": "axis"
48        },
49        "tf.reduce_max": {
50            "reduction_indices": "axis"
51        },
52        "tf.reduce_mean": {
53            "reduction_indices": "axis"
54        },
55        "tf.reduce_min": {
56            "reduction_indices": "axis"
57        },
58        "tf.reduce_prod": {
59            "reduction_indices": "axis"
60        },
61        "tf.reduce_sum": {
62            "reduction_indices": "axis"
63        },
64        "tf.reduce_logsumexp": {
65            "reduction_indices": "axis"
66        },
67        "tf.expand_dims": {
68            "dim": "axis"
69        },
70        "tf.argmax": {
71            "dimension": "axis"
72        },
73        "tf.argmin": {
74            "dimension": "axis"
75        },
76        "tf.reduce_join": {
77            "reduction_indices": "axis"
78        },
79        "tf.sparse_concat": {
80            "concat_dim": "axis"
81        },
82        "tf.sparse_split": {
83            "split_dim": "axis"
84        },
85        "tf.sparse_reduce_sum": {
86            "reduction_axes": "axis"
87        },
88        "tf.reverse_sequence": {
89            "seq_dim": "seq_axis",
90            "batch_dim": "batch_axis"
91        },
92        "tf.sparse_reduce_sum_sparse": {
93            "reduction_axes": "axis"
94        },
95        "tf.squeeze": {
96            "squeeze_dims": "axis"
97        },
98        "tf.split": {
99            "split_dim": "axis",
100            "num_split": "num_or_size_splits"
101        },
102        "tf.concat": {
103            "concat_dim": "axis"
104        },
105    }
106
107    # Mapping from function to the new name of the function
108    self.symbol_renames = {
109        "tf.inv": "tf.reciprocal",
110        "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar",
111        "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram",
112        "tf.listdiff": "tf.setdiff1d",
113        "tf.list_diff": "tf.setdiff1d",
114        "tf.mul": "tf.multiply",
115        "tf.neg": "tf.negative",
116        "tf.sub": "tf.subtract",
117        "tf.train.SummaryWriter": "tf.summary.FileWriter",
118        "tf.scalar_summary": "tf.summary.scalar",
119        "tf.histogram_summary": "tf.summary.histogram",
120        "tf.audio_summary": "tf.summary.audio",
121        "tf.image_summary": "tf.summary.image",
122        "tf.merge_summary": "tf.summary.merge",
123        "tf.merge_all_summaries": "tf.summary.merge_all",
124        "tf.image.per_image_whitening": "tf.image.per_image_standardization",
125        "tf.all_variables": "tf.global_variables",
126        "tf.VARIABLES": "tf.GLOBAL_VARIABLES",
127        "tf.initialize_all_variables": "tf.global_variables_initializer",
128        "tf.initialize_variables": "tf.variables_initializer",
129        "tf.initialize_local_variables": "tf.local_variables_initializer",
130        "tf.batch_matrix_diag": "tf.matrix_diag",
131        "tf.batch_band_part": "tf.band_part",
132        "tf.batch_set_diag": "tf.set_diag",
133        "tf.batch_matrix_transpose": "tf.matrix_transpose",
134        "tf.batch_matrix_determinant": "tf.matrix_determinant",
135        "tf.batch_matrix_inverse": "tf.matrix_inverse",
136        "tf.batch_cholesky": "tf.cholesky",
137        "tf.batch_cholesky_solve": "tf.cholesky_solve",
138        "tf.batch_matrix_solve": "tf.matrix_solve",
139        "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve",
140        "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls",
141        "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig",
142        "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals",
143        "tf.batch_svd": "tf.svd",
144        "tf.batch_fft": "tf.fft",
145        "tf.batch_ifft": "tf.ifft",
146        "tf.batch_fft2d": "tf.fft2d",
147        "tf.batch_ifft2d": "tf.ifft2d",
148        "tf.batch_fft3d": "tf.fft3d",
149        "tf.batch_ifft3d": "tf.ifft3d",
150        "tf.select": "tf.where",
151        "tf.complex_abs": "tf.abs",
152        "tf.batch_matmul": "tf.matmul",
153        "tf.pack": "tf.stack",
154        "tf.unpack": "tf.unstack",
155        "tf.op_scope": "tf.name_scope",
156    }
157
158    self.change_to_function = {
159        "tf.ones_initializer",
160        "tf.zeros_initializer",
161    }
162
163    # Functions that were reordered should be changed to the new keyword args
164    # for safety, if positional arguments are used. If you have reversed the
165    # positional arguments yourself, this could do the wrong thing.
166    self.function_reorders = {
167        "tf.split": ["axis", "num_or_size_splits", "value", "name"],
168        "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"],
169        "tf.concat": ["concat_dim", "values", "name"],
170        "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"],
171        "tf.nn.softmax_cross_entropy_with_logits": [
172            "logits", "labels", "dim", "name"
173        ],
174        "tf.nn.sparse_softmax_cross_entropy_with_logits": [
175            "logits", "labels", "name"
176        ],
177        "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"],
178        "tf.op_scope": ["values", "name", "default_name"],
179    }
180
181    # Warnings that should be printed if corresponding functions are used.
182    self.function_warnings = {
183        "tf.reverse": (
184            ast_edits.ERROR,
185            "tf.reverse has had its argument semantics changed "
186            "significantly. The converter cannot detect this reliably, so "
187            "you need to inspect this usage manually.\n"),
188    }
189
190    self.module_deprecations = {}
191
192
193if __name__ == "__main__":
194  parser = argparse.ArgumentParser(
195      formatter_class=argparse.RawDescriptionHelpFormatter,
196      description="""Convert a TensorFlow Python file to 1.0
197
198Simple usage:
199  tf_convert.py --infile foo.py --outfile bar.py
200  tf_convert.py --intree ~/code/old --outtree ~/code/new
201""")
202  parser.add_argument(
203      "--infile",
204      dest="input_file",
205      help="If converting a single file, the name of the file "
206      "to convert")
207  parser.add_argument(
208      "--outfile",
209      dest="output_file",
210      help="If converting a single file, the output filename.")
211  parser.add_argument(
212      "--intree",
213      dest="input_tree",
214      help="If converting a whole tree of files, the directory "
215      "to read from (relative or absolute).")
216  parser.add_argument(
217      "--outtree",
218      dest="output_tree",
219      help="If converting a whole tree of files, the output "
220      "directory (relative or absolute).")
221  parser.add_argument(
222      "--copyotherfiles",
223      dest="copy_other_files",
224      help=("If converting a whole tree of files, whether to "
225            "copy the other files."),
226      type=bool,
227      default=False)
228  parser.add_argument(
229      "--reportfile",
230      dest="report_filename",
231      help=("The name of the file where the report log is "
232            "stored."
233            "(default: %(default)s)"),
234      default="report.txt")
235  args = parser.parse_args()
236
237  upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec())
238  report_text = None
239  report_filename = args.report_filename
240  files_processed = 0
241  if args.input_file:
242    files_processed, report_text, errors = upgrade.process_file(
243        args.input_file, args.output_file)
244    files_processed = 1
245  elif args.input_tree:
246    files_processed, report_text, errors = upgrade.process_tree(
247        args.input_tree, args.output_tree, args.copy_other_files)
248  else:
249    parser.print_help()
250  if report_text:
251    open(report_filename, "w").write(six.ensure_str(report_text))
252    print("TensorFlow 1.0 Upgrade Script")
253    print("-----------------------------")
254    print("Converted %d files\n" % files_processed)
255    print("Detected %d errors that require attention" % len(errors))
256    print("-" * 80)
257    print("\n".join(errors))
258    print("\nMake sure to read the detailed log %r\n" % report_filename)
259