1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 
26 namespace tensorflow {
27 
28 // This TensorFlow op provides a functional switch/case primitive.
29 //
30 // The outputs of the branches must agree on the number, types, and
31 // shapes of the Tensors carried around the two bodies.
32 //
33 // Computations in branch bodies may read from and write to resource variables.
34 // Resource variables may be passed as arguments to the branch function's
35 // bodies. The XlaCompiler converts resource variable arguments
36 // into parameters to the XLA computation and moves them to the end of the
37 // parameter list, and by using the `return_updated_values_for_all_variables`
38 // we ensure that all variables that appear in the input also appear at the
39 // end of the branch bodies output. This ensures the branch bodies output
40 // signatures match.
41 //
42 // It is the user's responsibility to ensure that each non-variable _Arg matches
43 // the corresponding _Retval.
44 class XlaCaseOp : public XlaOpKernel {
45  public:
46   explicit XlaCaseOp(OpKernelConstruction* ctx);
47 
48   void Compile(XlaOpKernelContext* ctx) override;
49 
50  private:
51   TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp);
52 
53   // If the branch_index input is a constant: prunes out all but the branch
54   // corrresponding to that constant branch index, and returns that branch and
55   // the literal 0 (as the first and second component of the pair).
56   //
57   // If the branch_index input is not a constant: returns unpruned_branches_ and
58   // the branch_index input.
59   std::pair<std::vector<NameAttrList>, xla::XlaOp> GetPrunedBranchesAndIndex(
60       XlaOpKernelContext* ctx);
61 
62   std::vector<NameAttrList> unpruned_branches_;
63   DataTypeVector input_types_;
64   DataTypeVector output_types_;
65   bool has_token_input_output_;
66   std::vector<string> token_input_nodes_;
67   // Whether to propagate compile time consts into the cond branches.
68   // This is not supported by default now since it may cause HBM memory
69   // overheads.
70   bool propagate_compile_time_consts_ = false;
71 };
72 
73 }  // namespace tensorflow
74 
75 #endif  // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_
76