1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various context managers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import tensor_array_ops 25 26 27def control_dependency_on_returns(return_value): 28 """Create a TF control dependency on the return values of a function. 29 30 If the function had no return value, a no-op context is returned. 31 32 Args: 33 return_value: The return value to set as control dependency. 34 35 Returns: 36 A context manager. 37 """ 38 def control_dependency_handle(t): 39 if isinstance(t, tensor_array_ops.TensorArray): 40 return t.flow 41 return t 42 43 if return_value is None: 44 return contextlib.contextmanager(lambda: (yield))() 45 # TODO(mdan): Filter to tensor objects. 46 if not isinstance(return_value, (list, tuple)): 47 return_value = (return_value,) 48 return_value = tuple(control_dependency_handle(t) for t in return_value) 49 return ops.control_dependencies(return_value) 50