1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras category crossing preprocessing layers."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import numpy as np
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.keras.engine import base_preprocessing_layer
30from tensorflow.python.keras.utils import tf_utils
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import sparse_ops
33from tensorflow.python.ops.ragged import ragged_array_ops
34from tensorflow.python.ops.ragged import ragged_tensor
35from tensorflow.python.util.tf_export import keras_export
36
37
38@keras_export('keras.layers.experimental.preprocessing.CategoryCrossing')
39class CategoryCrossing(base_preprocessing_layer.PreprocessingLayer):
40  """Category crossing layer.
41
42  This layer concatenates multiple categorical inputs into a single categorical
43  output (similar to Cartesian product). The output dtype is string.
44
45  Usage:
46  >>> inp_1 = ['a', 'b', 'c']
47  >>> inp_2 = ['d', 'e', 'f']
48  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing()
49  >>> layer([inp_1, inp_2])
50  <tf.Tensor: shape=(3, 1), dtype=string, numpy=
51    array([[b'a_X_d'],
52           [b'b_X_e'],
53           [b'c_X_f']], dtype=object)>
54
55
56  >>> inp_1 = ['a', 'b', 'c']
57  >>> inp_2 = ['d', 'e', 'f']
58  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing(
59  ...    separator='-')
60  >>> layer([inp_1, inp_2])
61  <tf.Tensor: shape=(3, 1), dtype=string, numpy=
62    array([[b'a-d'],
63           [b'b-e'],
64           [b'c-f']], dtype=object)>
65
66  Args:
67    depth: depth of input crossing. By default None, all inputs are crossed into
68      one output. It can also be an int or tuple/list of ints. Passing an
69      integer will create combinations of crossed outputs with depth up to that
70      integer, i.e., [1, 2, ..., `depth`), and passing a tuple of integers will
71      create crossed outputs with depth for the specified values in the tuple,
72      i.e., `depth`=(N1, N2) will create all possible crossed outputs with depth
73      equal to N1 or N2. Passing `None` means a single crossed output with all
74      inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the
75      output will be [a;b;c;cross(a, b);cross(bc);cross(ca)].
76    separator: A string added between each input being joined. Defaults to
77      '_X_'.
78    name: Name to give to the layer.
79    **kwargs: Keyword arguments to construct a layer.
80
81  Input shape: a list of string or int tensors or sparse tensors of shape
82    `[batch_size, d1, ..., dm]`
83
84  Output shape: a single string or int tensor or sparse tensor of shape
85    `[batch_size, d1, ..., dm]`
86
87  Returns:
88    If any input is `RaggedTensor`, the output is `RaggedTensor`.
89    Else, if any input is `SparseTensor`, the output is `SparseTensor`.
90    Otherwise, the output is `Tensor`.
91
92  Example: (`depth`=None)
93    If the layer receives three inputs:
94    `a=[[1], [4]]`, `b=[[2], [5]]`, `c=[[3], [6]]`
95    the output will be a string tensor:
96    `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
97
98  Example: (`depth` is an integer)
99    With the same input above, and if `depth`=2,
100    the output will be a list of 6 string tensors:
101    `[[b'1'], [b'4']]`
102    `[[b'2'], [b'5']]`
103    `[[b'3'], [b'6']]`
104    `[[b'1_X_2'], [b'4_X_5']]`,
105    `[[b'2_X_3'], [b'5_X_6']]`,
106    `[[b'3_X_1'], [b'6_X_4']]`
107
108  Example: (`depth` is a tuple/list of integers)
109    With the same input above, and if `depth`=(2, 3)
110    the output will be a list of 4 string tensors:
111    `[[b'1_X_2'], [b'4_X_5']]`,
112    `[[b'2_X_3'], [b'5_X_6']]`,
113    `[[b'3_X_1'], [b'6_X_4']]`,
114    `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]`
115  """
116
117  def __init__(self, depth=None, name=None, separator='_X_', **kwargs):
118    super(CategoryCrossing, self).__init__(name=name, **kwargs)
119    base_preprocessing_layer.keras_kpl_gauge.get_cell(
120        'CategoryCrossing').set(True)
121    self.depth = depth
122    self.separator = separator
123    if isinstance(depth, (tuple, list)):
124      self._depth_tuple = depth
125    elif depth is not None:
126      self._depth_tuple = tuple([i for i in range(1, depth + 1)])
127
128  def partial_crossing(self, partial_inputs, ragged_out, sparse_out):
129    """Gets the crossed output from a partial list/tuple of inputs."""
130    # If ragged_out=True, convert output from sparse to ragged.
131    if ragged_out:
132      # TODO(momernick): Support separator with ragged_cross.
133      if self.separator != '_X_':
134        raise ValueError('Non-default separator with ragged input is not '
135                         'supported yet, given {}'.format(self.separator))
136      return ragged_array_ops.cross(partial_inputs)
137    elif sparse_out:
138      return sparse_ops.sparse_cross(partial_inputs, separator=self.separator)
139    else:
140      return sparse_ops.sparse_tensor_to_dense(
141          sparse_ops.sparse_cross(partial_inputs, separator=self.separator))
142
143  def _preprocess_input(self, inp):
144    if isinstance(inp, (list, tuple, np.ndarray)):
145      inp = ops.convert_to_tensor_v2_with_dispatch(inp)
146    if inp.shape.rank == 1:
147      inp = array_ops.expand_dims(inp, axis=-1)
148    return inp
149
150  def call(self, inputs):
151    inputs = [self._preprocess_input(inp) for inp in inputs]
152    depth_tuple = self._depth_tuple if self.depth else (len(inputs),)
153    ragged_out = sparse_out = False
154    if any(tf_utils.is_ragged(inp) for inp in inputs):
155      ragged_out = True
156    elif any(isinstance(inp, sparse_tensor.SparseTensor) for inp in inputs):
157      sparse_out = True
158
159    outputs = []
160    for depth in depth_tuple:
161      if len(inputs) < depth:
162        raise ValueError(
163            'Number of inputs cannot be less than depth, got {} input tensors, '
164            'and depth {}'.format(len(inputs), depth))
165      for partial_inps in itertools.combinations(inputs, depth):
166        partial_out = self.partial_crossing(
167            partial_inps, ragged_out, sparse_out)
168        outputs.append(partial_out)
169    if sparse_out:
170      return sparse_ops.sparse_concat_v2(axis=1, sp_inputs=outputs)
171    return array_ops.concat(outputs, axis=1)
172
173  def compute_output_shape(self, input_shape):
174    if not isinstance(input_shape, (tuple, list)):
175      raise ValueError('A `CategoryCrossing` layer should be called '
176                       'on a list of inputs.')
177    input_shapes = input_shape
178    batch_size = None
179    for inp_shape in input_shapes:
180      inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list()
181      if len(inp_tensor_shape) != 2:
182        raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes))
183      if batch_size is None:
184        batch_size = inp_tensor_shape[0]
185    # The second dimension is dynamic based on inputs.
186    output_shape = [batch_size, None]
187    return tensor_shape.TensorShape(output_shape)
188
189  def compute_output_signature(self, input_spec):
190    input_shapes = [x.shape for x in input_spec]
191    output_shape = self.compute_output_shape(input_shapes)
192    if any(
193        isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
194        for inp_spec in input_spec):
195      return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
196    elif any(
197        isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
198        for inp_spec in input_spec):
199      return sparse_tensor.SparseTensorSpec(
200          shape=output_shape, dtype=dtypes.string)
201    return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
202
203  def get_config(self):
204    config = {
205        'depth': self.depth,
206        'separator': self.separator,
207    }
208    base_config = super(CategoryCrossing, self).get_config()
209    return dict(list(base_config.items()) + list(config.items()))
210