1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_
17 #define TENSORFLOW_CC_OPS_WHILE_LOOP_H_
18 
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/framework/scope.h"
21 
22 namespace tensorflow {
23 namespace ops {
24 
25 // Function that takes cond graph inputs and returns cond graph boolean output.
26 // 'output' need not be set if an error is returned.
27 typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
28                              Output* output)>
29     CondGraphBuilderFn;
30 
31 // Function that takes body graph inputs and returns body graph outputs.
32 // 'outputs' need not be populated if an error is returned.
33 typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
34                              std::vector<Output>* outputs)>
35     BodyGraphBuilderFn;
36 
37 // Constructs a while loop.
38 //
39 // Arguments:
40 // * scope: used to construct the while loop.
41 // * inputs: the initial values of the loop variables. Must be non-empty.
42 // * cond: a function that builds the condition graph of the loop. Takes the
43 //     current loop variables as inputs and returns a scalar boolean Output
44 //     indicating whether the loop should continue.
45 // * body: a function that builds the body graph of the loop. Takes the current
46 //     loop variables as inputs and returns the updated loop variables.
47 // * frame_name: the frame name to use for this while loop. This should be a
48 //     unique name. This will be used as a prefix for created operations.
49 // * outputs: output param that returns final loop variable outputs in non-error
50 //     case. Must be non-null and empty.
51 // * create_while_ctx: if true, a WhileContext is created and populated for this
52 //     loop. See core/graph/while_context.h for more details on
53 //     WhileContexts. This is set to false for loops used as part of gradient
54 //     computations, since they're part of the gradient for a loop in the
55 //     forward-pass.
56 //     TODO(skyewm): revisit this. Should we create WhileContexts for all loops,
57 //     even if we don't need them?
58 // * cond_output: if non-null, the output of the predicate is returned. This
59 //     will always be a LoopCond node.
60 //
61 // Returns an error if the while loop could not be fully constructed.
62 //
63 // TODO(skyewm): clean up partially-constructed loop in error case
64 // TODO(skyewm): create public interface to this method
65 Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
66                       const CondGraphBuilderFn& cond,
67                       const BodyGraphBuilderFn& body, const string& frame_name,
68                       OutputList* outputs, bool create_while_ctx = true,
69                       Output* cond_output = nullptr);
70 
71 }  // namespace ops
72 }  // namespace tensorflow
73 
74 #endif  // TENSORFLOW_CC_OPS_WHILE_LOOP_H_
75