Home
last modified time | relevance | path

Searched refs:HloInstruction (Results 1 – 25 of 381) sorted by relevance

12345678910>>...16

/external/tensorflow/tensorflow/compiler/xla/service/
Dbfloat16_propagation_test.cc39 bool SupportsBF16Operand(const HloInstruction& hlo, in SupportsBF16Operand()
44 bool SupportsBF16Output(const HloInstruction& hlo) const override { in SupportsBF16Output()
48 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { in SupportsMixedPrecisions()
52 bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, in EffectiveOperandPrecisionIsBF16()
76 bool OutputsBF16(const HloInstruction* inst) { in OutputsBF16()
85 std::unique_ptr<HloInstruction> CreateDot(const Shape& shape, in CreateDot()
86 HloInstruction* lhs, in CreateDot()
87 HloInstruction* rhs) { in CreateDot()
91 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, in CreateDot()
102 HloInstruction* a = in TEST_F()
[all …]
Dhlo_creation_utils.h32 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
33 HloInstruction* rhs);
37 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction,
38 HloInstruction* lhs,
39 HloInstruction* rhs);
44 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
45 HloInstruction* padding_value,
50 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
57 StatusOr<HloInstruction*> MakeConvolveHlo(
58 HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
[all …]
Dhlo_instruction.h279 class HloInstruction {
352 virtual ~HloInstruction();
362 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
364 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
368 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
373 static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
376 static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
380 static std::unique_ptr<HloInstruction> CreateGetTupleElement(
381 const Shape& shape, HloInstruction* operand, int64 index);
384 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
[all …]
Dhlo_computation.h72 HloInstruction* fusion_instruction = nullptr)
82 HloInstruction* root_instruction = nullptr);
84 HloInstruction* AddInstruction( in AddInstruction()
85 std::unique_ptr<HloInstruction> instruction) { in AddInstruction()
92 const std::function<Status(const HloInstruction*)>& func) const { in ForEachInstruction()
101 HloInstruction* last_added_instruction_;
102 HloInstruction* fusion_instruction_;
103 std::vector<std::unique_ptr<HloInstruction>> instructions_;
108 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
124 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction);
[all …]
Dtuple_simplifier_test.cc61 HloInstruction* param0 = builder.AddInstruction( in TEST_F()
62 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); in TEST_F()
63 HloInstruction* param1 = builder.AddInstruction( in TEST_F()
64 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); in TEST_F()
65 HloInstruction* param2 = builder.AddInstruction( in TEST_F()
66 HloInstruction::CreateParameter(2, scalar_shape_, "param2")); in TEST_F()
67 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); in TEST_F()
77 HloInstruction* param = builder.AddInstruction( in TEST_F()
78 HloInstruction::CreateParameter(0, tuple_shape_, "param")); in TEST_F()
80 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); in TEST_F()
[all …]
Dhlo_cost_analysis.h52 Status HandleElementwiseUnary(const HloInstruction* hlo) override;
53 Status HandleElementwiseBinary(const HloInstruction* hlo) override;
54 Status HandleConstant(const HloInstruction* constant) override;
55 Status HandleIota(const HloInstruction* iota) override;
57 const HloInstruction* get_tuple_element) override;
58 Status HandleSelect(const HloInstruction* hlo) override;
59 Status HandleTupleSelect(const HloInstruction* hlo) override;
60 Status HandleCompare(const HloInstruction* compare) override;
61 Status HandleClamp(const HloInstruction* clamp) override;
62 Status HandleReducePrecision(const HloInstruction* hlo) override;
[all …]
Dbfloat16_conversion_folding_test.cc35 bool SupportsBF16Operand(const HloInstruction& hlo, in SupportsBF16Operand()
47 bool SupportsBF16Output(const HloInstruction& hlo) const override { in SupportsBF16Output()
58 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { in SupportsMixedPrecisions()
88 HloInstruction* a = builder.AddInstruction( in TEST_F()
89 HloInstruction::CreateParameter(0, f32_shape, "a")); in TEST_F()
90 HloInstruction* b = builder.AddInstruction( in TEST_F()
91 HloInstruction::CreateParameter(1, f32_shape, "b")); in TEST_F()
92 HloInstruction* c = builder.AddInstruction( in TEST_F()
93 HloInstruction::CreateParameter(2, f32_shape, "c")); in TEST_F()
95 HloInstruction* add0 = builder.AddInstruction( in TEST_F()
[all …]
Dhlo_verifier.h40 Status Preprocess(HloInstruction* hlo) override;
42 Status HandleElementwiseUnary(HloInstruction* hlo) override;
43 Status HandleElementwiseBinary(HloInstruction* hlo) override;
44 Status HandleClamp(HloInstruction* clamp) override;
45 Status HandleSelect(HloInstruction* select) override;
46 Status HandleTupleSelect(HloInstruction* tuple_select) override;
47 Status HandleConcatenate(HloInstruction* concatenate) override;
48 Status HandleIota(HloInstruction* iota) override;
49 Status HandleConvert(HloInstruction* convert) override;
50 Status HandleBitcastConvert(HloInstruction* convert) override;
[all …]
Dhlo_instructions.h26 class HloBatchNormInstruction : public HloInstruction {
41 HloInstruction* operand,
42 HloInstruction* scale, float epsilon,
49 const HloInstruction& other,
62 HloInstruction* operand,
63 HloInstruction* scale,
64 HloInstruction* offset,
69 std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
70 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
77 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
[all …]
Dwhile_loop_invariant_code_motion_test.cc38 HloInstruction** while_instruction) { in FindOnlyWhileInstruction()
54 HloInstruction::CreateParameter(0, param_shape, "param")); in MakeAlwaysTrueComputation()
56 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))); in MakeAlwaysTrueComputation()
68 HloInstruction* param = builder.AddInstruction( in TEST_F()
69 HloInstruction::CreateParameter(0, while_shape, "param")); in TEST_F()
70 HloInstruction* gte_0 = builder.AddInstruction( in TEST_F()
71 HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); in TEST_F()
72 HloInstruction* gte_1 = builder.AddInstruction( in TEST_F()
73 HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); in TEST_F()
74 HloInstruction* add_result = in TEST_F()
[all …]
Dhlo_instruction_test.cc53 Status DefaultAction(HloInstruction* hlo_instruction) override { in DefaultAction()
58 Status HandleParameter(HloInstruction* parameter) override { in HandleParameter()
64 Status HandleConstant(HloInstruction* constant) override { in HandleConstant()
70 Status HandleAdd(HloInstruction* add) override { in HandleAdd()
80 Status HandleNegate(HloInstruction* negate) override { in HandleNegate()
88 Status HandleMap(HloInstruction* map) override { in HandleMap()
90 for (HloInstruction* arg : map->operands()) { in HandleMap()
97 Status HandleReduce(HloInstruction* reduce) override { in HandleReduce()
107 int64 NumOperands(const HloInstruction* node) { in NumOperands()
113 int64 NumUsers(const HloInstruction* node) { in NumUsers()
[all …]
Dcopy_insertion_test.cc90 HloInstruction* x = builder.AddInstruction( in TEST_F()
91 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); in TEST_F()
92 HloInstruction* tuple = in TEST_F()
93 builder.AddInstruction(HloInstruction::CreateTuple({x})); in TEST_F()
110 HloInstruction* constant = builder.AddInstruction( in TEST_F()
111 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); in TEST_F()
112 HloInstruction* tuple = in TEST_F()
113 builder.AddInstruction(HloInstruction::CreateTuple({constant})); in TEST_F()
133 HloInstruction* constant = in TEST_F()
134 builder.AddInstruction(HloInstruction::CreateConstant( in TEST_F()
[all …]
Dhlo_instruction.cc62 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( in CreateFromProto()
64 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, in CreateFromProto()
98 std::unique_ptr<HloInstruction> instruction; in CreateFromProto()
103 std::vector<HloInstruction*> result(proto.operand_ids_size()); in CreateFromProto()
593 instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); in CreateFromProto()
637 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter( in CreateParameter()
643 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace( in CreateTrace()
644 const string& tag, HloInstruction* operand) { in CreateTrace()
648 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant( in CreateConstant()
653 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota( in CreateIota()
[all …]
Ddynamic_dimension_inference.cc33 Status DefaultAction(HloInstruction* hlo) override;
42 Status HandleParameter(HloInstruction* hlo) override;
44 Status HandleReduce(HloInstruction* hlo) override;
46 Status HandleDot(HloInstruction* hlo) override;
48 Status HandleTuple(HloInstruction* hlo) override;
50 Status HandleTranspose(HloInstruction* hlo) override;
52 Status HandleReshape(HloInstruction* hlo) override;
54 Status HandlePad(HloInstruction* hlo) override;
56 Status HandleBroadcast(HloInstruction* hlo) override;
58 Status HandleGetDimensionSize(HloInstruction* hlo) override;
[all …]
Dreduce_precision_insertion_test.cc39 const std::function<bool(const HloInstruction*)>& filter) { in InsertOps()
52 HloInstruction* a = in TEST_F()
53 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); in TEST_F()
54 HloInstruction* b = builder.AddInstruction( in TEST_F()
55 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); in TEST_F()
65 [](const HloInstruction* instruction) { in TEST_F()
79 HloInstruction* a = in TEST_F()
80 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); in TEST_F()
81 HloInstruction* b = builder.AddInstruction( in TEST_F()
82 HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); in TEST_F()
[all …]
Dinstruction_fusion.h38 std::function<bool(const HloInstruction& instruction)> is_expensive,
51 static bool IsExpensive(const HloInstruction& instruction);
71 virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index);
76 virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, in ShouldFuseIntoMultiOutput()
83 virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
84 const HloInstruction* consumer);
87 virtual HloInstruction* Fuse(HloInstruction* producer,
88 HloInstruction* consumer);
94 virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
95 HloInstruction* consumer);
[all …]
Dhlo_module_group_util.cc41 std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( in GlobalPredecessors()
42 HloInstruction* instruction) { in GlobalPredecessors()
43 std::vector<HloInstruction*> in GlobalPredecessors()
45 absl::flat_hash_set<HloInstruction*> unique; in GlobalPredecessors()
51 auto add_unique_predecessor = [&](HloInstruction* predecessor) { in GlobalPredecessors()
56 for (HloInstruction* instr : metadata_.Companions(predecessor)) { in GlobalPredecessors()
64 for (HloInstruction* instr : in GlobalPredecessors()
79 std::vector<HloInstruction*> instruction_group; in GlobalPredecessors()
81 for (HloInstruction* companion : metadata_.Companions(instruction)) { in GlobalPredecessors()
91 for (HloInstruction* hlo : instruction_group) { in GlobalPredecessors()
[all …]
Dbatchnorm_expander.cc52 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { in DefaultAction() argument
56 Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
58 Status HandleBatchNormInference(HloInstruction* batch_norm) override;
60 Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
86 HloInstruction::CreateParameter(0, shape, "scalar_lhs")); in GetOrCreateScalarAddComputation()
88 HloInstruction::CreateParameter(1, shape, "scalar_rhs")); in GetOrCreateScalarAddComputation()
89 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( in GetOrCreateScalarAddComputation()
94 std::unique_ptr<HloInstruction> Rsqrt( in Rsqrt()
95 HloInstruction* operand, in Rsqrt()
96 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& in Rsqrt()
[all …]
Dmulti_output_fusion.h65 virtual bool ShapesCompatibleForFusion(HloInstruction* instr1,
66 HloInstruction* instr2) = 0;
69 virtual bool IsFusible(HloInstruction* instr) = 0;
73 virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0;
76 virtual bool IsProfitableOperand(HloInstruction* instr);
79 virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
83 virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
96 HloInstruction* instr1, HloInstruction* instr2,
97 absl::Span<HloInstruction* const> instrs_to_update,
98 const std::function<bool(HloInstruction*)>& skip = nullptr);
[all …]
Dscatter_expander.cc31 static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast( in TransposeIndexVectorDimToLast()
32 HloInstruction* scatter_indices, int64 index_vector_dim) { in TransposeIndexVectorDimToLast()
56 static StatusOr<HloInstruction*> CanonicalizeScatterIndices( in CanonicalizeScatterIndices()
57 HloInstruction* scatter_indices, int64 index_vector_dim) { in CanonicalizeScatterIndices()
60 HloInstruction * transposed_scatter_indices, in CanonicalizeScatterIndices()
94 static StatusOr<HloInstruction*> PermuteScatterAndWindowDims( in PermuteScatterAndWindowDims()
95 HloInstruction* updates, absl::Span<const int64> update_window_dims) { in PermuteScatterAndWindowDims()
114 static StatusOr<HloInstruction*> AdjustScatterDims( in AdjustScatterDims()
115 const Shape& scatter_indices_shape, HloInstruction* updates, in AdjustScatterDims()
132 static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace( in ExpandIndexVectorIntoOperandSpace()
[all …]
Dhlo_creation_utils.cc35 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, in MakeBinaryHlo()
36 HloInstruction* rhs) { in MakeBinaryHlo()
42 HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); in MakeBinaryHlo()
45 StatusOr<HloInstruction*> MakeCompareHlo(ComparisonDirection direction, in MakeCompareHlo()
46 HloInstruction* lhs, in MakeCompareHlo()
47 HloInstruction* rhs) { in MakeCompareHlo()
54 HloInstruction::CreateCompare(binary_op_shape, lhs, rhs, direction)); in MakeCompareHlo()
57 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand, in MakePadHlo()
58 HloInstruction* padding_value, in MakePadHlo()
66 return computation->AddInstruction(HloInstruction::CreatePad( in MakePadHlo()
[all …]
Dalgebraic_simplifier_test.cc60 HloInstruction* param0 = builder.AddInstruction( in TEST_F()
61 HloInstruction::CreateParameter(0, r0f32, "param0")); in TEST_F()
62 HloInstruction* zero = builder.AddInstruction( in TEST_F()
63 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); in TEST_F()
65 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); in TEST_F()
68 HloInstruction* root = computation->root_instruction(); in TEST_F()
283 HloInstruction* param0 = builder.AddInstruction( in TEST_F()
284 HloInstruction::CreateParameter(0, r0s32, "param0")); in TEST_F()
285 HloInstruction* zero = builder.AddInstruction( in TEST_F()
286 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); in TEST_F()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Dcpu_instruction_fusion_test.cc37 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs, in MakeDot()
38 HloInstruction* rhs) { in MakeDot()
45 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, in MakeDot()
51 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( in TEST_F()
53 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( in TEST_F()
56 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( in TEST_F()
58 HloInstruction* dot = builder.AddInstruction( in TEST_F()
70 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( in TEST_F()
72 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( in TEST_F()
75 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( in TEST_F()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dcudnn_batchnorm_rewriter.cc35 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { in DefaultAction() argument
39 Status HandleBatchNormInference(HloInstruction* batch_norm) override;
40 Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
41 Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
50 bool EpsilonInRange(HloInstruction* batch_norm) { in EpsilonInRange()
54 Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { in HandleBatchNormInference()
70 HloInstruction* epsilon = in HandleBatchNormInference()
71 computation_->AddInstruction(HloInstruction::CreateConstant( in HandleBatchNormInference()
73 HloInstruction* feature_index = in HandleBatchNormInference()
74 computation_->AddInstruction(HloInstruction::CreateConstant( in HandleBatchNormInference()
[all …]
Dir_emitter_unnested.h100 std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
103 std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
114 HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
157 Status DefaultAction(HloInstruction* hlo) override;
161 Status HandleCopy(HloInstruction* copy) override;
162 Status HandleConditional(HloInstruction* conditional) override;
163 Status HandleConvolution(HloInstruction* convolution) override;
164 Status HandleCustomCall(HloInstruction* custom_call) override;
165 Status HandleDot(HloInstruction* dot) override;
166 Status HandleFft(HloInstruction* fft) override;
[all …]

12345678910>>...16