1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tensor Handle Operations."""
17
18# pylint: disable=g-bad-name
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.core.framework import resource_handle_pb2
26from tensorflow.python import pywrap_tensorflow_internal
27from tensorflow.python.framework import device as pydev
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_data_flow_ops
32from tensorflow.python.util import compat
33from tensorflow.python.util.tf_export import tf_export
34
35
36def encode_resource_handle(resource_handle):
37  """Encode a ResourceHandle proto as custom numpy struct type."""
38  return np.asarray(bytearray(resource_handle.SerializeToString()),
39                    dtype=dtypes.np_resource)
40
41
42class TensorHandle(object):
43  """Represents a handle for a live tensor in a session."""
44
45  def __init__(self, handle, dtype, session):
46    """Constructs a new tensor handle.
47
48    A tensor handle for a persistent tensor is a python string
49    that has the form of "tensor_name;unique_id;device_name".
50
51    Args:
52      handle: A tensor handle.
53      dtype: The data type of the tensor represented by `handle`.
54      session: The session in which the tensor is produced.
55    """
56    self._handle = compat.as_str_any(handle)
57    self._resource_handle = None
58    self._dtype = dtype
59    self._session = session
60    self._auto_gc_enabled = True
61
62  def __del__(self):
63    if self._auto_gc_enabled:
64      self._session._register_dead_handle(self.handle)
65
66  def __str__(self):
67    return self._handle
68
69  def _get_resource_handle(self):
70    """The ResourceHandle representation of this handle."""
71    if not self._resource_handle:
72      self._resource_handle = resource_handle_pb2.ResourceHandleProto()
73      self._resource_handle.device = self._handle.split(";")[-1]
74      self._resource_handle.container = (
75          pywrap_tensorflow_internal.TENSOR_HANDLE_KEY)
76      self._resource_handle.name = self._handle
77    return self._resource_handle
78
79  def to_numpy_array(self):
80    """Convert a TensorHandle object to a feedable numpy value.
81
82    Returns:
83      A numpy array of a custom struct type that can be used as a feed value
84      to run().
85    """
86    return encode_resource_handle(self._get_resource_handle())
87
88  @property
89  def handle(self):
90    """The string representation of this handle."""
91    return self._handle
92
93  def eval(self):
94    """Return the value of the tensor represented by this handle."""
95    if not self._auto_gc_enabled:
96      raise TypeError("Persistent tensor %s may have already been deleted."
97                      % self.handle)
98    holder, reader = _get_handle_reader(self._session.graph, self._handle,
99                                        self._dtype)
100    return self._session.run(reader, feed_dict={holder: self._handle})
101
102  def delete(self):
103    """Force the deletion of this persistent tensor."""
104    if not self._auto_gc_enabled:
105      raise TypeError("Persistent tensor %s may have already been deleted."
106                      % self.handle)
107    self._auto_gc_enabled = False
108    holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
109    self._session.run(deleter, feed_dict={holder: self.handle})
110
111  def get_raw_handle(self):
112    """Return the raw handle of the tensor.
113
114    Note that the method disables the automatic garbage collection of this
115    persistent tensor. The caller is now responsible for managing the life
116    time of the tensor.
117    """
118    self._auto_gc_enabled = False
119    return self._handle
120
121  @staticmethod
122  def _get_device_name(handle):
123    """The device name encoded in the handle."""
124    handle_str = compat.as_str_any(handle)
125    return pydev.canonical_name(handle_str.split(";")[-1])
126
127  @staticmethod
128  def _get_reader_key(handle):
129    """The graph key for reader."""
130    handle_parts = str(handle).split(";")
131    return handle_parts[0] + ";" + handle_parts[-1]
132
133  @staticmethod
134  def _get_mover_key(feeder, handle):
135    """The graph key for mover."""
136    return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
137
138
139@tf_export(v1=["get_session_handle"])
140def get_session_handle(data, name=None):
141  """Return the handle of `data`.
142
143  This is EXPERIMENTAL and subject to change.
144
145  Keep `data` "in-place" in the runtime and create a handle that can be
146  used to retrieve `data` in a subsequent run().
147
148  Combined with `get_session_tensor`, we can keep a tensor produced in
149  one run call in place, and use it as the input in a future run call.
150
151  Args:
152    data: A tensor to be stored in the session.
153    name: Optional name prefix for the return tensor.
154
155  Returns:
156    A scalar string tensor representing a unique handle for `data`.
157
158  Raises:
159    TypeError: if `data` is not a Tensor.
160
161  Example:
162
163  ```python
164  c = tf.multiply(a, b)
165  h = tf.get_session_handle(c)
166  h = sess.run(h)
167
168  p, a = tf.get_session_tensor(h.handle, tf.float32)
169  b = tf.multiply(a, 10)
170  c = sess.run(b, feed_dict={p: h.handle})
171  ```
172
173  """
174  if not isinstance(data, ops.Tensor):
175    raise TypeError("`data` must be of type Tensor.")
176
177  # Colocate this operation with data.
178  with ops.colocate_with(data):
179    return gen_data_flow_ops.get_session_handle(data, name=name)
180
181
182@tf_export(v1=["get_session_tensor"])
183def get_session_tensor(handle, dtype, name=None):
184  """Get the tensor of type `dtype` by feeding a tensor handle.
185
186  This is EXPERIMENTAL and subject to change.
187
188  Get the value of the tensor from a tensor handle. The tensor
189  is produced in a previous run() and stored in the state of the
190  session.
191
192  Args:
193    handle: The string representation of a persistent tensor handle.
194    dtype: The type of the output tensor.
195    name: Optional name prefix for the return tensor.
196
197  Returns:
198    A pair of tensors. The first is a placeholder for feeding a
199    tensor handle and the second is the tensor in the session state
200    keyed by the tensor handle.
201
202  Example:
203
204  ```python
205  c = tf.multiply(a, b)
206  h = tf.get_session_handle(c)
207  h = sess.run(h)
208
209  p, a = tf.get_session_tensor(h.handle, tf.float32)
210  b = tf.multiply(a, 10)
211  c = sess.run(b, feed_dict={p: h.handle})
212  ```
213
214  """
215  handle_device = TensorHandle._get_device_name(handle)
216  with ops.device(handle_device):
217    holder = array_ops.placeholder(dtypes.string)
218    _register_handle_feeder(holder.graph, holder, dtype)
219    tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name)
220  return (holder, tensor)
221
222
223@tf_export(v1=["delete_session_tensor"])
224def delete_session_tensor(handle, name=None):
225  """Delete the tensor for the given tensor handle.
226
227  This is EXPERIMENTAL and subject to change.
228
229  Delete the tensor of a given tensor handle. The tensor is produced
230  in a previous run() and stored in the state of the session.
231
232  Args:
233    handle: The string representation of a persistent tensor handle.
234    name: Optional name prefix for the return tensor.
235
236  Returns:
237    A pair of graph elements. The first is a placeholder for feeding a
238    tensor handle and the second is a deletion operation.
239  """
240  handle_device = TensorHandle._get_device_name(handle)
241  with ops.device(handle_device):
242    holder = array_ops.placeholder(dtypes.string)
243    deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name)
244  return (holder, deleter)
245
246
247def _register_handle_feeder(graph, feeder, dtype):
248  graph._handle_feeders[feeder.op.name] = dtype
249
250
251def _get_handle_feeder(graph, feeder):
252  return graph._handle_feeders.get(feeder.op.name)
253
254
255def _get_handle_reader(graph, handle, dtype):
256  """Return a read subgraph for this handle."""
257  graph_key = TensorHandle._get_reader_key(handle)
258  result = graph._handle_readers.get(graph_key)
259  if result is None:
260    # Create reader if we haven't done it.
261    handle_device = TensorHandle._get_device_name(handle)
262    with graph.as_default(), graph.device(handle_device):
263      holder = array_ops.placeholder(dtypes.string)
264      _register_handle_feeder(holder.graph, holder, dtype)
265      reader = gen_data_flow_ops.get_session_tensor(holder, dtype)
266    result = (holder, reader)
267    graph._handle_readers[graph_key] = result
268  return result
269
270
271def _get_handle_mover(graph, feeder, handle):
272  """Return a move subgraph for this pair of feeder and handle."""
273  dtype = _get_handle_feeder(graph, feeder)
274  if dtype is None:
275    return None
276  handle_device = TensorHandle._get_device_name(handle)
277  if feeder.op.device == handle_device:
278    return None
279  # Now we know we have to move the tensor.
280  graph_key = TensorHandle._get_mover_key(feeder, handle)
281  result = graph._handle_movers.get(graph_key)
282  if result is None:
283    # Create mover if we haven't done it.
284    holder, reader = _get_handle_reader(graph, handle, dtype)
285    with graph.as_default(), graph.device(feeder.op.device):
286      mover = gen_data_flow_ops.get_session_handle(reader)
287    result = (holder, mover)
288    graph._handle_movers[graph_key] = result
289  return result
290
291
292def _get_handle_deleter(graph, deleter_key, handle):
293  """Return a deletion subgraph for this handle."""
294  result = graph._handle_deleters.get(deleter_key)
295  if result is None:
296    # Create deleter if we haven't done it.
297    handle_device = TensorHandle._get_device_name(handle)
298    with graph.as_default(), graph.device(handle_device):
299      holder = array_ops.placeholder(dtypes.string)
300      deleter = gen_data_flow_ops.delete_session_tensor(holder)
301    result = (holder, deleter)
302    graph._handle_deleters[deleter_key] = result
303  return result
304