Home
last modified time | relevance | path

Searched refs:cast_dtype (Results 1 – 3 of 3) sorted by relevance

/external/tensorflow/tensorflow/contrib/distribute/python/examples/
Dkeras_mnist.py42 cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32
68 train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
74 eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y))
/external/tensorflow/tensorflow/core/graph/
Dgraph_partition.cc197 const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; in AddSend() local
211 if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { in AddSend()
219 cast_builder.Attr("DstT", cast_dtype); in AddSend()
221 if (cast_dtype == DT_BFLOAT16) { in AddSend()
233 send_from.Reset(cast->name(), 0, cast_dtype); in AddSend()
257 DataType cast_dtype = dtype; in AddRecv() local
261 cast_dtype = opts.should_cast(edge); in AddRecv()
278 .Attr("tensor_type", cast_dtype); in AddRecv()
285 if (dtype != cast_dtype) { in AddRecv()
291 .Input(recv->name(), 0, cast_dtype); in AddRecv()
[all …]
/external/tensorflow/tensorflow/python/framework/
Dtensor_util.py718 cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
719 return pre_cast.astype(cast_dtype.as_numpy_dtype)