1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
17 #define TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
18 
19 #include "tensorflow/core/graph/graph.h"
20 
21 namespace tensorflow {
22 
23 // Information about a while loop. Every user-defined while loop has an
24 // associated WhileContext, i.e., there is a WhileContext for every execution
25 // frame. Created with the while loop and used during gradient
26 // construction. Note that the gradient graph of while loop contains while loops
27 // itself, but these do not generate separate WhileContexts.
28 //
29 // TODO(skyewm): this is currently insufficient to handle nested loops and
30 // conditionals (and possibly other requirements). This may change a lot in the
31 // future to support these features.
32 //
33 // TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be
34 // differentiable. Figure out backwards compatibility story.
35 class WhileContext {
36  public:
37   WhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
38                std::vector<Node*> exit_nodes, OutputTensor cond_output,
39                std::vector<OutputTensor> body_inputs,
40                std::vector<OutputTensor> body_outputs);
41 
frame_name()42   const string& frame_name() const { return frame_name_; }
enter_nodes()43   const std::vector<Node*>& enter_nodes() const { return enter_nodes_; }
exit_nodes()44   const std::vector<Node*>& exit_nodes() const { return exit_nodes_; }
cond_output()45   const OutputTensor& cond_output() const { return cond_output_; }
body_inputs()46   const std::vector<OutputTensor>& body_inputs() const { return body_inputs_; }
body_outputs()47   const std::vector<OutputTensor>& body_outputs() const {
48     return body_outputs_;
49   }
50 
51  private:
52   // Each user-defined while loop defines a new execution frame, which is
53   // uniquely identified by its frame name. Frames are used by the executor to
54   // manage the iterations of a loop. See the FrameState comment in
55   // core/common_runtime/executor.cc for more details.
56   const string frame_name_;
57 
58   // The enter nodes defining the input loop variables to the while loop. This
59   // vector defines the order of the loop variables.
60   const std::vector<Node*> enter_nodes_;
61 
62   // The exit nodes defining the outputs of the while loop. These are in loop
63   // variable order.
64   const std::vector<Node*> exit_nodes_;
65 
66   // The boolean output of the loop predicate.
67   const OutputTensor cond_output_;
68 
69   // The inputs and outputs to the loop body.
70   const std::vector<OutputTensor> body_inputs_;
71   const std::vector<OutputTensor> body_outputs_;
72 };
73 
74 }  // namespace tensorflow
75 
76 #endif  // TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
77