Searched refs:ArCrsCombiner (Results 1 – 3 of 3) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | ar_crs_combiner_test.cc | 46 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue( in TEST_F() 48 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 66 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 85 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 106 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 126 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 146 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 168 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 190 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() 213 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); in TEST_F() [all …]
|
D | ar_crs_combiner.cc | 41 absl::optional<ArCrsCombiner::ArCrsPair> ArCrsCombiner::MatchesArCrsPattern( in MatchesArCrsPattern() 95 absl::optional<HloInstruction*> ArCrsCombiner::WhileFromBodyParameter( in WhileFromBodyParameter() 109 std::vector<HloInstruction*> ArCrsCombiner::GetAllTuples( in GetAllTuples() 149 bool ArCrsCombiner::TupleElementsComputeSameValue( in TupleElementsComputeSameValue() 168 bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, in TestInstructionsComputeSameValue() 170 ArCrsCombiner combiner(/*num_spatial_partitions=*/2); in TestInstructionsComputeSameValue() 178 bool ArCrsCombiner::InstructionsComputeSameValue( in InstructionsComputeSameValue() 235 void ArCrsCombiner::GroupAllReducesById(HloModule* module) { in GroupAllReducesById() 291 void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { in KeepProvablyEqualInstructionGroups() 317 StatusOr<bool> ArCrsCombiner::RewriteGraph() { in RewriteGraph() [all …]
|
D | ar_crs_combiner.h | 70 class ArCrsCombiner : public HloModulePass { 72 ArCrsCombiner(int num_spatial_partitions) in ArCrsCombiner() function 101 absl::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern(
|