1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
18 
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/mutex.h"
26 
27 namespace tensorflow {
28 
29 // The session state remembers the tensors we choose to keep across
30 // multiple run calls.
31 class SessionState {
32  public:
33   // Get a tensor from the session state.
34   Status GetTensor(const std::string& handle, Tensor* tensor);
35 
36   // Store a tensor in the session state.
37   Status AddTensor(const std::string& handle, const Tensor& tensor);
38 
39   // Delete a tensor from the session state.
40   Status DeleteTensor(const std::string& handle);
41 
42   int64 GetNewId();
43 
44   static const char* kTensorHandleResourceTypeName;
45 
46  private:
47   mutex state_lock_;
48 
49   // For generating unique ids for tensors stored in the session.
50   int64 tensor_id_ = 0;
51 
52   // The live tensors in the session. A map from tensor handle to tensor.
53   std::unordered_map<string, Tensor> tensors_;
54 };
55 
56 // The tensor store remembers the tensors we choose to keep for the
57 // current run call. It is available to every op kernel.
58 class TensorStore {
59  public:
60   struct TensorAndKey {
61     Tensor tensor;
62     int64 id;
63     std::string device_name;
64 
GetHandleTensorAndKey65     std::string GetHandle(const std::string& tensor_name) {
66       return strings::StrCat(tensor_name, ";", id, ";", device_name);
67     }
68   };
69 
70   // Add the named tensor to the tensor store for this run.
71   Status AddTensor(const std::string& name, const TensorAndKey& tk);
72 
73   // Save the tensors in the tensor store of this run to the session.
74   Status SaveTensors(const std::vector<string>& output_names,
75                      SessionState* session_state);
76 
77   // Returns true if no tensors have been added to this store.
empty()78   bool empty() TF_NO_THREAD_SAFETY_ANALYSIS { return !dirty_; }
79 
80  private:
81   mutex lock_;
TF_GUARDED_BY(lock_)82   std::atomic<bool> dirty_ TF_GUARDED_BY(lock_){false};
83 
84   // The tensors that will be saved to session state when this run completes.
85   // A map from tensor string name to tensor.
86   std::unordered_map<string, TensorAndKey> tensors_ TF_GUARDED_BY(lock_);
87 };
88 
89 }  // namespace tensorflow
90 
91 #endif  // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
92