1# Lint as: python2, python3 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import argparse 23 24import six 25 26from tensorflow.tools.compatibility import ast_edits 27 28 29class TFAPIChangeSpec(ast_edits.APIChangeSpec): 30 """List of maps that describe what changed in the API.""" 31 32 def __init__(self): 33 # Maps from a function name to a dictionary that describes how to 34 # map from an old argument keyword to the new argument keyword. 35 self.function_keyword_renames = { 36 "tf.batch_matmul": { 37 "adj_x": "adjoint_a", 38 "adj_y": "adjoint_b", 39 }, 40 "tf.count_nonzero": { 41 "reduction_indices": "axis" 42 }, 43 "tf.reduce_all": { 44 "reduction_indices": "axis" 45 }, 46 "tf.reduce_any": { 47 "reduction_indices": "axis" 48 }, 49 "tf.reduce_max": { 50 "reduction_indices": "axis" 51 }, 52 "tf.reduce_mean": { 53 "reduction_indices": "axis" 54 }, 55 "tf.reduce_min": { 56 "reduction_indices": "axis" 57 }, 58 "tf.reduce_prod": { 59 "reduction_indices": "axis" 60 }, 61 "tf.reduce_sum": { 62 "reduction_indices": "axis" 63 }, 64 "tf.reduce_logsumexp": { 65 "reduction_indices": "axis" 66 }, 67 "tf.expand_dims": { 68 "dim": "axis" 69 }, 70 "tf.argmax": { 71 "dimension": "axis" 72 }, 73 "tf.argmin": { 74 "dimension": "axis" 75 }, 76 "tf.reduce_join": { 77 "reduction_indices": "axis" 78 }, 79 "tf.sparse_concat": { 80 "concat_dim": "axis" 81 }, 82 "tf.sparse_split": { 83 "split_dim": "axis" 84 }, 85 "tf.sparse_reduce_sum": { 86 "reduction_axes": "axis" 87 }, 88 "tf.reverse_sequence": { 89 "seq_dim": "seq_axis", 90 "batch_dim": "batch_axis" 91 }, 92 "tf.sparse_reduce_sum_sparse": { 93 "reduction_axes": "axis" 94 }, 95 "tf.squeeze": { 96 "squeeze_dims": "axis" 97 }, 98 "tf.split": { 99 "split_dim": "axis", 100 "num_split": "num_or_size_splits" 101 }, 102 "tf.concat": { 103 "concat_dim": "axis" 104 }, 105 } 106 107 # Mapping from function to the new name of the function 108 self.symbol_renames = { 109 "tf.inv": "tf.reciprocal", 110 "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", 111 "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", 112 "tf.listdiff": "tf.setdiff1d", 113 "tf.list_diff": "tf.setdiff1d", 114 "tf.mul": "tf.multiply", 115 "tf.neg": "tf.negative", 116 "tf.sub": "tf.subtract", 117 "tf.train.SummaryWriter": "tf.summary.FileWriter", 118 "tf.scalar_summary": "tf.summary.scalar", 119 "tf.histogram_summary": "tf.summary.histogram", 120 "tf.audio_summary": "tf.summary.audio", 121 "tf.image_summary": "tf.summary.image", 122 "tf.merge_summary": "tf.summary.merge", 123 "tf.merge_all_summaries": "tf.summary.merge_all", 124 "tf.image.per_image_whitening": "tf.image.per_image_standardization", 125 "tf.all_variables": "tf.global_variables", 126 "tf.VARIABLES": "tf.GLOBAL_VARIABLES", 127 "tf.initialize_all_variables": "tf.global_variables_initializer", 128 "tf.initialize_variables": "tf.variables_initializer", 129 "tf.initialize_local_variables": "tf.local_variables_initializer", 130 "tf.batch_matrix_diag": "tf.matrix_diag", 131 "tf.batch_band_part": "tf.band_part", 132 "tf.batch_set_diag": "tf.set_diag", 133 "tf.batch_matrix_transpose": "tf.matrix_transpose", 134 "tf.batch_matrix_determinant": "tf.matrix_determinant", 135 "tf.batch_matrix_inverse": "tf.matrix_inverse", 136 "tf.batch_cholesky": "tf.cholesky", 137 "tf.batch_cholesky_solve": "tf.cholesky_solve", 138 "tf.batch_matrix_solve": "tf.matrix_solve", 139 "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve", 140 "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls", 141 "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig", 142 "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals", 143 "tf.batch_svd": "tf.svd", 144 "tf.batch_fft": "tf.fft", 145 "tf.batch_ifft": "tf.ifft", 146 "tf.batch_fft2d": "tf.fft2d", 147 "tf.batch_ifft2d": "tf.ifft2d", 148 "tf.batch_fft3d": "tf.fft3d", 149 "tf.batch_ifft3d": "tf.ifft3d", 150 "tf.select": "tf.where", 151 "tf.complex_abs": "tf.abs", 152 "tf.batch_matmul": "tf.matmul", 153 "tf.pack": "tf.stack", 154 "tf.unpack": "tf.unstack", 155 "tf.op_scope": "tf.name_scope", 156 } 157 158 self.change_to_function = { 159 "tf.ones_initializer", 160 "tf.zeros_initializer", 161 } 162 163 # Functions that were reordered should be changed to the new keyword args 164 # for safety, if positional arguments are used. If you have reversed the 165 # positional arguments yourself, this could do the wrong thing. 166 self.function_reorders = { 167 "tf.split": ["axis", "num_or_size_splits", "value", "name"], 168 "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"], 169 "tf.concat": ["concat_dim", "values", "name"], 170 "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], 171 "tf.nn.softmax_cross_entropy_with_logits": [ 172 "logits", "labels", "dim", "name" 173 ], 174 "tf.nn.sparse_softmax_cross_entropy_with_logits": [ 175 "logits", "labels", "name" 176 ], 177 "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"], 178 "tf.op_scope": ["values", "name", "default_name"], 179 } 180 181 # Warnings that should be printed if corresponding functions are used. 182 self.function_warnings = { 183 "tf.reverse": ( 184 ast_edits.ERROR, 185 "tf.reverse has had its argument semantics changed " 186 "significantly. The converter cannot detect this reliably, so " 187 "you need to inspect this usage manually.\n"), 188 } 189 190 self.module_deprecations = {} 191 192 193if __name__ == "__main__": 194 parser = argparse.ArgumentParser( 195 formatter_class=argparse.RawDescriptionHelpFormatter, 196 description="""Convert a TensorFlow Python file to 1.0 197 198Simple usage: 199 tf_convert.py --infile foo.py --outfile bar.py 200 tf_convert.py --intree ~/code/old --outtree ~/code/new 201""") 202 parser.add_argument( 203 "--infile", 204 dest="input_file", 205 help="If converting a single file, the name of the file " 206 "to convert") 207 parser.add_argument( 208 "--outfile", 209 dest="output_file", 210 help="If converting a single file, the output filename.") 211 parser.add_argument( 212 "--intree", 213 dest="input_tree", 214 help="If converting a whole tree of files, the directory " 215 "to read from (relative or absolute).") 216 parser.add_argument( 217 "--outtree", 218 dest="output_tree", 219 help="If converting a whole tree of files, the output " 220 "directory (relative or absolute).") 221 parser.add_argument( 222 "--copyotherfiles", 223 dest="copy_other_files", 224 help=("If converting a whole tree of files, whether to " 225 "copy the other files."), 226 type=bool, 227 default=False) 228 parser.add_argument( 229 "--reportfile", 230 dest="report_filename", 231 help=("The name of the file where the report log is " 232 "stored." 233 "(default: %(default)s)"), 234 default="report.txt") 235 args = parser.parse_args() 236 237 upgrade = ast_edits.ASTCodeUpgrader(TFAPIChangeSpec()) 238 report_text = None 239 report_filename = args.report_filename 240 files_processed = 0 241 if args.input_file: 242 files_processed, report_text, errors = upgrade.process_file( 243 args.input_file, args.output_file) 244 files_processed = 1 245 elif args.input_tree: 246 files_processed, report_text, errors = upgrade.process_tree( 247 args.input_tree, args.output_tree, args.copy_other_files) 248 else: 249 parser.print_help() 250 if report_text: 251 open(report_filename, "w").write(six.ensure_str(report_text)) 252 print("TensorFlow 1.0 Upgrade Script") 253 print("-----------------------------") 254 print("Converted %d files\n" % files_processed) 255 print("Detected %d errors that require attention" % len(errors)) 256 print("-" * 80) 257 print("\n".join(errors)) 258 print("\nMake sure to read the detailed log %r\n" % report_filename) 259