1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tensor Handle Operations."""
17
18# pylint: disable=g-bad-name
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.core.framework import resource_handle_pb2
26from tensorflow.python.client import pywrap_tf_session
27from tensorflow.python.framework import device as pydev
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_data_flow_ops
32from tensorflow.python.util import compat
33from tensorflow.python.util.tf_export import tf_export
34
35
36def encode_resource_handle(resource_handle):
37  """Encode a ResourceHandle proto as custom numpy struct type."""
38  return np.asarray(bytearray(resource_handle.SerializeToString()),
39                    dtype=dtypes.np_resource)
40
41
42class TensorHandle(object):
43  """Represents a handle for a live tensor in a session."""
44
45  def __init__(self, handle, dtype, session):
46    """Constructs a new tensor handle.
47
48    A tensor handle for a persistent tensor is a python string
49    that has the form of "tensor_name;unique_id;device_name".
50
51    Args:
52      handle: A tensor handle.
53      dtype: The data type of the tensor represented by `handle`.
54      session: The session in which the tensor is produced.
55    """
56    self._handle = compat.as_str_any(handle)
57    self._resource_handle = None
58    self._dtype = dtype
59    self._session = session
60    self._auto_gc_enabled = True
61
62  def __del__(self):
63    if self._auto_gc_enabled:
64      self._session._register_dead_handle(self.handle)
65
66  def __str__(self):
67    return self._handle
68
69  def _get_resource_handle(self):
70    """The ResourceHandle representation of this handle."""
71    if not self._resource_handle:
72      self._resource_handle = resource_handle_pb2.ResourceHandleProto()
73      self._resource_handle.device = self._handle.split(";")[-1]
74      self._resource_handle.container = (pywrap_tf_session.TENSOR_HANDLE_KEY)
75      self._resource_handle.name = self._handle
76    return self._resource_handle
77
78  def to_numpy_array(self):
79    """Convert a TensorHandle object to a feedable numpy value.
80
81    Returns:
82      A numpy array of a custom struct type that can be used as a feed value
83      to run().
84    """
85    return encode_resource_handle(self._get_resource_handle())
86
87  @property
88  def handle(self):
89    """The string representation of this handle."""
90    return self._handle
91
92  def eval(self):
93    """Return the value of the tensor represented by this handle."""
94    if not self._auto_gc_enabled:
95      raise TypeError("Persistent tensor %s may have already been deleted."
96                      % self.handle)
97    holder, reader = _get_handle_reader(self._session.graph, self._handle,
98                                        self._dtype)
99    return self._session.run(reader, feed_dict={holder: self._handle})
100
101  def delete(self):
102    """Force the deletion of this persistent tensor."""
103    if not self._auto_gc_enabled:
104      raise TypeError("Persistent tensor %s may have already been deleted."
105                      % self.handle)
106    self._auto_gc_enabled = False
107    holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
108    self._session.run(deleter, feed_dict={holder: self.handle})
109
110  def get_raw_handle(self):
111    """Return the raw handle of the tensor.
112
113    Note that the method disables the automatic garbage collection of this
114    persistent tensor. The caller is now responsible for managing the life
115    time of the tensor.
116    """
117    self._auto_gc_enabled = False
118    return self._handle
119
120  @staticmethod
121  def _get_device_name(handle):
122    """The device name encoded in the handle."""
123    handle_str = compat.as_str_any(handle)
124    return pydev.canonical_name(handle_str.split(";")[-1])
125
126  @staticmethod
127  def _get_reader_key(handle):
128    """The graph key for reader."""
129    handle_parts = str(handle).split(";")
130    return handle_parts[0] + ";" + handle_parts[-1]
131
132  @staticmethod
133  def _get_mover_key(feeder, handle):
134    """The graph key for mover."""
135    return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
136
137
138@tf_export(v1=["get_session_handle"])
139def get_session_handle(data, name=None):
140  """Return the handle of `data`.
141
142  This is EXPERIMENTAL and subject to change.
143
144  Keep `data` "in-place" in the runtime and create a handle that can be
145  used to retrieve `data` in a subsequent run().
146
147  Combined with `get_session_tensor`, we can keep a tensor produced in
148  one run call in place, and use it as the input in a future run call.
149
150  Args:
151    data: A tensor to be stored in the session.
152    name: Optional name prefix for the return tensor.
153
154  Returns:
155    A scalar string tensor representing a unique handle for `data`.
156
157  Raises:
158    TypeError: if `data` is not a Tensor.
159
160  Example:
161
162  ```python
163  c = tf.multiply(a, b)
164  h = tf.compat.v1.get_session_handle(c)
165  h = sess.run(h)
166
167  p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32)
168  b = tf.multiply(a, 10)
169  c = sess.run(b, feed_dict={p: h.handle})
170  ```
171
172  """
173  if not isinstance(data, ops.Tensor):
174    raise TypeError("`data` must be of type Tensor.")
175
176  # Colocate this operation with data.
177  with ops.colocate_with(data):
178    return gen_data_flow_ops.get_session_handle(data, name=name)
179
180
181@tf_export(v1=["get_session_tensor"])
182def get_session_tensor(handle, dtype, name=None):
183  """Get the tensor of type `dtype` by feeding a tensor handle.
184
185  This is EXPERIMENTAL and subject to change.
186
187  Get the value of the tensor from a tensor handle. The tensor
188  is produced in a previous run() and stored in the state of the
189  session.
190
191  Args:
192    handle: The string representation of a persistent tensor handle.
193    dtype: The type of the output tensor.
194    name: Optional name prefix for the return tensor.
195
196  Returns:
197    A pair of tensors. The first is a placeholder for feeding a
198    tensor handle and the second is the tensor in the session state
199    keyed by the tensor handle.
200
201  Example:
202
203  ```python
204  c = tf.multiply(a, b)
205  h = tf.compat.v1.get_session_handle(c)
206  h = sess.run(h)
207
208  p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32)
209  b = tf.multiply(a, 10)
210  c = sess.run(b, feed_dict={p: h.handle})
211  ```
212
213  """
214  handle_device = TensorHandle._get_device_name(handle)
215  with ops.device(handle_device):
216    holder = array_ops.placeholder(dtypes.string)
217    _register_handle_feeder(holder.graph, holder, dtype)
218    tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name)
219  return (holder, tensor)
220
221
222@tf_export(v1=["delete_session_tensor"])
223def delete_session_tensor(handle, name=None):
224  """Delete the tensor for the given tensor handle.
225
226  This is EXPERIMENTAL and subject to change.
227
228  Delete the tensor of a given tensor handle. The tensor is produced
229  in a previous run() and stored in the state of the session.
230
231  Args:
232    handle: The string representation of a persistent tensor handle.
233    name: Optional name prefix for the return tensor.
234
235  Returns:
236    A pair of graph elements. The first is a placeholder for feeding a
237    tensor handle and the second is a deletion operation.
238  """
239  handle_device = TensorHandle._get_device_name(handle)
240  with ops.device(handle_device):
241    holder = array_ops.placeholder(dtypes.string)
242    deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name)
243  return (holder, deleter)
244
245
246def _register_handle_feeder(graph, feeder, dtype):
247  graph._handle_feeders[feeder.op.name] = dtype
248
249
250def _get_handle_feeder(graph, feeder):
251  return graph._handle_feeders.get(feeder.op.name)
252
253
254def _get_handle_reader(graph, handle, dtype):
255  """Return a read subgraph for this handle."""
256  graph_key = TensorHandle._get_reader_key(handle)
257  result = graph._handle_readers.get(graph_key)
258  if result is None:
259    # Create reader if we haven't done it.
260    handle_device = TensorHandle._get_device_name(handle)
261    with graph.as_default(), graph.device(handle_device):
262      holder = array_ops.placeholder(dtypes.string)
263      _register_handle_feeder(holder.graph, holder, dtype)
264      reader = gen_data_flow_ops.get_session_tensor(holder, dtype)
265    result = (holder, reader)
266    graph._handle_readers[graph_key] = result
267  return result
268
269
270def _get_handle_mover(graph, feeder, handle):
271  """Return a move subgraph for this pair of feeder and handle."""
272  dtype = _get_handle_feeder(graph, feeder)
273  if dtype is None:
274    return None
275  handle_device = TensorHandle._get_device_name(handle)
276  if feeder.op.device == handle_device:
277    return None
278  # Now we know we have to move the tensor.
279  graph_key = TensorHandle._get_mover_key(feeder, handle)
280  result = graph._handle_movers.get(graph_key)
281  if result is None:
282    # Create mover if we haven't done it.
283    holder, reader = _get_handle_reader(graph, handle, dtype)
284    with graph.as_default(), graph.device(feeder.op.device):
285      mover = gen_data_flow_ops.get_session_handle(reader)
286    result = (holder, mover)
287    graph._handle_movers[graph_key] = result
288  return result
289
290
291def _get_handle_deleter(graph, deleter_key, handle):
292  """Return a deletion subgraph for this handle."""
293  result = graph._handle_deleters.get(deleter_key)
294  if result is None:
295    # Create deleter if we haven't done it.
296    handle_device = TensorHandle._get_device_name(handle)
297    with graph.as_default(), graph.device(handle_device):
298      holder = array_ops.placeholder(dtypes.string)
299      deleter = gen_data_flow_ops.delete_session_tensor(holder)
300    result = (holder, deleter)
301    graph._handle_deleters[deleter_key] = result
302  return result
303