1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Pyfunc creation utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import namedtuple 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import script_ops 26 27 28class MatchDType(namedtuple('MatchDType', ('arg_number',))): 29 """Allows matching the dtype of an argument. 30 31 Used in conjunction with function calls. For example, MatchDType(0) will 32 match the DType of the first argument. 33 """ 34 35 pass 36 37 38def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False): 39 """Helper that wraps a callable to py_func. 40 41 The helper passes tensor arguments through the py_func interface. Non-tensor 42 arguments are allowed, and will be passed to f directly. Note that non-tensor 43 arguments are captured by f will not update every time the wrapper is 44 called (this is consistent with its argument list, which only includes 45 the tensor arguments). In general, it's safest not to reuse this wrapper. 46 47 Args: 48 f: Callable 49 return_dtypes: None, individual of tuple/list of DType or MatchDType, the 50 data type for each of f's return value(s). Set to None if f has no 51 return values or use_dummy_return is True. Use MatchDType to define a 52 dtype identical to that of `i`th argument (argument 0 is the first); 53 an argument must of Tensor type if it is to be used with MatchDType. 54 args: Positional arguments for f, as list or tuple. 55 kwargs: Keyword arguments for f, as dict with string keys. May be None. 56 use_dummy_return: If True, the function will return a dummy value of 1 57 and discard its actual return value. 58 Returns: 59 The return values of f converted to tensor. 60 Raises: 61 ValueError: if any of the arguments are incorrect. 62 """ 63 64 if return_dtypes and use_dummy_return: 65 raise ValueError('if use_dummy_return is True, return_dtypes must be empty') 66 67 tensor_args = [] 68 tensor_args_idx = {} 69 70 # Of the positional arguments, only grab the tensor ones to be passed through 71 # the py_func. 72 n_args = len(args) 73 arg_is_tensor = tuple(map(tensor_util.is_tf_type, args)) 74 for i in range(n_args): 75 if arg_is_tensor[i]: 76 tensor_args_idx[i] = len(tensor_args) 77 tensor_args.append(args[i]) 78 79 # We essentially take the tensor kwargs, if any, and add them to the list of 80 # positional arguments. The kwargs are then reconstructed inside the py_func. 81 # 82 # For example, if 83 # 84 # args = [Tensor(1), 'foo'] 85 # kwargs = {'a': Tensor(2), 'b': 'bar'} 86 # 87 # Then 88 # 89 # tensor_args = (Tensor(1), Tensor(2)) 90 # kwarg_keys = ('a', 'b') 91 if kwargs: 92 kwarg_keys = tuple(kwargs.keys()) 93 kwarg_is_tensor = {k: tensor_util.is_tf_type(kwargs[k]) for k in kwarg_keys} 94 for k in kwarg_keys: 95 if kwarg_is_tensor[k]: 96 tensor_args_idx[k] = len(tensor_args) 97 tensor_args.append(kwargs[k]) 98 else: 99 kwarg_keys = () 100 101 # Set up return dtypes. 102 def match_arg_dtype(arg_number): 103 arg = args[arg_number] 104 if not arg_is_tensor[arg_number]: 105 raise ValueError( 106 'argument %d was used with MatchDType and must be a tf.Tensor, but ' 107 'was %s instead' % (arg_number, type(arg))) 108 return arg.dtype 109 110 if return_dtypes: 111 if isinstance(return_dtypes, MatchDType): 112 return_dtypes = match_arg_dtype(return_dtypes.arg_number) 113 elif isinstance(return_dtypes, (list, tuple)): 114 return_dtypes = tuple( 115 match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a 116 for a in return_dtypes) 117 else: 118 assert isinstance(return_dtypes, dtypes.DType) 119 120 def f_wrapper(*tensor_args): 121 f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a 122 for i, a in enumerate(args)) 123 f_kwargs = { 124 k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k] 125 for i, k in enumerate(kwarg_keys) 126 } 127 retval = f(*f_args, **f_kwargs) 128 return 1 if use_dummy_return else retval 129 130 if use_dummy_return: 131 return_dtypes = dtypes.int32 132 return script_ops.eager_py_func(f_wrapper, tensor_args, return_dtypes) 133