Home
last modified time | relevance | path

Searched refs:grad_warp (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/contrib/resampler/kernels/
Dresampler_ops_gpu.cu.cc147 T* __restrict__ grad_warp, const int batch_size, const int data_height, in ResamplerGrad2DKernel() argument
209 atomicAdd(grad_warp + warp_id_x, in ResamplerGrad2DKernel()
212 atomicAdd(grad_warp + warp_id_y, in ResamplerGrad2DKernel()
246 T* __restrict__ grad_warp, const int batch_size, in operator ()()
259 grad_warp_size, grad_warp)); in operator ()()
271 warp, grad_output, grad_data, grad_warp, in operator ()()
Dresampler_ops.cc209 T* __restrict__ grad_warp, const int batch_size, in operator ()()
220 memset(grad_warp, 0, sizeof(T) * grad_warp_size); in operator ()()
255 grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] += in operator ()()
376 ::tensorflow::Tensor* grad_warp = nullptr; in Compute()
378 OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp)); in Compute()
384 grad_data->flat<T>().data(), grad_warp->flat<T>().data(), batch_size, in Compute()
Dresampler_ops.h46 T* __restrict__ grad_warp, const int batch_size,
/external/tensorflow/tensorflow/contrib/resampler/xla/
Dresampler_ops_xla_test.py49 grad_data, grad_warp = gen_resampler_ops.resampler_grad(
52 grad_data_tf, grad_warp_tf = sess.run([grad_data, grad_warp], {
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dresampler_ops.cc662 auto grad_warp = in Compile() local
672 broadcasted_dims, last_warp_dim, data_shape, grad_warp); in Compile()