1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Demo of the tfdbg curses UI: A TF network computing Fibonacci sequence.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import sys 22 23import numpy as np 24from six.moves import xrange # pylint: disable=redefined-builtin 25import tensorflow as tf 26 27from tensorflow.python import debug as tf_debug 28 29FLAGS = None 30 31 32def main(_): 33 sess = tf.Session() 34 35 # Construct the TensorFlow network. 36 n0 = tf.Variable( 37 np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32, name="node_00") 38 n1 = tf.Variable( 39 np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32, name="node_01") 40 41 for i in xrange(2, FLAGS.length): 42 n0, n1 = n1, tf.add(n0, n1, name="node_%.2d" % i) 43 44 sess.run(tf.global_variables_initializer()) 45 46 # Wrap the TensorFlow Session object for debugging. 47 if FLAGS.debug and FLAGS.tensorboard_debug_address: 48 raise ValueError( 49 "The --debug and --tensorboard_debug_address flags are mutually " 50 "exclusive.") 51 if FLAGS.debug: 52 sess = tf_debug.LocalCLIDebugWrapperSession(sess) 53 54 def has_negative(_, tensor): 55 return np.any(tensor < 0) 56 57 sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 58 sess.add_tensor_filter("has_negative", has_negative) 59 elif FLAGS.tensorboard_debug_address: 60 sess = tf_debug.TensorBoardDebugWrapperSession( 61 sess, FLAGS.tensorboard_debug_address) 62 63 print("Fibonacci number at position %d:\n%s" % 64 (FLAGS.length, sess.run(n1))) 65 66 67if __name__ == "__main__": 68 parser = argparse.ArgumentParser() 69 parser.register("type", "bool", lambda v: v.lower() == "true") 70 parser.add_argument( 71 "--tensor_size", 72 type=int, 73 default=1, 74 help="""\ 75 Size of tensor. E.g., if the value is 30, the tensors will have shape 76 [30, 30].\ 77 """) 78 parser.add_argument( 79 "--length", 80 type=int, 81 default=20, 82 help="Length of the fibonacci sequence to compute.") 83 parser.add_argument( 84 "--ui_type", 85 type=str, 86 default="curses", 87 help="Command-line user interface type (curses | readline)") 88 parser.add_argument( 89 "--debug", 90 dest="debug", 91 action="store_true", 92 help="Use TensorFlow Debugger (tfdbg). Mutually exclusive with the " 93 "--tensorboard_debug_address flag.") 94 parser.add_argument( 95 "--tensorboard_debug_address", 96 type=str, 97 default=None, 98 help="Connect to the TensorBoard Debugger Plugin backend specified by " 99 "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " 100 "--debug flag.") 101 102 FLAGS, unparsed = parser.parse_known_args() 103 with tf.Graph().as_default(): 104 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 105