Home
last modified time | relevance | path

Searched refs:case_op (Results 1 – 12 of 12) sorted by relevance

/external/tensorflow/tensorflow/python/ops/
Dcond_v2.py1029 case_op = op.outputs[0].op
1030 branch_graphs = get_func_graphs(case_op)
1035 assert branch_graph.outer_graph == case_op.graph
1083 case_op._set_func_list_attr("branches", [
1087 case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
1088 case_op._set_shape_list_attr("output_shapes",
1090 case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
1104 lowering = case_op._get_attr_bool("_lower_using_switch_merge")
1109 case_op.inputs[0],
1168 case_op, tensors = util.get_op_and_outputs(op_fn(
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf_control_flow.cc136 auto case_op = in LowerCase() local
143 ImportXlaRegion(branch_func, &case_op.branches()[i], loc, in LowerCase()
147 op.replaceAllUsesWith(case_op.getResults()); in LowerCase()
301 auto case_op = in LowerCaseRegion() local
304 for (auto region : llvm::zip(case_op.branches(), op.branches())) in LowerCaseRegion()
307 op.replaceAllUsesWith(case_op.getResults()); in LowerCaseRegion()
387 if (auto case_op = dyn_cast<TF::CaseOp>(op)) { in runOnOperation() local
388 LowerCase(case_op); in runOnOperation()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtensor_list_ops_decomposition.cc388 TF::CaseRegionOp case_op, ModuleOp module, in HandleCaseRegionOp() argument
393 RegionRange branches = case_op.getRegions(); in HandleCaseRegionOp()
415 auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>( in HandleCaseRegionOp()
416 case_op.getLoc(), in HandleCaseRegionOp()
418 case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions()); in HandleCaseRegionOp()
424 for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) { in HandleCaseRegionOp()
427 case_op.replaceAllUsesWith( in HandleCaseRegionOp()
428 new_op.getResults().take_front(case_op.getNumResults())); in HandleCaseRegionOp()
429 case_op.erase(); in HandleCaseRegionOp()
899 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) { in DecomposeTensorListOpsInternal() local
[all …]
Dresource_op_lifting_cleanup.cc422 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) { in CleanupAndCanonicalize() local
424 case_op.get_branch_functions(branches); in CleanupAndCanonicalize()
425 result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input()); in CleanupAndCanonicalize()
Dresource_op_lifting.cc1256 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) { in HoistForControlFlow() local
1258 case_op.get_branch_functions(branch_functions); in HoistForControlFlow()
1264 if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); in HoistForControlFlow()
Dshape_inference.cc1330 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) { in PropagateShapeIntoAttachedFunctions() local
1332 case_op.get_branch_functions(branches); in PropagateShapeIntoAttachedFunctions()
1333 return PropagateShapeToFunctions(module, case_op.input().getTypes(), in PropagateShapeIntoAttachedFunctions()
/external/tensorflow/tensorflow/core/common_runtime/
Dlower_case_op.cc41 CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
98 CaseBuilder::CaseBuilder(Node* case_op, in CaseBuilder() argument
101 : case_op_(case_op), in CaseBuilder()
104 name_(case_op->name()), in CaseBuilder()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
DBUILD145 ":case_op",
330 name = "case_op",
331 srcs = ["case_op.cc"],
332 hdrs = ["case_op.h"],
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/analysis/
Dresource_alias_analysis.cc322 } else if (auto case_op = dyn_cast<CaseOp>(op)) { in ResourceAliasAnalysisInfo() local
324 case_op.get_branch_functions(functions); in ResourceAliasAnalysisInfo()
325 AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis); in ResourceAliasAnalysisInfo()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/
Dtensor_list_ops_decomposition.mlir259 …%case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = fal…
262 …%pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor<!tf.variant<tensor<f32>>>, tensor…
407 %case_op = "tf.CaseRegion"(%arg0) ({
443 …%pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor<!tf.variant<tensor<f32>>>, tensor…
Dtensor_array_ops_decomposition.mlir406 %case_op = "tf.IfRegion"(%arg0) ({
427 …%read_val = "tf.TensorArrayReadV3"(%ta#0, %idx, %case_op) : (tensor<!tf.resource>, tensor<i32>, te…
Dstack_ops_decomposition.mlir189 %case_op = "tf.CaseRegion"(%arg0) ({