1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in WideNDeep model classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.keras import activations
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras import layers as layer_module
25from tensorflow.python.keras.engine import base_layer
26from tensorflow.python.keras.engine import data_adapter
27from tensorflow.python.keras.engine import training as keras_training
28from tensorflow.python.keras.utils import generic_utils
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export('keras.experimental.WideDeepModel')
34class WideDeepModel(keras_training.Model):
35  r"""Wide & Deep Model for regression and classification problems.
36
37  This model jointly train a linear and a dnn model.
38
39  Example:
40
41  ```python
42  linear_model = LinearModel()
43  dnn_model = keras.Sequential([keras.layers.Dense(units=64),
44                               keras.layers.Dense(units=1)])
45  combined_model = WideDeepModel(linear_model, dnn_model)
46  combined_model.compile(optimizer=['sgd', 'adam'], 'mse', ['mse'])
47  # define dnn_inputs and linear_inputs as separate numpy arrays or
48  # a single numpy array if dnn_inputs is same as linear_inputs.
49  combined_model.fit([linear_inputs, dnn_inputs], y, epochs)
50  # or define a single `tf.data.Dataset` that contains a single tensor or
51  # separate tensors for dnn_inputs and linear_inputs.
52  dataset = tf.data.Dataset.from_tensors(([linear_inputs, dnn_inputs], y))
53  combined_model.fit(dataset, epochs)
54  ```
55
56  Both linear and dnn model can be pre-compiled and trained separately
57  before jointly training:
58
59  Example:
60  ```python
61  linear_model = LinearModel()
62  linear_model.compile('adagrad', 'mse')
63  linear_model.fit(linear_inputs, y, epochs)
64  dnn_model = keras.Sequential([keras.layers.Dense(units=1)])
65  dnn_model.compile('rmsprop', 'mse')
66  dnn_model.fit(dnn_inputs, y, epochs)
67  combined_model = WideDeepModel(linear_model, dnn_model)
68  combined_model.compile(optimizer=['sgd', 'adam'], 'mse', ['mse'])
69  combined_model.fit([linear_inputs, dnn_inputs], y, epochs)
70  ```
71
72  """
73
74  def __init__(self, linear_model, dnn_model, activation=None, **kwargs):
75    """Create a Wide & Deep Model.
76
77    Args:
78      linear_model: a premade LinearModel, its output must match the output of
79        the dnn model.
80      dnn_model: a `tf.keras.Model`, its output must match the output of the
81        linear model.
82      activation: Activation function. Set it to None to maintain a linear
83        activation.
84      **kwargs: The keyword arguments that are passed on to BaseLayer.__init__.
85        Allowed keyword arguments include `name`.
86    """
87    super(WideDeepModel, self).__init__(**kwargs)
88    base_layer.keras_premade_model_gauge.get_cell('WideDeep').set(True)
89    self.linear_model = linear_model
90    self.dnn_model = dnn_model
91    self.activation = activations.get(activation)
92
93  def call(self, inputs, training=None):
94    if not isinstance(inputs, (tuple, list)) or len(inputs) != 2:
95      linear_inputs = dnn_inputs = inputs
96    else:
97      linear_inputs, dnn_inputs = inputs
98    linear_output = self.linear_model(linear_inputs)
99    # pylint: disable=protected-access
100    if self.dnn_model._expects_training_arg:
101      if training is None:
102        training = K.learning_phase()
103      dnn_output = self.dnn_model(dnn_inputs, training=training)
104    else:
105      dnn_output = self.dnn_model(dnn_inputs)
106    output = nest.map_structure(lambda x, y: (x + y), linear_output, dnn_output)
107    if self.activation:
108      return nest.map_structure(self.activation, output)
109    return output
110
111  # This does not support gradient scaling and LossScaleOptimizer.
112  def train_step(self, data):
113    x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
114    x, y, sample_weight = data_adapter.expand_1d((x, y, sample_weight))
115
116    with backprop.GradientTape() as tape:
117      y_pred = self(x, training=True)
118      loss = self.compiled_loss(
119          y, y_pred, sample_weight, regularization_losses=self.losses)
120    self.compiled_metrics.update_state(y, y_pred, sample_weight)
121
122    if isinstance(self.optimizer, (list, tuple)):
123      linear_vars = self.linear_model.trainable_variables
124      dnn_vars = self.dnn_model.trainable_variables
125      linear_grads, dnn_grads = tape.gradient(loss, (linear_vars, dnn_vars))
126
127      linear_optimizer = self.optimizer[0]
128      dnn_optimizer = self.optimizer[1]
129      linear_optimizer.apply_gradients(zip(linear_grads, linear_vars))
130      dnn_optimizer.apply_gradients(zip(dnn_grads, dnn_vars))
131    else:
132      trainable_variables = self.trainable_variables
133      grads = tape.gradient(loss, trainable_variables)
134      self.optimizer.apply_gradients(zip(grads, trainable_variables))
135
136    return {m.name: m.result() for m in self.metrics}
137
138  def _make_train_function(self):
139    # Only needed for graph mode and model_to_estimator.
140    has_recompiled = self._recompile_weights_loss_and_weighted_metrics()
141    self._check_trainable_weights_consistency()
142    # If we have re-compiled the loss/weighted metric sub-graphs then create
143    # train function even if one exists already. This is because
144    # `_feed_sample_weights` list has been updated on re-compile.
145    if getattr(self, 'train_function', None) is None or has_recompiled:
146      # Restore the compiled trainable state.
147      current_trainable_state = self._get_trainable_state()
148      self._set_trainable_state(self._compiled_trainable_state)
149
150      inputs = (
151          self._feed_inputs + self._feed_targets + self._feed_sample_weights)
152      if not isinstance(K.symbolic_learning_phase(), int):
153        inputs += [K.symbolic_learning_phase()]
154
155      if isinstance(self.optimizer, (list, tuple)):
156        linear_optimizer = self.optimizer[0]
157        dnn_optimizer = self.optimizer[1]
158      else:
159        linear_optimizer = self.optimizer
160        dnn_optimizer = self.optimizer
161
162      with K.get_graph().as_default():
163        with K.name_scope('training'):
164          # Training updates
165          updates = []
166          linear_updates = linear_optimizer.get_updates(
167              params=self.linear_model.trainable_weights,  # pylint: disable=protected-access
168              loss=self.total_loss)
169          updates += linear_updates
170          dnn_updates = dnn_optimizer.get_updates(
171              params=self.dnn_model.trainable_weights,  # pylint: disable=protected-access
172              loss=self.total_loss)
173          updates += dnn_updates
174          # Unconditional updates
175          updates += self.get_updates_for(None)
176          # Conditional updates relevant to this model
177          updates += self.get_updates_for(self.inputs)
178
179        metrics = self._get_training_eval_metrics()
180        metrics_tensors = [
181            m._call_result for m in metrics if hasattr(m, '_call_result')  # pylint: disable=protected-access
182        ]
183
184      with K.name_scope('training'):
185        # Gets loss and metrics. Updates weights at each call.
186        fn = K.function(
187            inputs, [self.total_loss] + metrics_tensors,
188            updates=updates,
189            name='train_function',
190            **self._function_kwargs)
191        setattr(self, 'train_function', fn)
192
193      # Restore the current trainable state
194      self._set_trainable_state(current_trainable_state)
195
196  def get_config(self):
197    linear_config = generic_utils.serialize_keras_object(self.linear_model)
198    dnn_config = generic_utils.serialize_keras_object(self.dnn_model)
199    config = {
200        'linear_model': linear_config,
201        'dnn_model': dnn_config,
202        'activation': activations.serialize(self.activation),
203    }
204    base_config = base_layer.Layer.get_config(self)
205    return dict(list(base_config.items()) + list(config.items()))
206
207  @classmethod
208  def from_config(cls, config, custom_objects=None):
209    linear_config = config.pop('linear_model')
210    linear_model = layer_module.deserialize(linear_config, custom_objects)
211    dnn_config = config.pop('dnn_model')
212    dnn_model = layer_module.deserialize(dnn_config, custom_objects)
213    activation = activations.deserialize(
214        config.pop('activation', None), custom_objects=custom_objects)
215    return cls(
216        linear_model=linear_model,
217        dnn_model=dnn_model,
218        activation=activation,
219        **config)
220