1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
17 
18 #include "tensorflow/core/graph/graph.h"
19 #include "tensorflow/core/lib/core/status.h"
20 
21 // TODO(skyewm): can this be combined with ConstantFold?
22 
23 namespace tensorflow {
24 
25 class GraphRunner;
26 class OpRegistryInterface;
27 class ShapeRefiner;
28 class Tensor;
29 
30 // Attempts to evaluate `tensor`. This will only be possible if `tensor` doesn't
31 // depend on any graph inputs (this function is safe to call if this isn't the
32 // case though).
33 //
34 // If the evaluation is successful, `evaluated` will be set to true and
35 // `tensor`s value returned in `result`. Otherwise `evaluated` will be set to
36 // false. An error status is returned if something is wrong with the graph or
37 // input. Note that `evaluated` may set to false if Status::OK() is returned.
38 //
39 // Params:
40 //   tensor - the tensor to be evaluated.
41 //   refiner - used to fetch the InferenceContexts for nodes in the graph.
42 //   ops - the OpRegistryInterface for the graph.
43 //   graph_def_version - the producer version of the graph.
44 //   evaluated - output param indicating whether evaluation was successful.
45 //   result - output param containing the result if evaluated is true.
46 //   graph_runner - optional. If not set, a GraphRunner will be created for
47 //     evaluating tensor. This can be set to avoid creating a new GraphRunner
48 //     for every call.
49 //   cached_values - optional. This can be used to cache evaluated results
50 //     across calls, to avoid evaluating the same parts of the graph multiple
51 //     times.
52 //   max_cached_value_size - optional. If `cached_values` is set, the maximum
53 //     result size to cache.
54 //   disable_constant_propagation - if true, only Const node values will be
55 //     returned.
56 Status EvaluateConstantTensor(
57     OutputTensor tensor, const ShapeRefiner& refiner,
58     const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated,
59     Tensor* result, GraphRunner* graph_runner = nullptr,
60     std::unordered_map<string, Tensor>* cached_values = nullptr,
61     int64 max_cached_value_size = 1024,
62     bool disable_constant_propagation = false);
63 
64 }  // namespace tensorflow
65 
66 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_
67