Searched refs:batch_dim_indices (Results 1 – 1 of 1) sorted by relevance
173 std::vector<int64> batch_dim_indices(num_batch_dims); in QRBlock() local174 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); in QRBlock()197 /*broadcast_dimensions=*/batch_dim_indices); in QRBlock()209 Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); in QRBlock()255 std::vector<int64> batch_dim_indices(batch_dims.size()); in ComputeWYRepresentation() local256 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); in ComputeWYRepresentation()280 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); in ComputeWYRepresentation()297 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); in ComputeWYRepresentation()