1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for head.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23# pylint: disable=g-bad-todo,g-import-not-at-top
24import numpy as np
25import six
26
27from tensorflow.contrib.learn.python.learn.estimators import constants
28from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
29from tensorflow.contrib.learn.python.learn.estimators import model_fn
30from tensorflow.contrib.learn.python.learn.estimators import prediction_key
31from tensorflow.core.framework import summary_pb2
32from tensorflow.python.client import session
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.ops import lookup_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import variables
38from tensorflow.python.ops.losses import losses as losses_lib
39from tensorflow.python.platform import test
40
41
42def _assert_variables(test_case,
43                      expected_global=None,
44                      expected_model=None,
45                      expected_trainable=None):
46  test_case.assertItemsEqual(
47      tuple([] if expected_global is None else expected_global),
48      tuple([k.name for k in variables.global_variables()]))
49  test_case.assertItemsEqual(
50      tuple([] if expected_model is None else expected_model),
51      tuple([k.name for k in variables.model_variables()]))
52  test_case.assertItemsEqual(
53      tuple([] if expected_trainable is None else expected_trainable),
54      tuple([k.name for k in variables.trainable_variables()]))
55
56
57def _assert_no_variables(test_case):
58  _assert_variables(test_case)
59
60
61# This must be called from within a tf.Session.
62def _assert_metrics(test_case, expected_loss, expected_eval_metrics,
63                    model_fn_ops):
64  test_case.assertAlmostEqual(expected_loss, model_fn_ops.loss.eval(), places=4)
65  for k in six.iterkeys(expected_eval_metrics):
66    test_case.assertIn(k, six.iterkeys(model_fn_ops.eval_metric_ops))
67  variables.initialize_local_variables().run()
68  for key, expected_value in six.iteritems(expected_eval_metrics):
69    value_tensor, update_tensor = model_fn_ops.eval_metric_ops[key]
70    update = update_tensor.eval()
71    test_case.assertAlmostEqual(
72        expected_value,
73        update,
74        places=4,
75        msg="%s: update, expected %s, got %s." % (key, expected_value, update))
76    value = value_tensor.eval()
77    test_case.assertAlmostEqual(
78        expected_value,
79        value,
80        places=4,
81        msg="%s: value, expected %s, got %s." % (key, expected_value, value))
82
83
84# This must be called from within a tf.Session.
85def _assert_summary_tags(test_case, expected_tags=None):
86  actual_tags = []
87  for summary_op in ops.get_collection(ops.GraphKeys.SUMMARIES):
88    summ = summary_pb2.Summary()
89    summ.ParseFromString(summary_op.eval())
90    actual_tags.append(summ.value[0].tag)
91  test_case.assertItemsEqual(expected_tags or [], actual_tags)
92
93
94def _sigmoid(x):
95  return 1. / (1. + math.exp(-1 * x))
96
97
98class PoissonHeadTest(test.TestCase):
99
100  def _assert_output_alternatives(self, model_fn_ops):
101    self.assertEquals({
102        None: constants.ProblemType.LINEAR_REGRESSION
103    }, {
104        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
105    })
106
107  def _log_poisson_loss(self, logits, labels):
108    x = np.array([f[0] for f in logits])
109    z = np.array([f[0] for f in labels])
110    lpl = np.exp(x) - z * x
111    stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z)
112    lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)
113    return sum(lpl)/len(lpl)
114
115  def testPoissonWithLogits(self):
116    head = head_lib.poisson_regression_head()
117    labels = ((0.,), (1.,), (1.,))
118    logits = ((0.,), (-1.,), (3.,))
119    with ops.Graph().as_default(), session.Session():
120      model_fn_ops = head.create_model_fn_ops(
121          {},
122          labels=labels,
123          mode=model_fn.ModeKeys.TRAIN,
124          train_op_fn=head_lib.no_op_train_fn,
125          logits=logits)
126      self._assert_output_alternatives(model_fn_ops)
127      _assert_summary_tags(self, ["loss"])
128      _assert_no_variables(self)
129      loss = self._log_poisson_loss(logits, labels)
130      _assert_metrics(self, loss, {"loss": loss}, model_fn_ops)
131
132
133class RegressionHeadTest(test.TestCase):
134
135  def _assert_output_alternatives(self, model_fn_ops):
136    self.assertEquals({
137        None: constants.ProblemType.LINEAR_REGRESSION
138    }, {
139        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
140    })
141
142  # TODO(zakaria): test multilabel regression.
143  def testRegressionWithLogits(self):
144    head = head_lib.regression_head()
145    with ops.Graph().as_default(), session.Session():
146      model_fn_ops = head.create_model_fn_ops(
147          {},
148          labels=((0.,), (1.,), (1.,)),
149          mode=model_fn.ModeKeys.TRAIN,
150          train_op_fn=head_lib.no_op_train_fn,
151          logits=((1.,), (1.,), (3.,)))
152      self._assert_output_alternatives(model_fn_ops)
153      _assert_summary_tags(self, ["loss"])
154      _assert_no_variables(self)
155      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
156
157  def testRegressionWithLogitFn(self):
158    head = head_lib.regression_head(link_fn=math_ops.square)
159    def _assert_preditions(test_case, expected_predictions, model_fn_ops):
160      variables.initialize_local_variables().run()
161      test_case.assertAllClose(expected_predictions,
162                               model_fn_ops.predictions["scores"].eval())
163    with ops.Graph().as_default(), session.Session():
164      model_fn_ops = head.create_model_fn_ops(
165          {},
166          labels=((0.,), (1.,), (1.,)),
167          mode=model_fn.ModeKeys.TRAIN,
168          train_op_fn=head_lib.no_op_train_fn,
169          logits=((1.,), (1.,), (3.,)))
170      self._assert_output_alternatives(model_fn_ops)
171      _assert_summary_tags(self, ["loss"])
172      _assert_no_variables(self)
173      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
174      _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops)
175
176  def testRegressionWithInvalidLogits(self):
177    head = head_lib.regression_head()
178    with ops.Graph().as_default(), session.Session():
179      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
180        head.create_model_fn_ops(
181            {},
182            labels=((0.,), (1.,), (1.,)),
183            mode=model_fn.ModeKeys.TRAIN,
184            train_op_fn=head_lib.no_op_train_fn,
185            logits=((1., 1.), (1., 1.), (3., 1.)))
186
187  def testRegressionWithLogitsInput(self):
188    head = head_lib.regression_head()
189    with ops.Graph().as_default(), session.Session():
190      model_fn_ops = head.create_model_fn_ops(
191          {},
192          labels=((0.,), (1.,), (1.,)),
193          mode=model_fn.ModeKeys.TRAIN,
194          train_op_fn=head_lib.no_op_train_fn,
195          logits_input=((0., 0.), (0., 0.), (0., 0.)))
196      self._assert_output_alternatives(model_fn_ops)
197      w = ("regression_head/logits/weights:0",
198           "regression_head/logits/biases:0")
199      _assert_variables(
200          self, expected_global=w, expected_model=w, expected_trainable=w)
201      variables.global_variables_initializer().run()
202      _assert_summary_tags(self, ["loss"])
203      _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
204
205  def testRegressionWithLogitsAndLogitsInput(self):
206    head = head_lib.regression_head()
207    with ops.Graph().as_default(), session.Session():
208      with self.assertRaisesRegexp(
209          ValueError, "Both logits and logits_input supplied"):
210        head.create_model_fn_ops(
211            {},
212            labels=((0.,), (1.,), (1.,)),
213            mode=model_fn.ModeKeys.TRAIN,
214            train_op_fn=head_lib.no_op_train_fn,
215            logits_input=((0., 0.), (0., 0.), (0., 0.)),
216            logits=((1.,), (1.,), (3.,)))
217
218  def testRegressionEvalMode(self):
219    head = head_lib.regression_head()
220    with ops.Graph().as_default(), session.Session():
221      model_fn_ops = head.create_model_fn_ops(
222          {},
223          labels=((1.,), (1.,), (3.,)),
224          mode=model_fn.ModeKeys.EVAL,
225          train_op_fn=head_lib.no_op_train_fn,
226          logits=((0.,), (1.,), (1.,)))
227      self._assert_output_alternatives(model_fn_ops)
228      self.assertIsNone(model_fn_ops.train_op)
229      _assert_no_variables(self)
230      _assert_summary_tags(self, ["loss"])
231      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
232
233  def testRegressionWithLabelName(self):
234    label_name = "my_label"
235    head = head_lib.regression_head(label_name=label_name)
236    with ops.Graph().as_default(), session.Session():
237      model_fn_ops = head.create_model_fn_ops(
238          {},
239          labels={label_name: ((0.,), (1.,), (1.,))},
240          mode=model_fn.ModeKeys.TRAIN,
241          train_op_fn=head_lib.no_op_train_fn,
242          logits=((1.,), (1.,), (3.,)))
243      self._assert_output_alternatives(model_fn_ops)
244      _assert_no_variables(self)
245      _assert_summary_tags(self, ["loss"])
246      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
247
248  def testRegressionWithScalarWeights(self):
249    head = head_lib.regression_head(weight_column_name="label_weight")
250    with ops.Graph().as_default(), session.Session():
251      weights = 2.
252      labels = ((0.,), (1.,), (1.,))
253      model_fn_ops = head.create_model_fn_ops(
254          features={"label_weight": weights},
255          labels=labels,
256          mode=model_fn.ModeKeys.TRAIN,
257          train_op_fn=head_lib.no_op_train_fn,
258          logits=((1.,), (1.,), (3.,)))
259      self._assert_output_alternatives(model_fn_ops)
260      _assert_no_variables(self)
261      _assert_summary_tags(self, ["loss"])
262      _assert_metrics(self, (weights * 5.) / len(labels), {
263          "loss": (weights * 5.) / (weights * len(labels))
264      }, model_fn_ops)
265
266  def testRegressionWith1DWeights(self):
267    head = head_lib.regression_head(weight_column_name="label_weight")
268    with ops.Graph().as_default(), session.Session():
269      weights = (2., 5., 0.)
270      labels = ((0.,), (1.,), (1.,))
271      model_fn_ops = head.create_model_fn_ops(
272          features={"label_weight": weights},
273          labels=labels,
274          mode=model_fn.ModeKeys.TRAIN,
275          train_op_fn=head_lib.no_op_train_fn,
276          logits=((1.,), (1.,), (3.,)))
277      self._assert_output_alternatives(model_fn_ops)
278      _assert_no_variables(self)
279      _assert_summary_tags(self, ["loss"])
280      _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},
281                      model_fn_ops)
282
283  def testRegressionWith2DWeights(self):
284    head = head_lib.regression_head(weight_column_name="label_weight")
285    with ops.Graph().as_default(), session.Session():
286      weights = ((2.,), (5.,), (0.,))
287      labels = ((0.,), (1.,), (1.,))
288      model_fn_ops = head.create_model_fn_ops(
289          features={"label_weight": weights},
290          labels=labels,
291          mode=model_fn.ModeKeys.TRAIN,
292          train_op_fn=head_lib.no_op_train_fn,
293          logits=((1.,), (1.,), (3.,)))
294      self._assert_output_alternatives(model_fn_ops)
295      _assert_no_variables(self)
296      _assert_summary_tags(self, ["loss"])
297      _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},
298                      model_fn_ops)
299
300  def testRegressionWithCenteredBias(self):
301    head = head_lib.regression_head(enable_centered_bias=True)
302    with ops.Graph().as_default(), session.Session():
303      model_fn_ops = head.create_model_fn_ops(
304          {},
305          labels=((0.,), (1.,), (1.,)),
306          mode=model_fn.ModeKeys.TRAIN,
307          train_op_fn=head_lib.no_op_train_fn,
308          logits=((1.,), (1.,), (3.,)))
309      self._assert_output_alternatives(model_fn_ops)
310      _assert_variables(
311          self,
312          expected_global=(
313              "regression_head/centered_bias_weight:0",
314              "regression_head/regression_head/centered_bias_weight/Adagrad:0",
315          ),
316          expected_trainable=("regression_head/centered_bias_weight:0",))
317      variables.global_variables_initializer().run()
318      _assert_summary_tags(self, [
319          "loss",
320          "regression_head/centered_bias/bias_0"
321      ])
322      _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
323
324  def testRegressionErrorInSparseTensorLabels(self):
325    head = head_lib.regression_head()
326    with ops.Graph().as_default():
327      labels = sparse_tensor.SparseTensorValue(
328          indices=((0, 0), (1, 0), (2, 0)),
329          values=(0., 1., 1.),
330          dense_shape=(3, 1))
331      with self.assertRaisesRegexp(ValueError,
332                                   "SparseTensor is not supported"):
333        head.create_model_fn_ops(
334            {},
335            labels=labels,
336            mode=model_fn.ModeKeys.TRAIN,
337            train_op_fn=head_lib.no_op_train_fn,
338            logits=((1.,), (1.,), (3.,)))
339
340
341class MultiLabelHeadTest(test.TestCase):
342
343  def _assert_output_alternatives(self, model_fn_ops):
344    self.assertEquals({
345        None: constants.ProblemType.CLASSIFICATION
346    }, {
347        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
348    })
349
350  def setUp(self):
351    self._logits = ((1., 0., 0.),)
352    self._labels = ((0, 0, 1),)
353
354  def _expected_eval_metrics(self, expected_loss):
355    return {
356        "accuracy": 1. / 3,
357        "loss": expected_loss,
358        "auc": 1. / 4,
359        "auc/class0": 1.,
360        "auc/class1": 1.,
361        "auc/class2": 0.,
362        "auc_precision_recall": 0.166667,
363        "auc_precision_recall/class0": 0,
364        "auc_precision_recall/class1": 0.,
365        "auc_precision_recall/class2": 0.49999,
366        "labels/actual_label_mean/class0": self._labels[0][0],
367        "labels/actual_label_mean/class1": self._labels[0][1],
368        "labels/actual_label_mean/class2": self._labels[0][2],
369        "labels/logits_mean/class0": self._logits[0][0],
370        "labels/logits_mean/class1": self._logits[0][1],
371        "labels/logits_mean/class2": self._logits[0][2],
372        "labels/prediction_mean/class0": self._logits[0][0],
373        "labels/prediction_mean/class1": self._logits[0][1],
374        "labels/prediction_mean/class2": self._logits[0][2],
375        "labels/probability_mean/class0": _sigmoid(self._logits[0][0]),
376        "labels/probability_mean/class1": _sigmoid(self._logits[0][1]),
377        "labels/probability_mean/class2": _sigmoid(self._logits[0][2]),
378    }
379
380  def testMultiLabelWithLogits(self):
381    n_classes = 3
382    head = head_lib.multi_label_head(
383        n_classes=n_classes, metric_class_ids=range(n_classes))
384    with ops.Graph().as_default(), session.Session():
385      model_fn_ops = head.create_model_fn_ops(
386          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
387          logits=self._logits)
388      self._assert_output_alternatives(model_fn_ops)
389      _assert_no_variables(self)
390      _assert_summary_tags(self, ["loss"])
391      expected_loss = .89985204
392      _assert_metrics(self, expected_loss,
393                      self._expected_eval_metrics(expected_loss), model_fn_ops)
394
395  def testMultiLabelTwoClasses(self):
396    n_classes = 2
397    labels = ((0, 1),)
398    logits = ((1., 0.),)
399    head = head_lib.multi_label_head(
400        n_classes=n_classes, metric_class_ids=range(n_classes))
401    with ops.Graph().as_default(), session.Session():
402      model_fn_ops = head.create_model_fn_ops(
403          {}, model_fn.ModeKeys.TRAIN, labels=labels,
404          train_op_fn=head_lib.no_op_train_fn, logits=logits)
405      self._assert_output_alternatives(model_fn_ops)
406      _assert_no_variables(self)
407      _assert_summary_tags(self, ["loss"])
408      expected_loss = 1.00320443
409      _assert_metrics(self, expected_loss, {
410          "accuracy": 0.,
411          "auc": 0.,
412          "loss": expected_loss,
413          "auc/class0": 1.,
414          "auc/class1": 0.,
415          "labels/actual_label_mean/class0": labels[0][0],
416          "labels/actual_label_mean/class1": labels[0][1],
417          "labels/logits_mean/class0": logits[0][0],
418          "labels/logits_mean/class1": logits[0][1],
419          "labels/prediction_mean/class0": logits[0][0],
420          "labels/prediction_mean/class1": logits[0][1],
421          "labels/probability_mean/class0": _sigmoid(logits[0][0]),
422          "labels/probability_mean/class1": _sigmoid(logits[0][1]),
423      }, model_fn_ops)
424
425  def testMultiLabelWithInvalidLogits(self):
426    head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)
427    with ops.Graph().as_default(), session.Session():
428      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
429        head.create_model_fn_ops(
430            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
431            logits=self._logits)
432
433  def testMultiLabelWithLogitsInput(self):
434    n_classes = 3
435    head = head_lib.multi_label_head(
436        n_classes=n_classes, metric_class_ids=range(n_classes))
437    with ops.Graph().as_default(), session.Session():
438      model_fn_ops = head.create_model_fn_ops(
439          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
440          logits_input=((0., 0.),))
441      self._assert_output_alternatives(model_fn_ops)
442      w = ("multi_label_head/logits/weights:0",
443           "multi_label_head/logits/biases:0")
444      _assert_variables(
445          self, expected_global=w, expected_model=w, expected_trainable=w)
446      variables.global_variables_initializer().run()
447      _assert_summary_tags(self, ["loss"])
448      expected_loss = .69314718
449      _assert_metrics(self, expected_loss, {
450          "accuracy": 2. / 3,
451          "auc": 2. / 4,
452          "loss": expected_loss,
453          "auc/class0": 1.,
454          "auc/class1": 1.,
455          "auc/class2": 0.,
456          "labels/actual_label_mean/class0": self._labels[0][0],
457          "labels/actual_label_mean/class1": self._labels[0][1],
458          "labels/actual_label_mean/class2": self._labels[0][2],
459          "labels/logits_mean/class0": 0.,
460          "labels/logits_mean/class1": 0.,
461          "labels/logits_mean/class2": 0.,
462          "labels/prediction_mean/class0": 0.,
463          "labels/prediction_mean/class1": 0.,
464          "labels/prediction_mean/class2": 0.,
465          "labels/probability_mean/class0": .5,
466          "labels/probability_mean/class1": .5,
467          "labels/probability_mean/class2": .5,
468      }, model_fn_ops)
469
470  def testMultiLabelWithLogitsAndLogitsInput(self):
471    n_classes = 3
472    head = head_lib.multi_label_head(
473        n_classes=n_classes, metric_class_ids=range(n_classes))
474    with ops.Graph().as_default(), session.Session():
475      with self.assertRaisesRegexp(
476          ValueError, "Both logits and logits_input supplied"):
477        head.create_model_fn_ops(
478            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
479            logits_input=((0., 0.),), logits=self._logits)
480
481  def testMultiLabelEval(self):
482    n_classes = 3
483    head = head_lib.multi_label_head(
484        n_classes=n_classes, metric_class_ids=range(n_classes))
485    with ops.Graph().as_default(), session.Session():
486      model_fn_ops = head.create_model_fn_ops(
487          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,
488          logits=self._logits)
489      self._assert_output_alternatives(model_fn_ops)
490      self.assertIsNone(model_fn_ops.train_op)
491      _assert_no_variables(self)
492      _assert_summary_tags(self, ["loss"])
493      expected_loss = .89985204
494      _assert_metrics(self, expected_loss,
495                      self._expected_eval_metrics(expected_loss), model_fn_ops)
496
497  def testMultiClassEvalWithLargeLogits(self):
498    n_classes = 3
499    head = head_lib.multi_label_head(
500        n_classes=n_classes, metric_class_ids=range(n_classes))
501    logits = ((2., 0., -1),)
502    with ops.Graph().as_default(), session.Session():
503      # logloss: z:label, x:logit
504      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
505      model_fn_ops = head.create_model_fn_ops(
506          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,
507          logits=logits)
508      self._assert_output_alternatives(model_fn_ops)
509      self.assertIsNone(model_fn_ops.train_op)
510      _assert_no_variables(self)
511      _assert_summary_tags(self, ["loss"])
512      expected_loss = 1.377779
513      expected_eval_metrics = {
514          "accuracy": 1. / 3,
515          "auc": 9.99999e-07,
516          "loss": expected_loss,
517          "auc/class0": 1.,
518          "auc/class1": 1.,
519          "auc/class2": 0.,
520          "labels/actual_label_mean/class0": 0. / 1,
521          "labels/actual_label_mean/class1": 0. / 1,
522          "labels/actual_label_mean/class2": 1. / 1,
523          "labels/logits_mean/class0": logits[0][0],
524          "labels/logits_mean/class1": logits[0][1],
525          "labels/logits_mean/class2": logits[0][2],
526          "labels/prediction_mean/class0": 1,
527          "labels/prediction_mean/class1": 0,
528          "labels/prediction_mean/class2": 0,
529          "labels/probability_mean/class0": _sigmoid(logits[0][0]),
530          "labels/probability_mean/class1": _sigmoid(logits[0][1]),
531          "labels/probability_mean/class2": _sigmoid(logits[0][2]),
532      }
533      _assert_metrics(self, expected_loss,
534                      expected_eval_metrics, model_fn_ops)
535
536  def testMultiLabelInfer(self):
537    n_classes = 3
538    head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")
539    with ops.Graph().as_default(), session.Session():
540      model_fn_ops = head.create_model_fn_ops(
541          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
542          logits=((1., 0., 0.), (0., 0., 1)))
543      self.assertIsNone(model_fn_ops.train_op)
544      _assert_no_variables(self)
545      with session.Session():
546        self.assertListEqual(
547            [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])
548        self.assertItemsEqual(
549            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
550        self.assertEqual(
551            constants.ProblemType.CLASSIFICATION,
552            model_fn_ops.output_alternatives["head_name"][0])
553
554        predictions_for_serving = (
555            model_fn_ops.output_alternatives["head_name"][1])
556        self.assertIn("classes", six.iterkeys(predictions_for_serving))
557        self.assertAllEqual(
558            [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
559            predictions_for_serving["classes"].eval())
560        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
561        self.assertAllClose(
562            [[0.731059, 0.5, 0.5],
563             [0.5, 0.5, 0.731059,]],
564            predictions_for_serving["probabilities"].eval())
565
566  def testMultiLabelWithLabelName(self):
567    n_classes = 3
568    label_name = "my_label"
569    head = head_lib.multi_label_head(
570        n_classes=n_classes,
571        label_name=label_name,
572        metric_class_ids=range(n_classes))
573    with ops.Graph().as_default(), session.Session():
574      model_fn_ops = head.create_model_fn_ops(
575          {}, model_fn.ModeKeys.TRAIN, {label_name: self._labels},
576          head_lib.no_op_train_fn, logits=self._logits)
577      self._assert_output_alternatives(model_fn_ops)
578      _assert_no_variables(self)
579      _assert_summary_tags(self, ["loss"])
580      expected_loss = .89985204
581      _assert_metrics(self, expected_loss,
582                      self._expected_eval_metrics(expected_loss), model_fn_ops)
583
584  def testMultiLabelWithScalarWeight(self):
585    n_classes = 3
586    head = head_lib.multi_label_head(
587        n_classes=n_classes,
588        weight_column_name="label_weight",
589        metric_class_ids=range(n_classes))
590    with ops.Graph().as_default(), session.Session():
591      model_fn_ops = head.create_model_fn_ops(
592          features={"label_weight": .1},
593          labels=self._labels,
594          mode=model_fn.ModeKeys.TRAIN,
595          train_op_fn=head_lib.no_op_train_fn,
596          logits=self._logits)
597      self._assert_output_alternatives(model_fn_ops)
598      _assert_no_variables(self)
599      _assert_summary_tags(self, ["loss"])
600      _assert_metrics(self, .089985214,
601                      self._expected_eval_metrics(.89985214), model_fn_ops)
602
603  def testMultiLabelWith1DWeight(self):
604    n_classes = 3
605    head = head_lib.multi_label_head(
606        n_classes=n_classes,
607        weight_column_name="label_weight",
608        metric_class_ids=range(n_classes))
609    with ops.Graph().as_default(), session.Session():
610      with self.assertRaisesRegexp(
611          ValueError, "weights can not be broadcast to values"):
612        head.create_model_fn_ops(
613            features={"label_weight": (.1, .1, .1)},
614            labels=self._labels,
615            mode=model_fn.ModeKeys.TRAIN,
616            train_op_fn=head_lib.no_op_train_fn,
617            logits=self._logits)
618
619  def testMultiLabelWith2DWeight(self):
620    n_classes = 3
621    head = head_lib.multi_label_head(
622        n_classes=n_classes,
623        weight_column_name="label_weight",
624        metric_class_ids=range(n_classes))
625    with ops.Graph().as_default(), session.Session():
626      model_fn_ops = head.create_model_fn_ops(
627          features={"label_weight": ((.1, .1, .1),)},
628          labels=self._labels,
629          mode=model_fn.ModeKeys.TRAIN,
630          train_op_fn=head_lib.no_op_train_fn,
631          logits=self._logits)
632      self._assert_output_alternatives(model_fn_ops)
633      _assert_no_variables(self)
634      _assert_summary_tags(self, ["loss"])
635      _assert_metrics(self, .089985214,
636                      self._expected_eval_metrics(.89985214), model_fn_ops)
637
638  def testMultiLabelWithCustomLoss(self):
639    n_classes = 3
640    head = head_lib.multi_label_head(
641        n_classes=n_classes,
642        weight_column_name="label_weight",
643        metric_class_ids=range(n_classes),
644        loss_fn=_sigmoid_cross_entropy)
645    with ops.Graph().as_default(), session.Session():
646      model_fn_ops = head.create_model_fn_ops(
647          features={"label_weight": .1},
648          labels=self._labels,
649          mode=model_fn.ModeKeys.TRAIN,
650          train_op_fn=head_lib.no_op_train_fn,
651          logits=self._logits)
652      self._assert_output_alternatives(model_fn_ops)
653      _assert_no_variables(self)
654      _assert_summary_tags(self, ["loss"])
655      expected_loss = .089985214
656      _assert_metrics(self, expected_loss,
657                      self._expected_eval_metrics(expected_loss), model_fn_ops)
658
659  def testMultiLabelWithCenteredBias(self):
660    n_classes = 3
661    head = head_lib.multi_label_head(
662        n_classes=n_classes,
663        enable_centered_bias=True,
664        metric_class_ids=range(n_classes))
665    with ops.Graph().as_default(), session.Session():
666      model_fn_ops = head.create_model_fn_ops(
667          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
668          logits=self._logits)
669      self._assert_output_alternatives(model_fn_ops)
670      _assert_variables(
671          self,
672          expected_global=(
673              "multi_label_head/centered_bias_weight:0",
674              ("multi_label_head/multi_label_head/centered_bias_weight/"
675               "Adagrad:0"),),
676          expected_trainable=("multi_label_head/centered_bias_weight:0",))
677      variables.global_variables_initializer().run()
678      _assert_summary_tags(self, (
679          "loss",
680          "multi_label_head/centered_bias/bias_0",
681          "multi_label_head/centered_bias/bias_1",
682          "multi_label_head/centered_bias/bias_2"
683      ))
684      expected_loss = .89985204
685      _assert_metrics(self, expected_loss,
686                      self._expected_eval_metrics(expected_loss), model_fn_ops)
687
688  def testMultiLabelSparseTensorLabels(self):
689    n_classes = 3
690    head = head_lib.multi_label_head(
691        n_classes=n_classes, metric_class_ids=range(n_classes))
692    with ops.Graph().as_default(), session.Session():
693      labels = sparse_tensor.SparseTensorValue(
694          indices=((0, 0),),
695          values=(2,),
696          dense_shape=(1, 1))
697      model_fn_ops = head.create_model_fn_ops(
698          features={},
699          mode=model_fn.ModeKeys.TRAIN,
700          labels=labels,
701          train_op_fn=head_lib.no_op_train_fn,
702          logits=self._logits)
703      _assert_no_variables(self)
704      _assert_summary_tags(self, ["loss"])
705      expected_loss = .89985204
706      _assert_metrics(self, expected_loss,
707                      self._expected_eval_metrics(expected_loss), model_fn_ops)
708
709  def testMultiLabelSparseTensorLabelsTooFewClasses(self):
710    n_classes = 3
711    head = head_lib.multi_label_head(
712        n_classes=n_classes, metric_class_ids=range(n_classes))
713    # Set _logits_dimension (n_classes) to a lower value; if it's set to 1
714    # upfront, the class throws an error during initialization.
715    head._logits_dimension = 1
716    with ops.Graph().as_default(), session.Session():
717      labels = sparse_tensor.SparseTensorValue(
718          indices=((0, 0),),
719          values=(2,),
720          dense_shape=(1, 1))
721      with self.assertRaisesRegexp(ValueError,
722                                   "Must set num_classes >= 2 when passing"):
723        head.create_model_fn_ops(
724            features={},
725            labels=labels,
726            mode=model_fn.ModeKeys.TRAIN,
727            train_op_fn=head_lib.no_op_train_fn,
728            logits=[0.])
729
730
731class BinaryClassificationHeadTest(test.TestCase):
732
733  def _assert_output_alternatives(self, model_fn_ops):
734    self.assertEquals({
735        None: constants.ProblemType.LOGISTIC_REGRESSION
736    }, {
737        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
738    })
739
740  def setUp(self):
741    self._logits = ((1.,), (1.,))
742    self._labels = ((1.,), (0.,))
743
744  def _expected_eval_metrics(self, expected_loss):
745    label_mean = np.mean(self._labels)
746    return {
747        "accuracy": 1. / 2,
748        "accuracy/baseline_label_mean": label_mean,
749        "accuracy/threshold_0.500000_mean": 1. / 2,
750        "auc": 1. / 2,
751        "auc_precision_recall": 0.25,
752        "labels/actual_label_mean": label_mean,
753        "labels/prediction_mean": .731059,  # softmax
754        "loss": expected_loss,
755        "precision/positive_threshold_0.500000_mean": 1. / 2,
756        "recall/positive_threshold_0.500000_mean": 1. / 1,
757    }
758
759  def testBinaryClassificationWithLogits(self):
760    n_classes = 2
761    head = head_lib.multi_class_head(n_classes=n_classes)
762    with ops.Graph().as_default(), session.Session():
763      # logloss: z:label, x:logit
764      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
765      model_fn_ops = head.create_model_fn_ops(
766          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
767          logits=self._logits)
768      self._assert_output_alternatives(model_fn_ops)
769      _assert_no_variables(self)
770      _assert_summary_tags(self, ["loss"])
771      expected_loss = .81326175
772      _assert_metrics(self, expected_loss,
773                      self._expected_eval_metrics(expected_loss), model_fn_ops)
774
775  def testBinaryClassificationWithInvalidLogits(self):
776    head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)
777    with ops.Graph().as_default(), session.Session():
778      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
779        head.create_model_fn_ops(
780            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
781            logits=self._logits)
782
783  def testBinaryClassificationWithLogitsInput(self):
784    n_classes = 2
785    head = head_lib.multi_class_head(n_classes=n_classes)
786    with ops.Graph().as_default(), session.Session():
787      # logloss: z:label, x:logit
788      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
789      model_fn_ops = head.create_model_fn_ops(
790          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
791          logits_input=((0., 0.), (0., 0.)))
792      self._assert_output_alternatives(model_fn_ops)
793      w = ("binary_logistic_head/logits/weights:0",
794           "binary_logistic_head/logits/biases:0")
795      _assert_variables(
796          self, expected_global=w, expected_model=w, expected_trainable=w)
797      variables.global_variables_initializer().run()
798      _assert_summary_tags(self, ["loss"])
799      expected_loss = .69314718
800      label_mean = np.mean(self._labels)
801      _assert_metrics(self, expected_loss, {
802          "accuracy": 1. / 2,
803          "accuracy/baseline_label_mean": label_mean,
804          "accuracy/threshold_0.500000_mean": 1. / 2,
805          "auc": 1. / 2,
806          "labels/actual_label_mean": label_mean,
807          "labels/prediction_mean": .5,  # softmax
808          "loss": expected_loss,
809          "precision/positive_threshold_0.500000_mean": 0. / 2,
810          "recall/positive_threshold_0.500000_mean": 0. / 1,
811      }, model_fn_ops)
812
813  def testBinaryClassificationWithLogitsAndLogitsInput(self):
814    head = head_lib.multi_class_head(n_classes=2)
815    with ops.Graph().as_default(), session.Session():
816      with self.assertRaisesRegexp(
817          ValueError, "Both logits and logits_input supplied"):
818        head.create_model_fn_ops(
819            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
820            logits_input=((0., 0.), (0., 0.)), logits=self._logits)
821
822  def testBinaryClassificationEval(self):
823    n_classes = 2
824    head = head_lib.multi_class_head(n_classes=n_classes)
825    with ops.Graph().as_default(), session.Session():
826      # logloss: z:label, x:logit
827      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
828      model_fn_ops = head.create_model_fn_ops(
829          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,
830          logits=self._logits)
831      self._assert_output_alternatives(model_fn_ops)
832      self.assertIsNone(model_fn_ops.train_op)
833      _assert_no_variables(self)
834      _assert_summary_tags(self, ["loss"])
835      expected_loss = .81326175
836      _assert_metrics(self, expected_loss,
837                      self._expected_eval_metrics(expected_loss), model_fn_ops)
838
839  def testBinaryClassificationInfer(self):
840    n_classes = 2
841    head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")
842    with ops.Graph().as_default(), session.Session():
843      # logloss: z:label, x:logit
844      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
845      model_fn_ops = head.create_model_fn_ops(
846          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
847          logits=self._logits)
848      self.assertIsNone(model_fn_ops.train_op)
849      _assert_no_variables(self)
850      with session.Session():
851        self.assertListEqual(
852            [1, 1], list(model_fn_ops.predictions["classes"].eval()))
853        self.assertItemsEqual(
854            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
855        self.assertEqual(
856            constants.ProblemType.LOGISTIC_REGRESSION,
857            model_fn_ops.output_alternatives["head_name"][0])
858        predictions_for_serving = (
859            model_fn_ops.output_alternatives["head_name"][1])
860        self.assertIn("classes", six.iterkeys(predictions_for_serving))
861        predicted_classes = predictions_for_serving["classes"].eval().tolist()
862        self.assertListEqual(
863            [b"0", b"1"], predicted_classes[0])
864        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
865
866  def testBinaryClassificationInferMode_withWeightColumn(self):
867    n_classes = 2
868    head = head_lib.multi_class_head(n_classes=n_classes,
869                                     weight_column_name="label_weight")
870    with ops.Graph().as_default(), session.Session():
871      # logloss: z:label, x:logit
872      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
873      model_fn_ops = head.create_model_fn_ops(
874          # This is what is being tested, features should not have weight for
875          # inference.
876          {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
877          logits=self._logits)
878      self._assert_output_alternatives(model_fn_ops)
879      self.assertIsNone(model_fn_ops.train_op)
880      _assert_no_variables(self)
881
882  def testErrorInSparseTensorLabels(self):
883    n_classes = 2
884    head = head_lib.multi_class_head(n_classes=n_classes)
885    with ops.Graph().as_default():
886      labels = sparse_tensor.SparseTensorValue(
887          indices=((0, 0), (1, 0), (2, 0)),
888          values=(0, 1, 1),
889          dense_shape=(3, 1))
890      with self.assertRaisesRegexp(ValueError,
891                                   "SparseTensor is not supported"):
892        head.create_model_fn_ops(
893            {},
894            model_fn.ModeKeys.TRAIN,
895            labels,
896            head_lib.no_op_train_fn,
897            logits=((1.,), (1.,), (3.,)))
898
899  def testBinaryClassificationWithLabelName(self):
900    label_name = "my_label"
901    head = head_lib.multi_class_head(n_classes=2, label_name=label_name)
902    with ops.Graph().as_default(), session.Session():
903      # logloss: z:label, x:logit
904      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
905      model_fn_ops = head.create_model_fn_ops(
906          {},
907          labels={label_name: self._labels},
908          mode=model_fn.ModeKeys.TRAIN,
909          train_op_fn=head_lib.no_op_train_fn,
910          logits=self._logits)
911      self._assert_output_alternatives(model_fn_ops)
912      _assert_no_variables(self)
913      _assert_summary_tags(self, ["loss"])
914      expected_loss = .81326175
915      _assert_metrics(self, expected_loss,
916                      self._expected_eval_metrics(expected_loss), model_fn_ops)
917
918  def testBinaryClassificationWith1DWeights(self):
919    n_classes = 2
920    head = head_lib.multi_class_head(
921        n_classes=n_classes, weight_column_name="label_weight")
922    with ops.Graph().as_default(), session.Session():
923      weights = (1., 0.)
924      # logloss: z:label, x:logit
925      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
926      model_fn_ops = head.create_model_fn_ops(
927          features={"label_weight": weights},
928          labels=self._labels,
929          mode=model_fn.ModeKeys.TRAIN,
930          train_op_fn=head_lib.no_op_train_fn,
931          logits=self._logits)
932      self._assert_output_alternatives(model_fn_ops)
933      _assert_no_variables(self)
934      _assert_summary_tags(self, ["loss"])
935      expected_total_loss = .31326166
936      _assert_metrics(
937          self,
938          expected_total_loss / len(weights),
939          {
940              "accuracy": 1. / 1,
941              "accuracy/baseline_label_mean": 1. / 1,
942              "accuracy/threshold_0.500000_mean": 1. / 1,
943              "auc": 0. / 1,
944              "labels/actual_label_mean": 1. / 1,
945              "labels/prediction_mean": .731059,  # softmax
946              # eval loss is weighted loss divided by sum of weights.
947              "loss": expected_total_loss,
948              "precision/positive_threshold_0.500000_mean": 1. / 1,
949              "recall/positive_threshold_0.500000_mean": 1. / 1,
950          },
951          model_fn_ops)
952
953  def testBinaryClassificationWith2DWeights(self):
954    n_classes = 2
955    head = head_lib.multi_class_head(
956        n_classes=n_classes, weight_column_name="label_weight")
957    with ops.Graph().as_default(), session.Session():
958      weights = ((1.,), (0.,))
959      # logloss: z:label, x:logit
960      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
961      model_fn_ops = head.create_model_fn_ops(
962          features={"label_weight": weights},
963          labels=self._labels,
964          mode=model_fn.ModeKeys.TRAIN,
965          train_op_fn=head_lib.no_op_train_fn,
966          logits=self._logits)
967      self._assert_output_alternatives(model_fn_ops)
968      _assert_no_variables(self)
969      _assert_summary_tags(self, ["loss"])
970      expected_total_loss = .31326166
971      _assert_metrics(
972          self,
973          expected_total_loss / len(weights),
974          {
975              "accuracy": 1. / 1,
976              "accuracy/baseline_label_mean": 1. / 1,
977              "accuracy/threshold_0.500000_mean": 1. / 1,
978              "auc": 0. / 1,
979              "labels/actual_label_mean": 1. / 1,
980              "labels/prediction_mean": .731059,  # softmax
981              # eval loss is weighted loss divided by sum of weights.
982              "loss": expected_total_loss,
983              "precision/positive_threshold_0.500000_mean": 1. / 1,
984              "recall/positive_threshold_0.500000_mean": 1. / 1,
985          },
986          model_fn_ops)
987
988  def testBinaryClassificationWithCustomLoss(self):
989    head = head_lib.multi_class_head(
990        n_classes=2, weight_column_name="label_weight",
991        loss_fn=_sigmoid_cross_entropy)
992    with ops.Graph().as_default(), session.Session():
993      weights = ((.2,), (0.,))
994      model_fn_ops = head.create_model_fn_ops(
995          features={"label_weight": weights},
996          labels=self._labels,
997          mode=model_fn.ModeKeys.TRAIN,
998          train_op_fn=head_lib.no_op_train_fn,
999          logits=self._logits)
1000      self._assert_output_alternatives(model_fn_ops)
1001      _assert_no_variables(self)
1002      _assert_summary_tags(self, ["loss"])
1003      # logloss: z:label, x:logit
1004      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1005      # expected_loss is (total_weighted_loss)/1 since there is 1 nonzero
1006      # weight.
1007      expected_loss = 0.062652342
1008      _assert_metrics(
1009          self,
1010          expected_loss,
1011          {
1012              "accuracy": 1. / 1,
1013              "accuracy/baseline_label_mean": 1. / 1,
1014              "accuracy/threshold_0.500000_mean": 1. / 1,
1015              "auc": 0. / 1,
1016              "labels/actual_label_mean": 1. / 1,
1017              "labels/prediction_mean": .731059,  # softmax
1018              "loss": expected_loss,
1019              "precision/positive_threshold_0.500000_mean": 1. / 1,
1020              "recall/positive_threshold_0.500000_mean": 1. / 1,
1021          },
1022          model_fn_ops)
1023
1024  def testBinaryClassificationWithCenteredBias(self):
1025    head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)
1026    with ops.Graph().as_default(), session.Session():
1027      # logloss: z:label, x:logit
1028      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1029      model_fn_ops = head.create_model_fn_ops(
1030          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1031          logits=self._logits)
1032      self._assert_output_alternatives(model_fn_ops)
1033      _assert_variables(
1034          self,
1035          expected_global=(
1036              "binary_logistic_head/centered_bias_weight:0",
1037              ("binary_logistic_head/binary_logistic_head/centered_bias_weight/"
1038               "Adagrad:0"),),
1039          expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
1040      variables.global_variables_initializer().run()
1041      _assert_summary_tags(self, [
1042          "loss",
1043          "binary_logistic_head/centered_bias/bias_0"
1044      ])
1045      expected_loss = .81326175
1046      _assert_metrics(self, expected_loss,
1047                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1048
1049
1050class MultiClassHeadTest(test.TestCase):
1051
1052  def _assert_output_alternatives(self, model_fn_ops):
1053    self.assertEquals({
1054        None: constants.ProblemType.CLASSIFICATION
1055    }, {
1056        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
1057    })
1058
1059  def setUp(self):
1060    self._logits = ((1., 0., 0.),)
1061    self._labels = ((2,),)
1062
1063  def _expected_eval_metrics(self, expected_loss):
1064    return {
1065        "accuracy": 0.,
1066        "loss": expected_loss,
1067        "labels/actual_label_mean/class0": 0. / 1,
1068        "labels/actual_label_mean/class1": 0. / 1,
1069        "labels/actual_label_mean/class2": 1. / 1,
1070        "labels/logits_mean/class0": self._logits[0][0],
1071        "labels/logits_mean/class1": self._logits[0][1],
1072        "labels/logits_mean/class2": self._logits[0][2],
1073        "labels/prediction_mean/class0": self._logits[0][0],
1074        "labels/prediction_mean/class1": self._logits[0][1],
1075        "labels/prediction_mean/class2": self._logits[0][2],
1076        "labels/probability_mean/class0": 0.576117,  # softmax
1077        "labels/probability_mean/class1": 0.211942,  # softmax
1078        "labels/probability_mean/class2": 0.211942,  # softmax
1079    }
1080
1081  def testMultiClassWithLogits(self):
1082    n_classes = 3
1083    head = head_lib.multi_class_head(
1084        n_classes=n_classes, metric_class_ids=range(n_classes))
1085    with ops.Graph().as_default(), session.Session():
1086      # logloss: z:label, x:logit
1087      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1088      model_fn_ops = head.create_model_fn_ops(
1089          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1090          logits=self._logits)
1091      self._assert_output_alternatives(model_fn_ops)
1092      _assert_no_variables(self)
1093      _assert_summary_tags(self, ["loss"])
1094      expected_loss = 1.5514447
1095      _assert_metrics(self, expected_loss,
1096                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1097
1098  def testMultiClassWithInvalidLogits(self):
1099    head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)
1100    with ops.Graph().as_default(), session.Session():
1101      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
1102        head.create_model_fn_ops(
1103            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1104            logits=self._logits)
1105
1106  def testMultiClassWithNoneTrainOpFnInTrain(self):
1107    head = head_lib.multi_class_head(n_classes=3)
1108    with ops.Graph().as_default(), session.Session():
1109      with self.assertRaisesRegexp(
1110          ValueError, "train_op_fn can not be None in TRAIN mode"):
1111        head.create_model_fn_ops(
1112            {}, model_fn.ModeKeys.TRAIN, self._labels,
1113            train_op_fn=None,
1114            logits=self._logits)
1115
1116  def testMultiClassWithLogitsInput(self):
1117    n_classes = 3
1118    head = head_lib.multi_class_head(
1119        n_classes=n_classes, metric_class_ids=range(n_classes))
1120    with ops.Graph().as_default(), session.Session():
1121      # logloss: z:label, x:logit
1122      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1123      model_fn_ops = head.create_model_fn_ops(
1124          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1125          logits_input=((0., 0.),))
1126      self._assert_output_alternatives(model_fn_ops)
1127      w = ("multi_class_head/logits/weights:0",
1128           "multi_class_head/logits/biases:0")
1129      _assert_variables(
1130          self, expected_global=w, expected_model=w, expected_trainable=w)
1131      variables.global_variables_initializer().run()
1132      _assert_summary_tags(self, ["loss"])
1133      expected_loss = 1.0986123
1134      _assert_metrics(self, expected_loss, {
1135          "accuracy": 0.,
1136          "loss": expected_loss,
1137          "labels/actual_label_mean/class0": 0. / 1,
1138          "labels/actual_label_mean/class1": 0. / 1,
1139          "labels/actual_label_mean/class2": 1. / 1,
1140          "labels/logits_mean/class0": 0.,
1141          "labels/logits_mean/class1": 0.,
1142          "labels/logits_mean/class2": 0.,
1143          "labels/prediction_mean/class0": 1.,
1144          "labels/prediction_mean/class1": 0.,
1145          "labels/prediction_mean/class2": 0.,
1146          "labels/probability_mean/class0": 0.333333,  # softmax
1147          "labels/probability_mean/class1": 0.333333,  # softmax
1148          "labels/probability_mean/class2": 0.333333,  # softmax
1149      }, model_fn_ops)
1150
1151  def testMultiClassWithLogitsAndLogitsInput(self):
1152    n_classes = 3
1153    head = head_lib.multi_class_head(
1154        n_classes=n_classes, metric_class_ids=range(n_classes))
1155    with ops.Graph().as_default(), session.Session():
1156      with self.assertRaisesRegexp(
1157          ValueError, "Both logits and logits_input supplied"):
1158        head.create_model_fn_ops(
1159            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1160            logits_input=((0., 0.),), logits=self._logits)
1161
1162  def testMultiClassEnableCenteredBias(self):
1163    n_classes = 3
1164    head = head_lib.multi_class_head(
1165        n_classes=n_classes, enable_centered_bias=True)
1166    with ops.Graph().as_default(), session.Session():
1167      # logloss: z:label, x:logit
1168      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1169      model_fn_ops = head.create_model_fn_ops(
1170          {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1171          logits=self._logits)
1172      self._assert_output_alternatives(model_fn_ops)
1173      _assert_variables(
1174          self,
1175          expected_global=(
1176              "multi_class_head/centered_bias_weight:0",
1177              ("multi_class_head/multi_class_head/centered_bias_weight/"
1178               "Adagrad:0"),
1179          ),
1180          expected_trainable=("multi_class_head/centered_bias_weight:0",))
1181      variables.global_variables_initializer().run()
1182      _assert_summary_tags(self,
1183                           ["loss",
1184                            "multi_class_head/centered_bias/bias_0",
1185                            "multi_class_head/centered_bias/bias_1",
1186                            "multi_class_head/centered_bias/bias_2"])
1187
1188  def testMultiClassEval(self):
1189    n_classes = 3
1190    head = head_lib.multi_class_head(
1191        n_classes=n_classes, metric_class_ids=range(n_classes))
1192    with ops.Graph().as_default(), session.Session():
1193      # logloss: z:label, x:logit
1194      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1195      model_fn_ops = head.create_model_fn_ops(
1196          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,
1197          logits=self._logits)
1198      self._assert_output_alternatives(model_fn_ops)
1199      self.assertIsNone(model_fn_ops.train_op)
1200      _assert_no_variables(self)
1201      _assert_summary_tags(self, ["loss"])
1202      expected_loss = 1.5514447
1203      _assert_metrics(self, expected_loss,
1204                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1205
1206  def testMultiClassEvalModeWithLargeLogits(self):
1207    n_classes = 3
1208    head = head_lib.multi_class_head(
1209        n_classes=n_classes, metric_class_ids=range(n_classes))
1210    logits = ((2., 0., -1),)
1211    with ops.Graph().as_default(), session.Session():
1212      # logloss: z:label, x:logit
1213      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1214      model_fn_ops = head.create_model_fn_ops(
1215          {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,
1216          logits=logits)
1217      self._assert_output_alternatives(model_fn_ops)
1218      self.assertIsNone(model_fn_ops.train_op)
1219      _assert_no_variables(self)
1220      _assert_summary_tags(self, ["loss"])
1221      expected_loss = 3.1698461
1222      expected_eval_metrics = {
1223          "accuracy": 0.,
1224          "loss": expected_loss,
1225          "labels/actual_label_mean/class0": 0. / 1,
1226          "labels/actual_label_mean/class1": 0. / 1,
1227          "labels/actual_label_mean/class2": 1. / 1,
1228          "labels/logits_mean/class0": logits[0][0],
1229          "labels/logits_mean/class1": logits[0][1],
1230          "labels/logits_mean/class2": logits[0][2],
1231          "labels/prediction_mean/class0": 1,
1232          "labels/prediction_mean/class1": 0,
1233          "labels/prediction_mean/class2": 0,
1234          "labels/probability_mean/class0": 0.843795,  # softmax
1235          "labels/probability_mean/class1": 0.114195,  # softmax
1236          "labels/probability_mean/class2": 0.0420101,  # softmax
1237      }
1238      _assert_metrics(self, expected_loss,
1239                      expected_eval_metrics, model_fn_ops)
1240
1241  def testMultiClassWithScalarWeight(self):
1242    n_classes = 3
1243    head = head_lib.multi_class_head(
1244        n_classes=n_classes,
1245        weight_column_name="label_weight",
1246        metric_class_ids=range(n_classes))
1247    with ops.Graph().as_default(), session.Session():
1248      weight = .1
1249      # logloss: z:label, x:logit
1250      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1251      model_fn_ops = head.create_model_fn_ops(
1252          features={"label_weight": weight},
1253          labels=self._labels,
1254          mode=model_fn.ModeKeys.TRAIN,
1255          train_op_fn=head_lib.no_op_train_fn,
1256          logits=self._logits)
1257      self._assert_output_alternatives(model_fn_ops)
1258      _assert_no_variables(self)
1259      _assert_summary_tags(self, ["loss"])
1260      expected_loss = 1.5514447
1261      _assert_metrics(self, expected_loss * weight,
1262                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1263
1264  def testMultiClassWith1DWeight(self):
1265    n_classes = 3
1266    head = head_lib.multi_class_head(
1267        n_classes=n_classes,
1268        weight_column_name="label_weight",
1269        metric_class_ids=range(n_classes))
1270    with ops.Graph().as_default(), session.Session():
1271      weight = .1
1272      weights = (weight,)
1273      # logloss: z:label, x:logit
1274      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1275      model_fn_ops = head.create_model_fn_ops(
1276          features={"label_weight": weights},
1277          labels=self._labels,
1278          mode=model_fn.ModeKeys.TRAIN,
1279          train_op_fn=head_lib.no_op_train_fn,
1280          logits=self._logits)
1281      self._assert_output_alternatives(model_fn_ops)
1282      _assert_no_variables(self)
1283      _assert_summary_tags(self, ["loss"])
1284      expected_loss = 1.5514447
1285      _assert_metrics(self, expected_loss * weight,
1286                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1287
1288  def testMultiClassWith2DWeight(self):
1289    n_classes = 3
1290    head = head_lib.multi_class_head(
1291        n_classes=n_classes,
1292        weight_column_name="label_weight",
1293        metric_class_ids=range(n_classes))
1294    with ops.Graph().as_default(), session.Session():
1295      weight = .1
1296      weights = ((weight,),)
1297      # logloss: z:label, x:logit
1298      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1299      model_fn_ops = head.create_model_fn_ops(
1300          features={"label_weight": weights},
1301          labels=self._labels,
1302          mode=model_fn.ModeKeys.TRAIN,
1303          train_op_fn=head_lib.no_op_train_fn,
1304          logits=self._logits)
1305      self._assert_output_alternatives(model_fn_ops)
1306      _assert_no_variables(self)
1307      _assert_summary_tags(self, ["loss"])
1308      expected_loss = 1.5514447
1309      _assert_metrics(self, expected_loss * weight,
1310                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1311
1312  def testMultiClassWithCustomLoss(self):
1313    n_classes = 3
1314    head = head_lib.multi_class_head(
1315        n_classes=n_classes,
1316        weight_column_name="label_weight",
1317        metric_class_ids=range(n_classes),
1318        loss_fn=losses_lib.sparse_softmax_cross_entropy)
1319    with ops.Graph().as_default(), session.Session():
1320      weight = .1
1321      # logloss: z:label, x:logit
1322      # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
1323      model_fn_ops = head.create_model_fn_ops(
1324          features={"label_weight": weight},
1325          labels=self._labels,
1326          mode=model_fn.ModeKeys.TRAIN,
1327          train_op_fn=head_lib.no_op_train_fn,
1328          logits=self._logits)
1329      self._assert_output_alternatives(model_fn_ops)
1330      _assert_no_variables(self)
1331      _assert_summary_tags(self, ["loss"])
1332      expected_loss = 1.5514447 * weight
1333      _assert_metrics(self, expected_loss,
1334                      self._expected_eval_metrics(expected_loss), model_fn_ops)
1335
1336  def testMultiClassInfer(self):
1337    n_classes = 3
1338    head = head_lib._multi_class_head(
1339        n_classes=n_classes,
1340        head_name="head_name")
1341    with ops.Graph().as_default():
1342      model_fn_ops = head.create_model_fn_ops(
1343          features={},
1344          mode=model_fn.ModeKeys.INFER,
1345          train_op_fn=head_lib.no_op_train_fn,
1346          logits=((1., 0., 0.), (0., 0., 1.),))
1347      with session.Session():
1348        lookup_ops.tables_initializer().run()
1349        self.assertAllEqual(
1350            [0, 2],
1351            model_fn_ops.predictions["classes"].eval())
1352        self.assertItemsEqual(
1353            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
1354        self.assertEqual(
1355            constants.ProblemType.CLASSIFICATION,
1356            model_fn_ops.output_alternatives["head_name"][0])
1357        predictions_for_serving = (
1358            model_fn_ops.output_alternatives["head_name"][1])
1359        self.assertIn("classes", six.iterkeys(predictions_for_serving))
1360        self.assertAllEqual(
1361            [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],
1362            predictions_for_serving["classes"].eval())
1363        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
1364        self.assertAllClose(
1365            [[0.576117, 0.2119416, 0.2119416],
1366             [0.2119416, 0.2119416, 0.576117]],
1367            predictions_for_serving["probabilities"].eval())
1368
1369  def testInvalidNClasses(self):
1370    for n_classes in (None, -1, 0, 1):
1371      with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
1372        head_lib.multi_class_head(n_classes=n_classes)
1373
1374  def testMultiClassWithLabelKeysInvalidShape(self):
1375    with self.assertRaisesRegexp(
1376        ValueError, "Length of label_keys must equal n_classes"):
1377      head_lib._multi_class_head(
1378          n_classes=3, label_keys=("key0", "key1"))
1379
1380  def testMultiClassWithLabelKeysTwoClasses(self):
1381    with self.assertRaisesRegexp(
1382        ValueError, "label_keys is not supported for n_classes=2"):
1383      head_lib._multi_class_head(
1384          n_classes=2, label_keys=("key0", "key1"))
1385
1386  def testMultiClassWithLabelKeysInfer(self):
1387    n_classes = 3
1388    label_keys = ("key0", "key1", "key2")
1389    head = head_lib._multi_class_head(
1390        n_classes=n_classes, label_keys=label_keys,
1391        metric_class_ids=range(n_classes),
1392        head_name="head_name")
1393    with ops.Graph().as_default():
1394      model_fn_ops = head.create_model_fn_ops(
1395          features={},
1396          mode=model_fn.ModeKeys.INFER,
1397          train_op_fn=head_lib.no_op_train_fn,
1398          logits=((1., 0., 0.), (0., 0., 1.),))
1399      with session.Session():
1400        lookup_ops.tables_initializer().run()
1401        self.assertAllEqual(
1402            [b"key0", b"key2"],
1403            model_fn_ops.predictions["classes"].eval())
1404        self.assertItemsEqual(
1405            ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
1406        self.assertEqual(
1407            constants.ProblemType.CLASSIFICATION,
1408            model_fn_ops.output_alternatives["head_name"][0])
1409        predictions_for_serving = (
1410            model_fn_ops.output_alternatives["head_name"][1])
1411        self.assertIn("classes", six.iterkeys(predictions_for_serving))
1412        self.assertAllEqual(
1413            [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]],
1414            predictions_for_serving["classes"].eval())
1415        self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
1416        self.assertAllClose(
1417            [[0.576117, 0.2119416, 0.2119416],
1418             [0.2119416, 0.2119416, 0.576117]],
1419            predictions_for_serving["probabilities"].eval())
1420
1421  def testMultiClassWithLabelKeysEvalAccuracy0(self):
1422    n_classes = 3
1423    label_keys = ("key0", "key1", "key2")
1424    head = head_lib._multi_class_head(
1425        n_classes=n_classes,
1426        label_keys=label_keys)
1427    with ops.Graph().as_default():
1428      model_fn_ops = head.create_model_fn_ops(
1429          features={},
1430          mode=model_fn.ModeKeys.EVAL,
1431          labels=("key2",),
1432          train_op_fn=head_lib.no_op_train_fn,
1433          logits=((1., 0., 0.),))
1434      with session.Session():
1435        lookup_ops.tables_initializer().run()
1436        self.assertIsNone(model_fn_ops.train_op)
1437        _assert_no_variables(self)
1438        _assert_summary_tags(self, ["loss"])
1439        expected_loss = 1.5514447
1440        expected_eval_metrics = {
1441            "accuracy": 0.,
1442            "loss": expected_loss,
1443        }
1444        _assert_metrics(self, expected_loss,
1445                        expected_eval_metrics, model_fn_ops)
1446
1447  def testMultiClassWithLabelKeysEvalAccuracy1(self):
1448    n_classes = 3
1449    label_keys = ("key0", "key1", "key2")
1450    head = head_lib._multi_class_head(
1451        n_classes=n_classes,
1452        label_keys=label_keys)
1453    with ops.Graph().as_default():
1454      model_fn_ops = head.create_model_fn_ops(
1455          features={},
1456          mode=model_fn.ModeKeys.EVAL,
1457          labels=("key2",),
1458          train_op_fn=head_lib.no_op_train_fn,
1459          logits=((0., 0., 1.),))
1460      with session.Session():
1461        lookup_ops.tables_initializer().run()
1462        self.assertIsNone(model_fn_ops.train_op)
1463        _assert_no_variables(self)
1464        _assert_summary_tags(self, ["loss"])
1465        expected_loss = 0.5514447
1466        expected_eval_metrics = {
1467            "accuracy": 1.,
1468            "loss": expected_loss,
1469        }
1470        _assert_metrics(self, expected_loss,
1471                        expected_eval_metrics, model_fn_ops)
1472
1473
1474class BinarySvmHeadTest(test.TestCase):
1475
1476  def _assert_output_alternatives(self, model_fn_ops):
1477    self.assertEquals({
1478        None: constants.ProblemType.LOGISTIC_REGRESSION
1479    }, {
1480        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
1481    })
1482
1483  def setUp(self):
1484    # Prediction for first example is in the right side of the hyperplane
1485    # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss
1486    # incurred by this example. The 2nd prediction is outside the margin so it
1487    # incurs no loss at all.
1488    self._predictions = ((-.5,), (1.2,))
1489    self._labels = (0, 1)
1490    self._expected_losses = (.5, 0.)
1491
1492  def testBinarySVMWithLogits(self):
1493    head = head_lib.binary_svm_head()
1494    with ops.Graph().as_default(), session.Session():
1495      model_fn_ops = head.create_model_fn_ops(
1496          {},
1497          model_fn.ModeKeys.TRAIN,
1498          self._labels,
1499          head_lib.no_op_train_fn,
1500          logits=self._predictions)
1501      self._assert_output_alternatives(model_fn_ops)
1502      _assert_no_variables(self)
1503      _assert_summary_tags(self, ["loss"])
1504      expected_loss = np.average(self._expected_losses)
1505      _assert_metrics(self, expected_loss, {
1506          "accuracy": 1.,
1507          "loss": expected_loss,
1508      }, model_fn_ops)
1509
1510  def testBinarySVMWithInvalidLogits(self):
1511    head = head_lib.binary_svm_head()
1512    with ops.Graph().as_default(), session.Session():
1513      with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
1514        head.create_model_fn_ops(
1515            {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
1516            logits=np.ones((2, 2)))
1517
1518  def testBinarySVMWithLogitsInput(self):
1519    head = head_lib.binary_svm_head()
1520    with ops.Graph().as_default(), session.Session():
1521      model_fn_ops = head.create_model_fn_ops(
1522          {},
1523          model_fn.ModeKeys.TRAIN,
1524          self._labels,
1525          head_lib.no_op_train_fn,
1526          logits_input=((0., 0.), (0., 0.)))
1527      self._assert_output_alternatives(model_fn_ops)
1528      w = ("binary_svm_head/logits/weights:0",
1529           "binary_svm_head/logits/biases:0")
1530      _assert_variables(
1531          self, expected_global=w, expected_model=w, expected_trainable=w)
1532      variables.global_variables_initializer().run()
1533      _assert_summary_tags(self, ["loss"])
1534      expected_loss = 1.
1535      _assert_metrics(self, expected_loss, {
1536          "accuracy": .5,
1537          "loss": expected_loss,
1538      }, model_fn_ops)
1539
1540  def testBinarySVMWithLogitsAndLogitsInput(self):
1541    head = head_lib.binary_svm_head()
1542    with ops.Graph().as_default(), session.Session():
1543      with self.assertRaisesRegexp(
1544          ValueError, "Both logits and logits_input supplied"):
1545        head.create_model_fn_ops(
1546            {},
1547            model_fn.ModeKeys.TRAIN,
1548            self._labels,
1549            head_lib.no_op_train_fn,
1550            logits_input=((0., 0.), (0., 0.)),
1551            logits=self._predictions)
1552
1553  def testBinarySVMEvalMode(self):
1554    head = head_lib.binary_svm_head()
1555    with ops.Graph().as_default(), session.Session():
1556      model_fn_ops = head.create_model_fn_ops(
1557          {},
1558          model_fn.ModeKeys.EVAL,
1559          self._labels,
1560          head_lib.no_op_train_fn,
1561          logits=self._predictions)
1562      self._assert_output_alternatives(model_fn_ops)
1563      self.assertIsNone(model_fn_ops.train_op)
1564      _assert_no_variables(self)
1565      _assert_summary_tags(self, ["loss"])
1566      expected_loss = np.average(self._expected_losses)
1567      _assert_metrics(self, expected_loss, {
1568          "accuracy": 1.,
1569          "loss": expected_loss,
1570      }, model_fn_ops)
1571
1572  def testBinarySVMWithLabelName(self):
1573    label_name = "my_label"
1574    head = head_lib.binary_svm_head(label_name=label_name)
1575    with ops.Graph().as_default(), session.Session():
1576      model_fn_ops = head.create_model_fn_ops(
1577          {},
1578          model_fn.ModeKeys.TRAIN,
1579          {label_name: self._labels},
1580          head_lib.no_op_train_fn,
1581          logits=self._predictions)
1582      self._assert_output_alternatives(model_fn_ops)
1583      _assert_no_variables(self)
1584      _assert_summary_tags(self, ["loss"])
1585      expected_loss = np.average(self._expected_losses)
1586      _assert_metrics(self, expected_loss, {
1587          "accuracy": 1.,
1588          "loss": expected_loss,
1589      }, model_fn_ops)
1590
1591  def testBinarySVMWith1DWeights(self):
1592    head = head_lib.binary_svm_head(weight_column_name="weights")
1593    with ops.Graph().as_default(), session.Session():
1594      weights = (7., 11.)
1595      model_fn_ops = head.create_model_fn_ops(
1596          # We have to add an extra dim here for weights broadcasting to work.
1597          features={"weights": weights},
1598          mode=model_fn.ModeKeys.TRAIN,
1599          labels=self._labels,
1600          train_op_fn=head_lib.no_op_train_fn,
1601          logits=self._predictions)
1602      self._assert_output_alternatives(model_fn_ops)
1603      _assert_no_variables(self)
1604      _assert_summary_tags(self, ["loss"])
1605      expected_weighted_losses = np.multiply(weights, self._expected_losses)
1606      _assert_metrics(self, np.mean(expected_weighted_losses), {
1607          "accuracy": 1.,
1608          "loss": np.sum(expected_weighted_losses) / np.sum(weights),
1609      }, model_fn_ops)
1610
1611  def testBinarySVMWith2DWeights(self):
1612    head = head_lib.binary_svm_head(weight_column_name="weights")
1613    with ops.Graph().as_default(), session.Session():
1614      weights = (7., 11.)
1615      model_fn_ops = head.create_model_fn_ops(
1616          # We have to add an extra dim here for weights broadcasting to work.
1617          features={"weights": tuple([(w,) for w in weights])},
1618          mode=model_fn.ModeKeys.TRAIN,
1619          labels=self._labels,
1620          train_op_fn=head_lib.no_op_train_fn,
1621          logits=self._predictions)
1622      self._assert_output_alternatives(model_fn_ops)
1623      _assert_no_variables(self)
1624      _assert_summary_tags(self, ["loss"])
1625      expected_weighted_losses = np.multiply(weights, self._expected_losses)
1626      _assert_metrics(self, np.mean(expected_weighted_losses), {
1627          "accuracy": 1.,
1628          "loss": np.sum(expected_weighted_losses) / np.sum(weights),
1629      }, model_fn_ops)
1630
1631  def testBinarySVMWithCenteredBias(self):
1632    head = head_lib.binary_svm_head(enable_centered_bias=True)
1633    with ops.Graph().as_default(), session.Session():
1634      model_fn_ops = head.create_model_fn_ops(
1635          {},
1636          model_fn.ModeKeys.TRAIN,
1637          self._labels,
1638          head_lib.no_op_train_fn,
1639          logits=self._predictions)
1640      self._assert_output_alternatives(model_fn_ops)
1641      _assert_variables(
1642          self,
1643          expected_global=(
1644              "binary_svm_head/centered_bias_weight:0",
1645              ("binary_svm_head/binary_svm_head/centered_bias_weight/"
1646               "Adagrad:0"),
1647          ),
1648          expected_trainable=("binary_svm_head/centered_bias_weight:0",))
1649      variables.global_variables_initializer().run()
1650      _assert_summary_tags(self, [
1651          "loss",
1652          "binary_svm_head/centered_bias/bias_0"
1653      ])
1654      expected_loss = np.average(self._expected_losses)
1655      _assert_metrics(self, expected_loss, {
1656          "accuracy": 1.,
1657          "loss": expected_loss,
1658      }, model_fn_ops)
1659
1660
1661class LossOnlyHead(test.TestCase):
1662
1663  def testNoPredictionsAndNoMetrics(self):
1664    head = head_lib.loss_only_head(lambda: 1, head_name="const")
1665    model_fn_ops = head.create_model_fn_ops(
1666        features={},
1667        mode=model_fn.ModeKeys.TRAIN,
1668        train_op_fn=head_lib.no_op_train_fn)
1669    self.assertDictEqual(model_fn_ops.predictions, {})
1670    self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
1671    self.assertIsNotNone(model_fn_ops.loss)
1672    with session.Session() as sess:
1673      self.assertEqual(1, sess.run(model_fn_ops.loss))
1674
1675
1676class MultiHeadTest(test.TestCase):
1677
1678  def testInvalidHeads(self):
1679    named_head = head_lib.multi_class_head(
1680        n_classes=3, label_name="label", head_name="head1")
1681    unnamed_head = head_lib.multi_class_head(
1682        n_classes=4, label_name="label")
1683    with self.assertRaisesRegexp(ValueError, "must have names"):
1684      head_lib.multi_head((named_head, unnamed_head))
1685
1686  def testTrainWithNoneTrainOpFn(self):
1687    head1 = head_lib.multi_class_head(
1688        n_classes=3, label_name="label1", head_name="head1")
1689    head2 = head_lib.multi_class_head(
1690        n_classes=4, label_name="label2", head_name="head2")
1691    head = head_lib.multi_head((head1, head2))
1692    labels = {
1693        "label1": (1,),
1694        "label2": (1,)
1695    }
1696    with self.assertRaisesRegexp(
1697        ValueError, "train_op_fn can not be None in TRAIN mode"):
1698      head.create_model_fn_ops(
1699          features={"weights": (2.0, 10.0)},
1700          labels=labels,
1701          mode=model_fn.ModeKeys.TRAIN,
1702          train_op_fn=None,
1703          logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
1704
1705  def testTrain_withNoHeadWeights(self):
1706    head1 = head_lib.multi_class_head(
1707        n_classes=3, label_name="label1", head_name="head1")
1708    head2 = head_lib.multi_class_head(
1709        n_classes=4, label_name="label2", head_name="head2")
1710    head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
1711    head = head_lib.multi_head((head1, head2, head3))
1712    labels = {
1713        "label1": (1,),
1714        "label2": (1,)
1715    }
1716    model_fn_ops = head.create_model_fn_ops(
1717        features={"weights": (2.0, 10.0)},
1718        labels=labels,
1719        mode=model_fn.ModeKeys.TRAIN,
1720        train_op_fn=head_lib.no_op_train_fn,
1721        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
1722
1723    self.assertIsNone(model_fn_ops.predictions)
1724    self.assertIsNotNone(model_fn_ops.loss)
1725    self.assertIsNotNone(model_fn_ops.train_op)
1726    self.assertTrue(model_fn_ops.eval_metric_ops)
1727    self.assertIsNone(model_fn_ops.output_alternatives)
1728
1729    with session.Session() as sess:
1730      self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
1731
1732  def testTrain_withHeadWeights(self):
1733    head1 = head_lib.multi_class_head(
1734        n_classes=3, label_name="label1", head_name="head1")
1735    head2 = head_lib.multi_class_head(
1736        n_classes=4, label_name="label2", head_name="head2")
1737    head = head_lib.multi_head((head1, head2), (1, .5))
1738    labels = {
1739        "label1": (1,),
1740        "label2": (1,)
1741    }
1742    model_fn_ops = head.create_model_fn_ops(
1743        features={"weights": (2.0, 10.0)},
1744        labels=labels,
1745        mode=model_fn.ModeKeys.TRAIN,
1746        train_op_fn=head_lib.no_op_train_fn,
1747        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
1748    self.assertIsNone(model_fn_ops.predictions)
1749    self.assertIsNotNone(model_fn_ops.loss)
1750    self.assertIsNotNone(model_fn_ops.train_op)
1751    self.assertTrue(model_fn_ops.eval_metric_ops)
1752    self.assertIsNone(model_fn_ops.output_alternatives)
1753
1754    with session.Session() as sess:
1755      self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)
1756
1757  def testTrain_withDictLogits(self):
1758    head1 = head_lib.multi_class_head(
1759        n_classes=3, label_name="label1", head_name="head1")
1760    head2 = head_lib.multi_class_head(
1761        n_classes=4, label_name="label2", head_name="head2")
1762    head = head_lib.multi_head((head1, head2))
1763    labels = {
1764        "label1": (1,),
1765        "label2": (1,)
1766    }
1767    model_fn_ops = head.create_model_fn_ops(
1768        features={"weights": (2.0, 10.0)},
1769        labels=labels,
1770        mode=model_fn.ModeKeys.TRAIN,
1771        train_op_fn=head_lib.no_op_train_fn,
1772        logits={head1.head_name: ((-0.7, 0.2, .1),),
1773                head2.head_name: ((.1, .1, .1, .1),)})
1774
1775    self.assertIsNone(model_fn_ops.predictions)
1776    self.assertIsNotNone(model_fn_ops.loss)
1777    self.assertIsNotNone(model_fn_ops.train_op)
1778    self.assertTrue(model_fn_ops.eval_metric_ops)
1779    self.assertIsNone(model_fn_ops.output_alternatives)
1780
1781    with session.Session() as sess:
1782      self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
1783
1784  def testInfer(self):
1785    head1 = head_lib.multi_class_head(
1786        n_classes=3, label_name="label1", head_name="head1")
1787    head2 = head_lib.multi_class_head(
1788        n_classes=4, label_name="label2", head_name="head2")
1789    head = head_lib.multi_head((head1, head2), (1, .5))
1790    labels = {
1791        "label1": (1,),
1792        "label2": (1,)
1793    }
1794    model_fn_ops = head.create_model_fn_ops(
1795        features={"weights": (2.0, 10.0)},
1796        labels=labels,
1797        mode=model_fn.ModeKeys.INFER,
1798        train_op_fn=head_lib.no_op_train_fn,
1799        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
1800
1801    self.assertIsNotNone(model_fn_ops.predictions)
1802    self.assertIsNone(model_fn_ops.loss)
1803    self.assertIsNone(model_fn_ops.train_op)
1804    self.assertFalse(model_fn_ops.eval_metric_ops)
1805
1806    # Tests predictions keys.
1807    self.assertItemsEqual((
1808        ("head1", prediction_key.PredictionKey.LOGITS),
1809        ("head1", prediction_key.PredictionKey.PROBABILITIES),
1810        ("head1", prediction_key.PredictionKey.CLASSES),
1811        ("head2", prediction_key.PredictionKey.LOGITS),
1812        ("head2", prediction_key.PredictionKey.PROBABILITIES),
1813        ("head2", prediction_key.PredictionKey.CLASSES),
1814    ), model_fn_ops.predictions.keys())
1815
1816    # Tests output alternative.
1817    self.assertEquals({
1818        "head1": constants.ProblemType.CLASSIFICATION,
1819        "head2": constants.ProblemType.CLASSIFICATION,
1820    }, {
1821        k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)
1822    })
1823    self.assertItemsEqual((
1824        prediction_key.PredictionKey.PROBABILITIES,
1825        prediction_key.PredictionKey.CLASSES,
1826    ), model_fn_ops.output_alternatives["head1"][1].keys())
1827    self.assertItemsEqual((
1828        prediction_key.PredictionKey.PROBABILITIES,
1829        prediction_key.PredictionKey.CLASSES,
1830    ), model_fn_ops.output_alternatives["head2"][1].keys())
1831
1832  def testEval(self):
1833    head1 = head_lib.multi_class_head(
1834        n_classes=3, label_name="label1", head_name="head1")
1835    head2 = head_lib.multi_class_head(
1836        n_classes=4, label_name="label2", head_name="head2")
1837    head = head_lib.multi_head((head1, head2), (1, .5))
1838    labels = {
1839        "label1": (1,),
1840        "label2": (1,)
1841    }
1842    model_fn_ops = head.create_model_fn_ops(
1843        features={"weights": (2.0, 10.0)},
1844        labels=labels,
1845        mode=model_fn.ModeKeys.EVAL,
1846        train_op_fn=head_lib.no_op_train_fn,
1847        logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
1848
1849    self.assertIsNotNone(model_fn_ops.predictions)
1850    self.assertIsNotNone(model_fn_ops.loss)
1851    self.assertIsNone(model_fn_ops.train_op)
1852    self.assertIsNotNone(model_fn_ops.eval_metric_ops)
1853    self.assertIsNone(model_fn_ops.output_alternatives)
1854
1855    metric_ops = model_fn_ops.eval_metric_ops
1856
1857    # Tests eval keys.
1858    self.assertIn("accuracy/head1", metric_ops.keys())
1859    self.assertIn("accuracy/head2", metric_ops.keys())
1860
1861
1862def _sigmoid_cross_entropy(labels, logits, weights):
1863  return losses_lib.sigmoid_cross_entropy(labels, logits, weights)
1864
1865
1866if __name__ == "__main__":
1867  test.main()
1868