Home
last modified time | relevance | path

Searched refs:accumulation_type (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dlrn_ops.cc51 auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); in Compile() local
52 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); in Compile()
55 squared, XlaHelpers::Zero(builder, accumulation_type), in Compile()
56 *ctx->GetOrCreateAdd(accumulation_type), in Compile()
137 auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); in Compile() local
139 XlaHelpers::ConvertElementType(in_image, accumulation_type); in Compile()
142 squared, XlaHelpers::Zero(builder, accumulation_type), in Compile()
143 *ctx->GetOrCreateAdd(accumulation_type), in Compile()
157 auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type); in Compile()
159 converted_dy, XlaHelpers::Zero(builder, accumulation_type), in Compile()
[all …]
Dsoftmax_op.cc65 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); in Compile() local
67 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, in Compile()
73 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); in Compile()
112 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); in CrossEntropyWithLogits() local
114 XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type); in CrossEntropyWithLogits()
116 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in CrossEntropyWithLogits()
117 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); in CrossEntropyWithLogits()
129 auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type), in CrossEntropyWithLogits()
130 XlaHelpers::Zero(b, accumulation_type), in CrossEntropyWithLogits()
131 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); in CrossEntropyWithLogits()
Dl2loss_op.cc38 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); in Compile() local
39 auto t = XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); in Compile()
41 auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), in Compile()
42 *ctx->GetOrCreateAdd(accumulation_type), dims); in Compile()
Dbatch_norm_op.cc183 const DataType accumulation_type = in Compile() local
186 XlaHelpers::ConvertElementType(grad_backprop, accumulation_type); in Compile()
188 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
189 *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); in Compile()
199 converted = XlaHelpers::ConvertElementType(mul, accumulation_type); in Compile()
201 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
202 *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); in Compile()
Dbias_ops.cc111 const DataType accumulation_type = in Compile() local
114 XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); in Compile()
116 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
117 *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); in Compile()
Dfake_quantize_ops.cc243 const DataType accumulation_type = in Compile() local
263 XlaHelpers::ConvertElementType(select1, accumulation_type), in Compile()
264 XlaHelpers::Zero(b, accumulation_type), in Compile()
265 *ctx->GetOrCreateAdd(accumulation_type)); in Compile()
272 XlaHelpers::ConvertElementType(select2, accumulation_type), in Compile()
273 XlaHelpers::Zero(b, accumulation_type), in Compile()
274 *ctx->GetOrCreateAdd(accumulation_type)); in Compile()
Dimage_ops.cc195 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); in Compile() local
196 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); in Compile()
197 auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
198 *context->GetOrCreateAdd(accumulation_type), in Compile()
202 reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width)); in Compile()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dpooling.cc81 PrimitiveType accumulation_type = init_shape.element_type(); in ComputeSums() local
82 auto add_computation = CreateScalarAddComputation(accumulation_type, b); in ComputeSums()