Searched refs:dot_dims (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | resampler_ops.cc | 401 xla::DotDimensionNumbers dot_dims; in CalculateGradWarp() local 403 dot_dims.add_lhs_batch_dimensions(i); in CalculateGradWarp() 404 dot_dims.add_rhs_batch_dimensions(i); in CalculateGradWarp() 406 dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); in CalculateGradWarp() 407 dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); in CalculateGradWarp() 415 neighbors_data, dot_dims, /*precision_config=*/nullptr); in CalculateGradWarp() 423 neighbors_data, dot_dims, /*precision_config=*/nullptr); in CalculateGradWarp() 431 neighbors_data, dot_dims, /*precision_config=*/nullptr); in CalculateGradWarp() 439 neighbors_data, dot_dims, /*precision_config=*/nullptr); in CalculateGradWarp() 538 xla::DotDimensionNumbers dot_dims; in Compile() local [all …]
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | dot_decomposer.cc | 265 std::vector<int64> dot_dims = batch_dim_sizes; in CanonicalizeDot() local 266 dot_dims.push_back(lhs_non_contracting_size); in CanonicalizeDot() 267 dot_dims.push_back(rhs_non_contracting_size); in CanonicalizeDot() 278 ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims), in CanonicalizeDot()
|