1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental API for optimizing `tf.data` pipelines.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import gen_dataset_ops 24 25 26def map_defun(fn, 27 elems, 28 output_dtypes, 29 output_shapes, 30 max_intra_op_parallelism=1): 31 """Map a function on the list of tensors unpacked from `elems` on dimension 0. 32 33 Args: 34 fn: A function (`function.defun`) that takes a list of tensors and returns 35 another list of tensors. The output list has the same types as 36 output_dtypes. The elements of the output list have the same dimension 0 37 as `elems`, and the remaining dimensions correspond to those of 38 `fn_output_shapes`. 39 elems: A list of tensors. 40 output_dtypes: A list of dtypes corresponding to the output types of the 41 function. 42 output_shapes: A list of `TensorShape`s corresponding to the output shapes 43 from each invocation of the function on slices of inputs. 44 max_intra_op_parallelism: An integer. If positive, sets the max parallelism 45 limit of each function call to this. 46 47 Raises: 48 ValueError: if any of the inputs are malformed. 49 50 Returns: 51 A list of `Tensor` objects with the same types as `output_dtypes`. 52 """ 53 if not isinstance(elems, list): 54 raise ValueError("`elems` must be a list of tensors.") 55 if not isinstance(output_dtypes, list): 56 raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.") 57 if not isinstance(output_shapes, list): 58 raise ValueError("`output_shapes` must be a list of `tf.TensorShape` " 59 "objects.") 60 61 concrete_fn = fn._get_concrete_function_internal() # pylint: disable=protected-access 62 # TODO(shivaniagrawal/rachelim): what about functions created without 63 # input_signature. 64 elems = [ops.convert_to_tensor(e) for e in elems] 65 output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] 66 return gen_dataset_ops.map_defun(elems, concrete_fn.captured_inputs, 67 output_dtypes, output_shapes, concrete_fn, 68 max_intra_op_parallelism) 69