1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ 17 18 #include "tensorflow/core/graph/graph.h" 19 #include "tensorflow/core/lib/core/status.h" 20 21 // TODO(skyewm): can this be combined with ConstantFold? 22 23 namespace tensorflow { 24 25 class GraphRunner; 26 class OpRegistryInterface; 27 class ShapeRefiner; 28 class Tensor; 29 30 // Attempts to evaluate `tensor`. This will only be possible if `tensor` doesn't 31 // depend on any graph inputs (this function is safe to call if this isn't the 32 // case though). 33 // 34 // If the evaluation is successful, `evaluated` will be set to true and 35 // `tensor`s value returned in `result`. Otherwise `evaluated` will be set to 36 // false. An error status is returned if something is wrong with the graph or 37 // input. Note that `evaluated` may set to false if Status::OK() is returned. 38 // 39 // Params: 40 // tensor - the tensor to be evaluated. 41 // refiner - used to fetch the InferenceContexts for nodes in the graph. 42 // ops - the OpRegistryInterface for the graph. 43 // graph_def_version - the producer version of the graph. 44 // evaluated - output param indicating whether evaluation was successful. 45 // result - output param containing the result if evaluated is true. 46 // graph_runner - optional. If not set, a GraphRunner will be created for 47 // evaluating tensor. This can be set to avoid creating a new GraphRunner 48 // for every call. 49 // cached_values - optional. This can be used to cache evaluated results 50 // across calls, to avoid evaluating the same parts of the graph multiple 51 // times. 52 // max_cached_value_size - optional. If `cached_values` is set, the maximum 53 // result size to cache. 54 // disable_constant_propagation - if true, only Const node values will be 55 // returned. 56 Status EvaluateConstantTensor( 57 OutputTensor tensor, const ShapeRefiner& refiner, 58 const OpRegistryInterface& ops, int32 graph_def_version, bool* evaluated, 59 Tensor* result, GraphRunner* graph_runner = nullptr, 60 std::unordered_map<string, Tensor>* cached_values = nullptr, 61 int64 max_cached_value_size = 1024, 62 bool disable_constant_propagation = false); 63 64 } // namespace tensorflow 65 66 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ 67