1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/core/framework/attr_value.pb.h" 24 #include "tensorflow/core/framework/types.h" 25 26 namespace tensorflow { 27 28 // This TensorFlow op provides a functional switch/case primitive. 29 // 30 // The outputs of the branches must agree on the number, types, and 31 // shapes of the Tensors carried around the two bodies. 32 // 33 // Computations in branch bodies may read from and write to resource variables. 34 // Resource variables may be passed as arguments to the branch function's 35 // bodies. The XlaCompiler converts resource variable arguments 36 // into parameters to the XLA computation and moves them to the end of the 37 // parameter list, and by using the `return_updated_values_for_all_variables` 38 // we ensure that all variables that appear in the input also appear at the 39 // end of the branch bodies output. This ensures the branch bodies output 40 // signatures match. 41 // 42 // It is the user's responsibility to ensure that each non-variable _Arg matches 43 // the corresponding _Retval. 44 class XlaCaseOp : public XlaOpKernel { 45 public: 46 explicit XlaCaseOp(OpKernelConstruction* ctx); 47 48 void Compile(XlaOpKernelContext* ctx) override; 49 50 private: 51 TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp); 52 53 // If the branch_index input is a constant: prunes out all but the branch 54 // corrresponding to that constant branch index, and returns that branch and 55 // the literal 0 (as the first and second component of the pair). 56 // 57 // If the branch_index input is not a constant: returns unpruned_branches_ and 58 // the branch_index input. 59 std::pair<std::vector<NameAttrList>, xla::XlaOp> GetPrunedBranchesAndIndex( 60 XlaOpKernelContext* ctx); 61 62 std::vector<NameAttrList> unpruned_branches_; 63 DataTypeVector input_types_; 64 DataTypeVector output_types_; 65 bool has_token_input_output_; 66 std::vector<string> token_input_nodes_; 67 // Whether to propagate compile time consts into the cond branches. 68 // This is not supported by default now since it may cause HBM memory 69 // overheads. 70 bool propagate_compile_time_consts_ = false; 71 }; 72 73 } // namespace tensorflow 74 75 #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ 76