Home
last modified time | relevance | path

Searched refs:replicate_op (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dreplicate_invariant_op_hoisting.cc76 void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, in MakeShapeOpInvariant() argument
84 shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument( in MakeShapeOpInvariant()
109 replicate_op.GetReplicaOperandForBlockArgument(block_arg, in MakeShapeOpInvariant()
154 void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { in HoistReplicateInvariantOps() argument
155 const int num_replicas = replicate_op.n(); in HoistReplicateInvariantOps()
156 Block* replicate_block = &replicate_op.GetBody(); in HoistReplicateInvariantOps()
158 replicate_op.walk([&](TF::ShapeOp shape_op) { in HoistReplicateInvariantOps()
159 MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op); in HoistReplicateInvariantOps()
162 Region* replicate_region = &replicate_op.body(); in HoistReplicateInvariantOps()
163 Optional<DictionaryAttr> virtual_device_list = replicate_op.devices(); in HoistReplicateInvariantOps()
[all …]
Dtf_device_replication_pass.cc42 module.walk([&](tf_device::ReplicateOp replicate_op) { in runOnOperation() argument
43 OpBuilder builder(replicate_op); in runOnOperation()
48 llvm::Optional<DictionaryAttr> devices = replicate_op.devices(); in runOnOperation()
49 const int replicate_num = replicate_op.n(); in runOnOperation()
58 for (BlockArgument &arg : replicate_op.GetBody().getArguments()) { in runOnOperation()
60 replicate_op.GetReplicaOperandForBlockArgument(arg, i); in runOnOperation()
63 for (Operation &op : replicate_op.GetBody().without_terminator()) { in runOnOperation()
89 for (Value v : replicate_op.GetBody().getTerminator()->getOperands()) { in runOnOperation()
109 replicate_op.replaceAllUsesWith(new_results); in runOnOperation()
110 replicate_op.erase(); in runOnOperation()
Dreplicate_to_island.cc140 tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op, in ExpandReplicateIntoReplicas() argument
143 auto devices = replicate_op.devices(); in ExpandReplicateIntoReplicas()
146 Operation& terminator = replicate_op.GetBody().back(); in ExpandReplicateIntoReplicas()
166 for (auto& block_arg : replicate_op.GetBody().getArguments()) in ExpandReplicateIntoReplicas()
168 replicate_op.GetReplicaOperandForBlockArgument(block_arg, i)); in ExpandReplicateIntoReplicas()
171 replicate_op.body().cloneInto(&replica.body(), mapping); in ExpandReplicateIntoReplicas()
173 if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(), in ExpandReplicateIntoReplicas()
237 tf_device::ReplicateOp replicate_op) { in CreateIslandsFromReplicate() argument
239 const int num_replicas = replicate_op.n(); in CreateIslandsFromReplicate()
244 replicate_op, num_replicas, replicas))) in CreateIslandsFromReplicate()
[all …]
Dtpu_colocate_composite_resource_ops.cc92 tf_device::ReplicateOp replicate_op, OpBuilder* builder) { in ColocateCompositeResourceOpsInReplicate() argument
93 auto devices = replicate_op.devices(); in ColocateCompositeResourceOpsInReplicate()
99 GetResourceOpsUsingCompositeArgsInReplicate(replicate_op); in ColocateCompositeResourceOpsInReplicate()
Dtpu_cluster_formation.cc424 auto replicate_op = builder.create<tf_device::ReplicateOp>( in ReplicateCluster() local
429 replicate_op->setAttr(kReplicatedInputIndicesAttr, in ReplicateCluster()
433 replicate_op->setAttr(kMirroredVariableIndicesAttr, in ReplicateCluster()
441 std::next(replicate_op.result_begin(), idx * num_replicas), in ReplicateCluster()
442 std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); in ReplicateCluster()
467 llvm::zip(replicated_input_ops, replicate_op.GetBody().getArguments())) { in ReplicateCluster()
484 builder.setInsertionPointToEnd(&replicate_op.GetBody()); in ReplicateCluster()
485 auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(), in ReplicateCluster()
Dtpu_reorder_replicate_and_partitioned_inputs.cc95 auto replicate_op = builder.create<TF::TPUReplicatedInputOp>( in ReorderReplicateAndPartitionedInputs() local
98 operands_per_core.push_back(replicate_op); in ReorderReplicateAndPartitionedInputs()
Dexecutor_tpuv1_outline_tpu_island.cc80 getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) { in runOnOperation() argument
81 auto island_op = cast<IslandOp>(replicate_op->getParentOp()); in runOnOperation()
Dtpu_variable_runtime_reformatting.cc526 while_op.body().walk([&](tf_device::ReplicateOp replicate_op) { in runOnOperation() argument
528 replicate = replicate_op; in runOnOperation()