Searched refs:grad_warp (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/contrib/resampler/kernels/ |
D | resampler_ops_gpu.cu.cc | 147 T* __restrict__ grad_warp, const int batch_size, const int data_height, in ResamplerGrad2DKernel() argument 209 atomicAdd(grad_warp + warp_id_x, in ResamplerGrad2DKernel() 212 atomicAdd(grad_warp + warp_id_y, in ResamplerGrad2DKernel() 246 T* __restrict__ grad_warp, const int batch_size, in operator ()() 259 grad_warp_size, grad_warp)); in operator ()() 271 warp, grad_output, grad_data, grad_warp, in operator ()()
|
D | resampler_ops.cc | 209 T* __restrict__ grad_warp, const int batch_size, in operator ()() 220 memset(grad_warp, 0, sizeof(T) * grad_warp_size); in operator ()() 255 grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] += in operator ()() 376 ::tensorflow::Tensor* grad_warp = nullptr; in Compute() 378 OP_REQUIRES_OK(ctx, ctx->allocate_output(1, warp.shape(), &grad_warp)); in Compute() 384 grad_data->flat<T>().data(), grad_warp->flat<T>().data(), batch_size, in Compute()
|
D | resampler_ops.h | 46 T* __restrict__ grad_warp, const int batch_size,
|
/external/tensorflow/tensorflow/contrib/resampler/xla/ |
D | resampler_ops_xla_test.py | 49 grad_data, grad_warp = gen_resampler_ops.resampler_grad( 52 grad_data_tf, grad_warp_tf = sess.run([grad_data, grad_warp], {
|
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | resampler_ops.cc | 662 auto grad_warp = in Compile() local 672 broadcasted_dims, last_warp_dim, data_shape, grad_warp); in Compile()
|