Home
last modified time | relevance | path

Searched refs:warp_dims (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dresampler_ops.cc59 auto warp_dims = warp_shape.dim_sizes(); in BilinearWeights() local
60 std::vector<int64> broadcast_dims(warp_dims.begin(), warp_dims.end() - 1); in BilinearWeights()
139 int64 data_channels, int warp_dims) { in Gather2by2Neighbors() argument
141 const int64 neighbor_data_dimensions = warp_dims + 2; in Gather2by2Neighbors()
148 gather_dim_numbers.set_index_vector_dim(warp_dims - 1); in Gather2by2Neighbors()
161 return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims}); in Gather2by2Neighbors()
168 XlaOp updates, int64 warp_dims, in ScatterToGradData() argument
171 const int64 neighbor_data_dimensions = warp_dims + 2; in ScatterToGradData()
177 scatter_dim_numbers.set_index_vector_dim(warp_dims - 1); in ScatterToGradData()
256 auto warp_dims = warp_shape.dim_sizes(); in CalculateGradData() local
[all …]