Searched refs:flattened_list (Results 1 – 2 of 2) sorted by relevance
109 std::vector<HloSharding> flattened_list; in Tuple() local110 flattened_list.reserve(sub_shardings.leaf_count()); in Tuple()112 flattened_list.push_back(index_to_sharding.second); in Tuple()114 if (flattened_list.empty()) { in Tuple()120 flattened_list.push_back(sub_shardings.element(ShapeIndex({}))); in Tuple()122 return HloSharding(flattened_list); in Tuple()131 std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end()); in Tuple() local132 CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) in Tuple()133 << "Flat list has " << flattened_list.size() << ", required " in Tuple()135 return HloSharding(flattened_list); in Tuple()[all …]
1464 flattened_list = nest.flatten(replicate_inputs[0])1465 for input_tensor in flattened_list: