Home
last modified time | relevance | path

Searched refs:index_vector_dim (Results 1 – 17 of 17) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dgather_expander.cc30 HloInstruction* start_indices, int64 index_vector_dim) { in TransposeIndexVectorDimToLast() argument
33 if (start_indices_shape.dimensions_size() == index_vector_dim) { in TransposeIndexVectorDimToLast()
37 if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) { in TransposeIndexVectorDimToLast()
44 if (i != index_vector_dim) { in TransposeIndexVectorDimToLast()
48 permutation.push_back(index_vector_dim); in TransposeIndexVectorDimToLast()
57 HloInstruction* start_indices, int64 index_vector_dim) { in CanonicalizeGatherIndices() argument
61 TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); in CanonicalizeGatherIndices()
63 index_vector_dim == start_indices->shape().dimensions_size(); in CanonicalizeGatherIndices()
88 int64 index_vector_dim) { in AdjustBatchDimsInAccumulator() argument
92 if (i != index_vector_dim) { in AdjustBatchDimsInAccumulator()
[all …]
Dscatter_expander.cc32 HloInstruction* scatter_indices, int64 index_vector_dim) { in TransposeIndexVectorDimToLast() argument
35 if (scatter_indices_shape.dimensions_size() == index_vector_dim) { in TransposeIndexVectorDimToLast()
39 if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { in TransposeIndexVectorDimToLast()
46 if (i != index_vector_dim) { in TransposeIndexVectorDimToLast()
50 permutation.push_back(index_vector_dim); in TransposeIndexVectorDimToLast()
57 HloInstruction* scatter_indices, int64 index_vector_dim) { in CanonicalizeScatterIndices() argument
61 TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); in CanonicalizeScatterIndices()
62 if (scatter_indices->shape().rank() == index_vector_dim + 1 && in CanonicalizeScatterIndices()
63 scatter_indices->shape().dimensions(index_vector_dim) == 1) { in CanonicalizeScatterIndices()
65 ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); in CanonicalizeScatterIndices()
[all …]
Dshape_inference.cc2898 start_indices_shape[dim_numbers.index_vector_dim()]) { in ValidateGatherDimensionNumbers()
2903 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), in ValidateGatherDimensionNumbers()
2904 start_indices_shape[dim_numbers.index_vector_dim()]); in ValidateGatherDimensionNumbers()
2977 gather_dim_numbers.index_vector_dim() || in InferGatherShape()
2978 gather_dim_numbers.index_vector_dim() < 0) { in InferGatherShape()
2984 gather_dim_numbers.index_vector_dim()); in InferGatherShape()
2992 gather_dim_numbers.index_vector_dim()) { in InferGatherShape()
3055 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { in InferGatherShape()
3126 scatter_indices_shape[dim_numbers.index_vector_dim()]) { in ValidateScatterDimensionNumbers()
3132 dim_numbers.index_vector_dim(), in ValidateScatterDimensionNumbers()
[all …]
Dhlo_evaluator.cc855 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); in OutputBatchIndexToInputIndex()
894 if (index_vector_index_i == dim_numbers_.index_vector_dim()) { in PropagateOutputIndexGatherDimsToIndexVectorIndex()
905 int64 index_vector_dim = dim_numbers_.index_vector_dim(); in FetchIndexVector() local
907 index_vector_index_[index_vector_dim] = i; in FetchIndexVector()
1033 int64 index_vector_dim, const Literal& start_indices, in ReshapedGatherIndices() argument
1035 if (start_indices.shape().dimensions_size() != index_vector_dim) { in ReshapedGatherIndices()
1056 ReshapedGatherIndices(dim_numbers.index_vector_dim(), in HandleGather()
Dhlo_instructions.cc2332 string index_vector_dim = StrCat( in GatherDimensionNumbersToString() local
2333 "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); in GatherDimensionNumbersToString()
2336 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, in GatherDimensionNumbersToString()
2343 absl::Span<const int64> start_index_map, int64 index_vector_dim) { in MakeGatherDimNumbers() argument
2355 gather_dim_numbers.set_index_vector_dim(index_vector_dim); in MakeGatherDimNumbers()
2419 string index_vector_dim = StrCat( in ScatterDimensionNumbersToString() local
2420 "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); in ScatterDimensionNumbersToString()
2424 index_vector_dim}, in ScatterDimensionNumbersToString()
2433 int64 index_vector_dim) { in MakeScatterDimNumbers() argument
2445 scatter_dim_numbers.set_index_vector_dim(index_vector_dim); in MakeScatterDimNumbers()
Dhlo_evaluator_typed_visitor.h2042 int64 index_vector_dim, const Literal& indices, in ReshapedScatterIndices() argument
2044 if (indices.shape().dimensions_size() != index_vector_dim) { in ReshapedScatterIndices()
2129 scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); in UpdateScatterIndexToInputIndex()
2168 if (index_vector_index_i == dim_numbers_.index_vector_dim()) { in PropagateUpdateIndexScatterDimsToIndexVectorIndex()
2179 int64 index_vector_dim = dim_numbers_.index_vector_dim(); in FetchIndexVector() local
2181 index_vector_index_[index_vector_dim] = i; in FetchIndexVector()
2314 ReshapedScatterIndices(dim_numbers.index_vector_dim(), in HandleScatter()
Dhlo_parser.cc1624 optional<int64> index_vector_dim; in ParseInstructionRhs() local
1626 &index_vector_dim}; in ParseInstructionRhs()
1641 /*index_vector_dim=*/*index_vector_dim); in ParseInstructionRhs()
1659 optional<int64> index_vector_dim; in ParseInstructionRhs() local
1661 &index_vector_dim}; in ParseInstructionRhs()
1677 /*index_vector_dim=*/*index_vector_dim); in ParseInstructionRhs()
Dhlo_parser_test.cc853 …={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_size… in CreateTestCases()
872 … inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%… in CreateTestCases()
1333 …={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_size… in CreateTestCases()
Delemental_ir_emitter.cc1907 dim_numbers.index_vector_dim(), in EmitElementalGather()
1941 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { in EmitElementalGather()
1949 indices_shape.dimensions(dim_numbers.index_vector_dim()); in EmitElementalGather()
1951 gather_index_index_components[dim_numbers.index_vector_dim()] = in EmitElementalGather()
Dhlo_instructions.h1382 absl::Span<const int64> start_index_map, int64 index_vector_dim);
1420 int64 index_vector_dim);
Dindexed_array_analysis.cc256 if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { in ComputeArrayForGather()
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dscatter.cc193 VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); in XlaScatter()
/external/tensorflow/tensorflow/compiler/xla/g3doc/
Doperation_semantics.md1278 | `index_vector_dim` | `int64` | The dimension in |
1304 If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
1306 shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
1313 `start_indices.shape`, skipping `index_vector_dim` (i.e. pick
1314 `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
1328 Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
1374 `index_vector_dim` is set to `start_indices.rank` - `1` in all of the
1375 examples that follow. More interesting values for `index_vector_dim` does not
2115 <b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window…
2123 `index_vector_dim` | `int64` | The dimension in `scatter_indices` that cont…
[all …]
/external/tensorflow/tensorflow/compiler/xla/python/
Dxla_client.py1958 self.index_vector_dim = 0
1970 self.index_vector_dim = 0
Dxla_client_test.py1224 dnums.index_vector_dim = 2
1687 dnums.index_vector_dim = 1
/external/tensorflow/tensorflow/compiler/xla/
Dxla_data.proto462 int64 index_vector_dim = 4; field
476 int64 index_vector_dim = 4; field
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dir_emitter_unnested.cc1124 if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { in EmitScatter()
1127 dim_numbers.index_vector_dim()); in EmitScatter()
1135 raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(), in EmitScatter()
1142 raw_scatter_index_multidim[dim_numbers.index_vector_dim()] = in EmitScatter()