1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras CategoryEncoding preprocessing layer."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras.engine import base_preprocessing_layer
30from tensorflow.python.keras.utils import layer_utils
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import bincount_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util.tf_export import keras_export
37
38TFIDF = "tf-idf"
39INT = "int"
40BINARY = "binary"
41COUNT = "count"
42
43# The inverse-document-frequency weights
44_IDF_NAME = "idf"
45
46
47@keras_export("keras.layers.experimental.preprocessing.CategoryEncoding")
48class CategoryEncoding(base_preprocessing_layer.PreprocessingLayer):
49  """Category encoding layer.
50
51  This layer provides options for condensing data into a categorical encoding
52  when the total number of tokens are known in advance. It accepts integer
53  values as inputs and outputs a dense representation (one sample = 1-index
54  tensor of float values representing data about the sample's tokens) of those
55  inputs. For integer inputs where the total number of tokens is not known, see
56  `tf.keras.layers.experimental.preprocessing.IntegerLookup`.
57
58  Examples:
59
60  **Multi-hot encoding data**
61
62  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
63  ...           num_tokens=4, output_mode="binary")
64  >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]])
65  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
66    array([[1., 1., 0., 0.],
67           [1., 0., 0., 0.],
68           [0., 1., 1., 0.],
69           [0., 1., 0., 1.]], dtype=float32)>
70
71  **Using weighted inputs in `count` mode**
72
73  >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
74  ...           num_tokens=4, output_mode="count")
75  >>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]])
76  >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights)
77  <tf.Tensor: shape=(4, 4), dtype=float64, numpy=
78    array([[0.1, 0.2, 0. , 0. ],
79           [0.2, 0. , 0. , 0. ],
80           [0. , 0.2, 0.3, 0. ],
81           [0. , 0.2, 0. , 0.4]])>
82
83  Args:
84    num_tokens: The total number of tokens the layer should support. All inputs
85      to the layer must integers in the range 0 <= value < num_tokens or an
86      error will be thrown.
87    output_mode: Specification for the output of the layer.
88      Defaults to "binary". Values can
89      be "binary", "count" or "tf-idf", configuring the layer as follows:
90        "binary": Outputs a single int array per batch, of either vocab_size or
91          num_tokens size, containing 1s in all elements where the token mapped
92          to that index exists at least once in the batch item.
93        "count": As "binary", but the int array contains a count of the number
94          of times the token at that index appeared in the batch item.
95    sparse: Boolean. If true, returns a `SparseTensor` instead of a dense
96      `Tensor`. Defaults to `False`.
97
98  Call arguments:
99    inputs: A 2D tensor `(samples, timesteps)`.
100    count_weights: A 2D tensor in the same shape as `inputs` indicating the
101      weight for each sample value when summing up in `count` mode. Not used in
102      `binary` or `tfidf` mode.
103  """
104
105  def __init__(self,
106               num_tokens=None,
107               output_mode=BINARY,
108               sparse=False,
109               **kwargs):
110    # max_tokens is an old name for the num_tokens arg we continue to support
111    # because of usage.
112    if "max_tokens" in kwargs:
113      logging.warning(
114          "max_tokens is deprecated, please use num_tokens instead.")
115      num_tokens = kwargs["max_tokens"]
116      del kwargs["max_tokens"]
117
118    super(CategoryEncoding, self).__init__(**kwargs)
119
120    # 'output_mode' must be one of (COUNT, BINARY)
121    layer_utils.validate_string_arg(
122        output_mode,
123        allowable_strings=(COUNT, BINARY),
124        layer_name="CategoryEncoding",
125        arg_name="output_mode")
126
127    if num_tokens is None:
128      raise ValueError("num_tokens must be set to use this layer. If the "
129                       "number of tokens is not known beforehand, use the "
130                       "IntegerLookup layer instead.")
131    if num_tokens < 1:
132      raise ValueError("num_tokens must be >= 1.")
133
134    self.num_tokens = num_tokens
135    self.output_mode = output_mode
136    self.sparse = sparse
137
138  def compute_output_shape(self, input_shape):
139    return tensor_shape.TensorShape([input_shape[0], self.num_tokens])
140
141  def compute_output_signature(self, input_spec):
142    output_shape = self.compute_output_shape(input_spec.shape.as_list())
143    output_dtype = K.floatx() if self.output_mode == TFIDF else dtypes.int64
144    if self.sparse:
145      return sparse_tensor.SparseTensorSpec(
146          shape=output_shape, dtype=output_dtype)
147    else:
148      return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
149
150  def get_config(self):
151    config = {
152        "num_tokens": self.num_tokens,
153        "output_mode": self.output_mode,
154        "sparse": self.sparse,
155    }
156    base_config = super(CategoryEncoding, self).get_config()
157    return dict(list(base_config.items()) + list(config.items()))
158
159  def call(self, inputs, count_weights=None):
160    if isinstance(inputs, (list, np.ndarray)):
161      inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
162    if inputs.shape.rank == 1:
163      inputs = array_ops.expand_dims(inputs, 1)
164
165    if count_weights is not None and self.output_mode != COUNT:
166      raise ValueError("count_weights is not used in `output_mode='tf-idf'`, "
167                       "or `output_mode='binary'`. Please pass a single input.")
168
169    out_depth = self.num_tokens
170    binary_output = (self.output_mode == BINARY)
171    if isinstance(inputs, sparse_tensor.SparseTensor):
172      max_value = math_ops.reduce_max(inputs.values)
173      min_value = math_ops.reduce_min(inputs.values)
174    else:
175      max_value = math_ops.reduce_max(inputs)
176      min_value = math_ops.reduce_min(inputs)
177    condition = math_ops.logical_and(
178        math_ops.greater(
179            math_ops.cast(out_depth, max_value.dtype), max_value),
180        math_ops.greater_equal(
181            min_value, math_ops.cast(0, min_value.dtype)))
182    control_flow_ops.Assert(condition, [
183        "Input values must be in the range 0 <= values < num_tokens"
184        " with num_tokens={}".format(out_depth)
185    ])
186    if self.sparse:
187      return sparse_bincount(inputs, out_depth, binary_output, count_weights)
188    else:
189      return dense_bincount(inputs, out_depth, binary_output, count_weights)
190
191
192def sparse_bincount(inputs, out_depth, binary_output, count_weights=None):
193  """Apply binary or count encoding to an input and return a sparse tensor."""
194  result = bincount_ops.sparse_bincount(
195      inputs,
196      weights=count_weights,
197      minlength=out_depth,
198      maxlength=out_depth,
199      axis=-1,
200      binary_output=binary_output)
201  result = math_ops.cast(result, K.floatx())
202  batch_size = array_ops.shape(result)[0]
203  result = sparse_tensor.SparseTensor(
204      indices=result.indices,
205      values=result.values,
206      dense_shape=[batch_size, out_depth])
207  return result
208
209
210def dense_bincount(inputs, out_depth, binary_output, count_weights=None):
211  """Apply binary or count encoding to an input."""
212  result = bincount_ops.bincount(
213      inputs,
214      weights=count_weights,
215      minlength=out_depth,
216      maxlength=out_depth,
217      dtype=K.floatx(),
218      axis=-1,
219      binary_output=binary_output)
220  batch_size = inputs.shape.as_list()[0]
221  result.set_shape(tensor_shape.TensorShape((batch_size, out_depth)))
222  return result
223