Searched refs:cast_dtype (Results 1 – 3 of 3) sorted by relevance
/external/tensorflow/tensorflow/contrib/distribute/python/examples/ |
D | keras_mnist.py | 42 cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32 68 train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) 74 eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
|
/external/tensorflow/tensorflow/core/graph/ |
D | graph_partition.cc | 197 const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; in AddSend() local 211 if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { in AddSend() 219 cast_builder.Attr("DstT", cast_dtype); in AddSend() 221 if (cast_dtype == DT_BFLOAT16) { in AddSend() 233 send_from.Reset(cast->name(), 0, cast_dtype); in AddSend() 257 DataType cast_dtype = dtype; in AddRecv() local 261 cast_dtype = opts.should_cast(edge); in AddRecv() 278 .Attr("tensor_type", cast_dtype); in AddRecv() 285 if (dtype != cast_dtype) { in AddRecv() 291 .Input(recv->name(), 0, cast_dtype); in AddRecv() [all …]
|
/external/tensorflow/tensorflow/python/framework/ |
D | tensor_util.py | 718 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT")) 719 return pre_cast.astype(cast_dtype.as_numpy_dtype)
|