1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras category crossing preprocessing layers.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22import numpy as np 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.keras.engine import base_preprocessing_layer 30from tensorflow.python.keras.utils import tf_utils 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import sparse_ops 33from tensorflow.python.ops.ragged import ragged_array_ops 34from tensorflow.python.ops.ragged import ragged_tensor 35from tensorflow.python.util.tf_export import keras_export 36 37 38@keras_export('keras.layers.experimental.preprocessing.CategoryCrossing') 39class CategoryCrossing(base_preprocessing_layer.PreprocessingLayer): 40 """Category crossing layer. 41 42 This layer concatenates multiple categorical inputs into a single categorical 43 output (similar to Cartesian product). The output dtype is string. 44 45 Usage: 46 >>> inp_1 = ['a', 'b', 'c'] 47 >>> inp_2 = ['d', 'e', 'f'] 48 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing() 49 >>> layer([inp_1, inp_2]) 50 <tf.Tensor: shape=(3, 1), dtype=string, numpy= 51 array([[b'a_X_d'], 52 [b'b_X_e'], 53 [b'c_X_f']], dtype=object)> 54 55 56 >>> inp_1 = ['a', 'b', 'c'] 57 >>> inp_2 = ['d', 'e', 'f'] 58 >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing( 59 ... separator='-') 60 >>> layer([inp_1, inp_2]) 61 <tf.Tensor: shape=(3, 1), dtype=string, numpy= 62 array([[b'a-d'], 63 [b'b-e'], 64 [b'c-f']], dtype=object)> 65 66 Args: 67 depth: depth of input crossing. By default None, all inputs are crossed into 68 one output. It can also be an int or tuple/list of ints. Passing an 69 integer will create combinations of crossed outputs with depth up to that 70 integer, i.e., [1, 2, ..., `depth`), and passing a tuple of integers will 71 create crossed outputs with depth for the specified values in the tuple, 72 i.e., `depth`=(N1, N2) will create all possible crossed outputs with depth 73 equal to N1 or N2. Passing `None` means a single crossed output with all 74 inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the 75 output will be [a;b;c;cross(a, b);cross(bc);cross(ca)]. 76 separator: A string added between each input being joined. Defaults to 77 '_X_'. 78 name: Name to give to the layer. 79 **kwargs: Keyword arguments to construct a layer. 80 81 Input shape: a list of string or int tensors or sparse tensors of shape 82 `[batch_size, d1, ..., dm]` 83 84 Output shape: a single string or int tensor or sparse tensor of shape 85 `[batch_size, d1, ..., dm]` 86 87 Returns: 88 If any input is `RaggedTensor`, the output is `RaggedTensor`. 89 Else, if any input is `SparseTensor`, the output is `SparseTensor`. 90 Otherwise, the output is `Tensor`. 91 92 Example: (`depth`=None) 93 If the layer receives three inputs: 94 `a=[[1], [4]]`, `b=[[2], [5]]`, `c=[[3], [6]]` 95 the output will be a string tensor: 96 `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]` 97 98 Example: (`depth` is an integer) 99 With the same input above, and if `depth`=2, 100 the output will be a list of 6 string tensors: 101 `[[b'1'], [b'4']]` 102 `[[b'2'], [b'5']]` 103 `[[b'3'], [b'6']]` 104 `[[b'1_X_2'], [b'4_X_5']]`, 105 `[[b'2_X_3'], [b'5_X_6']]`, 106 `[[b'3_X_1'], [b'6_X_4']]` 107 108 Example: (`depth` is a tuple/list of integers) 109 With the same input above, and if `depth`=(2, 3) 110 the output will be a list of 4 string tensors: 111 `[[b'1_X_2'], [b'4_X_5']]`, 112 `[[b'2_X_3'], [b'5_X_6']]`, 113 `[[b'3_X_1'], [b'6_X_4']]`, 114 `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]` 115 """ 116 117 def __init__(self, depth=None, name=None, separator='_X_', **kwargs): 118 super(CategoryCrossing, self).__init__(name=name, **kwargs) 119 base_preprocessing_layer.keras_kpl_gauge.get_cell( 120 'CategoryCrossing').set(True) 121 self.depth = depth 122 self.separator = separator 123 if isinstance(depth, (tuple, list)): 124 self._depth_tuple = depth 125 elif depth is not None: 126 self._depth_tuple = tuple([i for i in range(1, depth + 1)]) 127 128 def partial_crossing(self, partial_inputs, ragged_out, sparse_out): 129 """Gets the crossed output from a partial list/tuple of inputs.""" 130 # If ragged_out=True, convert output from sparse to ragged. 131 if ragged_out: 132 # TODO(momernick): Support separator with ragged_cross. 133 if self.separator != '_X_': 134 raise ValueError('Non-default separator with ragged input is not ' 135 'supported yet, given {}'.format(self.separator)) 136 return ragged_array_ops.cross(partial_inputs) 137 elif sparse_out: 138 return sparse_ops.sparse_cross(partial_inputs, separator=self.separator) 139 else: 140 return sparse_ops.sparse_tensor_to_dense( 141 sparse_ops.sparse_cross(partial_inputs, separator=self.separator)) 142 143 def _preprocess_input(self, inp): 144 if isinstance(inp, (list, tuple, np.ndarray)): 145 inp = ops.convert_to_tensor_v2_with_dispatch(inp) 146 if inp.shape.rank == 1: 147 inp = array_ops.expand_dims(inp, axis=-1) 148 return inp 149 150 def call(self, inputs): 151 inputs = [self._preprocess_input(inp) for inp in inputs] 152 depth_tuple = self._depth_tuple if self.depth else (len(inputs),) 153 ragged_out = sparse_out = False 154 if any(tf_utils.is_ragged(inp) for inp in inputs): 155 ragged_out = True 156 elif any(isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs): 157 sparse_out = True 158 159 outputs = [] 160 for depth in depth_tuple: 161 if len(inputs) < depth: 162 raise ValueError( 163 'Number of inputs cannot be less than depth, got {} input tensors, ' 164 'and depth {}'.format(len(inputs), depth)) 165 for partial_inps in itertools.combinations(inputs, depth): 166 partial_out = self.partial_crossing( 167 partial_inps, ragged_out, sparse_out) 168 outputs.append(partial_out) 169 if sparse_out: 170 return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs) 171 return array_ops.concat(outputs, axis=1) 172 173 def compute_output_shape(self, input_shape): 174 if not isinstance(input_shape, (tuple, list)): 175 raise ValueError('A `CategoryCrossing` layer should be called ' 176 'on a list of inputs.') 177 input_shapes = input_shape 178 batch_size = None 179 for inp_shape in input_shapes: 180 inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list() 181 if len(inp_tensor_shape) != 2: 182 raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes)) 183 if batch_size is None: 184 batch_size = inp_tensor_shape[0] 185 # The second dimension is dynamic based on inputs. 186 output_shape = [batch_size, None] 187 return tensor_shape.TensorShape(output_shape) 188 189 def compute_output_signature(self, input_spec): 190 input_shapes = [x.shape for x in input_spec] 191 output_shape = self.compute_output_shape(input_shapes) 192 if any( 193 isinstance(inp_spec, ragged_tensor.RaggedTensorSpec) 194 for inp_spec in input_spec): 195 return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string) 196 elif any( 197 isinstance(inp_spec, sparse_tensor.SparseTensorSpec) 198 for inp_spec in input_spec): 199 return sparse_tensor.SparseTensorSpec( 200 shape=output_shape, dtype=dtypes.string) 201 return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string) 202 203 def get_config(self): 204 config = { 205 'depth': self.depth, 206 'separator': self.separator, 207 } 208 base_config = super(CategoryCrossing, self).get_config() 209 return dict(list(base_config.items()) + list(config.items())) 210