/external/tensorflow/tensorflow/compiler/xla/service/ |
D | gather_expander.cc | 30 HloInstruction* start_indices, int64 index_vector_dim) { in TransposeIndexVectorDimToLast() argument 33 if (start_indices_shape.dimensions_size() == index_vector_dim) { in TransposeIndexVectorDimToLast() 37 if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) { in TransposeIndexVectorDimToLast() 44 if (i != index_vector_dim) { in TransposeIndexVectorDimToLast() 48 permutation.push_back(index_vector_dim); in TransposeIndexVectorDimToLast() 57 HloInstruction* start_indices, int64 index_vector_dim) { in CanonicalizeGatherIndices() argument 61 TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); in CanonicalizeGatherIndices() 63 index_vector_dim == start_indices->shape().dimensions_size(); in CanonicalizeGatherIndices() 88 int64 index_vector_dim) { in AdjustBatchDimsInAccumulator() argument 92 if (i != index_vector_dim) { in AdjustBatchDimsInAccumulator() [all …]
|
D | scatter_expander.cc | 32 HloInstruction* scatter_indices, int64 index_vector_dim) { in TransposeIndexVectorDimToLast() argument 35 if (scatter_indices_shape.dimensions_size() == index_vector_dim) { in TransposeIndexVectorDimToLast() 39 if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { in TransposeIndexVectorDimToLast() 46 if (i != index_vector_dim) { in TransposeIndexVectorDimToLast() 50 permutation.push_back(index_vector_dim); in TransposeIndexVectorDimToLast() 57 HloInstruction* scatter_indices, int64 index_vector_dim) { in CanonicalizeScatterIndices() argument 61 TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); in CanonicalizeScatterIndices() 62 if (scatter_indices->shape().rank() == index_vector_dim + 1 && in CanonicalizeScatterIndices() 63 scatter_indices->shape().dimensions(index_vector_dim) == 1) { in CanonicalizeScatterIndices() 65 ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); in CanonicalizeScatterIndices() [all …]
|
D | shape_inference.cc | 2898 start_indices_shape[dim_numbers.index_vector_dim()]) { in ValidateGatherDimensionNumbers() 2903 dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), in ValidateGatherDimensionNumbers() 2904 start_indices_shape[dim_numbers.index_vector_dim()]); in ValidateGatherDimensionNumbers() 2977 gather_dim_numbers.index_vector_dim() || in InferGatherShape() 2978 gather_dim_numbers.index_vector_dim() < 0) { in InferGatherShape() 2984 gather_dim_numbers.index_vector_dim()); in InferGatherShape() 2992 gather_dim_numbers.index_vector_dim()) { in InferGatherShape() 3055 if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { in InferGatherShape() 3126 scatter_indices_shape[dim_numbers.index_vector_dim()]) { in ValidateScatterDimensionNumbers() 3132 dim_numbers.index_vector_dim(), in ValidateScatterDimensionNumbers() [all …]
|
D | hlo_evaluator.cc | 855 start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); in OutputBatchIndexToInputIndex() 894 if (index_vector_index_i == dim_numbers_.index_vector_dim()) { in PropagateOutputIndexGatherDimsToIndexVectorIndex() 905 int64 index_vector_dim = dim_numbers_.index_vector_dim(); in FetchIndexVector() local 907 index_vector_index_[index_vector_dim] = i; in FetchIndexVector() 1033 int64 index_vector_dim, const Literal& start_indices, in ReshapedGatherIndices() argument 1035 if (start_indices.shape().dimensions_size() != index_vector_dim) { in ReshapedGatherIndices() 1056 ReshapedGatherIndices(dim_numbers.index_vector_dim(), in HandleGather()
|
D | hlo_instructions.cc | 2332 string index_vector_dim = StrCat( in GatherDimensionNumbersToString() local 2333 "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); in GatherDimensionNumbersToString() 2336 {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, in GatherDimensionNumbersToString() 2343 absl::Span<const int64> start_index_map, int64 index_vector_dim) { in MakeGatherDimNumbers() argument 2355 gather_dim_numbers.set_index_vector_dim(index_vector_dim); in MakeGatherDimNumbers() 2419 string index_vector_dim = StrCat( in ScatterDimensionNumbersToString() local 2420 "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); in ScatterDimensionNumbersToString() 2424 index_vector_dim}, in ScatterDimensionNumbersToString() 2433 int64 index_vector_dim) { in MakeScatterDimNumbers() argument 2445 scatter_dim_numbers.set_index_vector_dim(index_vector_dim); in MakeScatterDimNumbers()
|
D | hlo_evaluator_typed_visitor.h | 2042 int64 index_vector_dim, const Literal& indices, in ReshapedScatterIndices() argument 2044 if (indices.shape().dimensions_size() != index_vector_dim) { in ReshapedScatterIndices() 2129 scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); in UpdateScatterIndexToInputIndex() 2168 if (index_vector_index_i == dim_numbers_.index_vector_dim()) { in PropagateUpdateIndexScatterDimsToIndexVectorIndex() 2179 int64 index_vector_dim = dim_numbers_.index_vector_dim(); in FetchIndexVector() local 2181 index_vector_index_[index_vector_dim] = i; in FetchIndexVector() 2314 ReshapedScatterIndices(dim_numbers.index_vector_dim(), in HandleScatter()
|
D | hlo_parser.cc | 1624 optional<int64> index_vector_dim; in ParseInstructionRhs() local 1626 &index_vector_dim}; in ParseInstructionRhs() 1641 /*index_vector_dim=*/*index_vector_dim); in ParseInstructionRhs() 1659 optional<int64> index_vector_dim; in ParseInstructionRhs() local 1661 &index_vector_dim}; in ParseInstructionRhs() 1677 /*index_vector_dim=*/*index_vector_dim); in ParseInstructionRhs()
|
D | hlo_parser_test.cc | 853 …={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_size… in CreateTestCases() 872 … inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%… in CreateTestCases() 1333 …={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_size… in CreateTestCases()
|
D | elemental_ir_emitter.cc | 1907 dim_numbers.index_vector_dim(), in EmitElementalGather() 1941 if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { in EmitElementalGather() 1949 indices_shape.dimensions(dim_numbers.index_vector_dim()); in EmitElementalGather() 1951 gather_index_index_components[dim_numbers.index_vector_dim()] = in EmitElementalGather()
|
D | hlo_instructions.h | 1382 absl::Span<const int64> start_index_map, int64 index_vector_dim); 1420 int64 index_vector_dim);
|
D | indexed_array_analysis.cc | 256 if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { in ComputeArrayForGather()
|
/external/tensorflow/tensorflow/compiler/tf2xla/lib/ |
D | scatter.cc | 193 VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); in XlaScatter()
|
/external/tensorflow/tensorflow/compiler/xla/g3doc/ |
D | operation_semantics.md | 1278 | `index_vector_dim` | `int64` | The dimension in | 1304 If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider 1306 shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the 1313 `start_indices.shape`, skipping `index_vector_dim` (i.e. pick 1314 `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and 1328 Combine(A, b) inserts b at position `index_vector_dim` into A. Note that 1374 `index_vector_dim` is set to `start_indices.rank` - `1` in all of the 1375 examples that follow. More interesting values for `index_vector_dim` does not 2115 <b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window… 2123 `index_vector_dim` | `int64` | The dimension in `scatter_indices` that cont… [all …]
|
/external/tensorflow/tensorflow/compiler/xla/python/ |
D | xla_client.py | 1958 self.index_vector_dim = 0 1970 self.index_vector_dim = 0
|
D | xla_client_test.py | 1224 dnums.index_vector_dim = 2 1687 dnums.index_vector_dim = 1
|
/external/tensorflow/tensorflow/compiler/xla/ |
D | xla_data.proto | 462 int64 index_vector_dim = 4; field 476 int64 index_vector_dim = 4; field
|
/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
D | ir_emitter_unnested.cc | 1124 if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) { in EmitScatter() 1127 dim_numbers.index_vector_dim()); in EmitScatter() 1135 raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(), in EmitScatter() 1142 raw_scatter_index_multidim[dim_numbers.index_vector_dim()] = in EmitScatter()
|