1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for head.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23# pylint: disable=g-bad-todo,g-import-not-at-top 24import numpy as np 25import six 26 27from tensorflow.contrib.learn.python.learn.estimators import constants 28from tensorflow.contrib.learn.python.learn.estimators import head as head_lib 29from tensorflow.contrib.learn.python.learn.estimators import model_fn 30from tensorflow.contrib.learn.python.learn.estimators import prediction_key 31from tensorflow.core.framework import summary_pb2 32from tensorflow.python.client import session 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import sparse_tensor 35from tensorflow.python.ops import lookup_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import variables 38from tensorflow.python.ops.losses import losses as losses_lib 39from tensorflow.python.platform import test 40 41 42def _assert_variables(test_case, 43 expected_global=None, 44 expected_model=None, 45 expected_trainable=None): 46 test_case.assertItemsEqual( 47 tuple([] if expected_global is None else expected_global), 48 tuple([k.name for k in variables.global_variables()])) 49 test_case.assertItemsEqual( 50 tuple([] if expected_model is None else expected_model), 51 tuple([k.name for k in variables.model_variables()])) 52 test_case.assertItemsEqual( 53 tuple([] if expected_trainable is None else expected_trainable), 54 tuple([k.name for k in variables.trainable_variables()])) 55 56 57def _assert_no_variables(test_case): 58 _assert_variables(test_case) 59 60 61# This must be called from within a tf.Session. 62def _assert_metrics(test_case, expected_loss, expected_eval_metrics, 63 model_fn_ops): 64 test_case.assertAlmostEqual(expected_loss, model_fn_ops.loss.eval(), places=4) 65 for k in six.iterkeys(expected_eval_metrics): 66 test_case.assertIn(k, six.iterkeys(model_fn_ops.eval_metric_ops)) 67 variables.initialize_local_variables().run() 68 for key, expected_value in six.iteritems(expected_eval_metrics): 69 value_tensor, update_tensor = model_fn_ops.eval_metric_ops[key] 70 update = update_tensor.eval() 71 test_case.assertAlmostEqual( 72 expected_value, 73 update, 74 places=4, 75 msg="%s: update, expected %s, got %s." % (key, expected_value, update)) 76 value = value_tensor.eval() 77 test_case.assertAlmostEqual( 78 expected_value, 79 value, 80 places=4, 81 msg="%s: value, expected %s, got %s." % (key, expected_value, value)) 82 83 84# This must be called from within a tf.Session. 85def _assert_summary_tags(test_case, expected_tags=None): 86 actual_tags = [] 87 for summary_op in ops.get_collection(ops.GraphKeys.SUMMARIES): 88 summ = summary_pb2.Summary() 89 summ.ParseFromString(summary_op.eval()) 90 actual_tags.append(summ.value[0].tag) 91 test_case.assertItemsEqual(expected_tags or [], actual_tags) 92 93 94def _sigmoid(x): 95 return 1. / (1. + math.exp(-1 * x)) 96 97 98class PoissonHeadTest(test.TestCase): 99 100 def _assert_output_alternatives(self, model_fn_ops): 101 self.assertEquals({ 102 None: constants.ProblemType.LINEAR_REGRESSION 103 }, { 104 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 105 }) 106 107 def _log_poisson_loss(self, logits, labels): 108 x = np.array([f[0] for f in logits]) 109 z = np.array([f[0] for f in labels]) 110 lpl = np.exp(x) - z * x 111 stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z) 112 lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.) 113 return sum(lpl)/len(lpl) 114 115 def testPoissonWithLogits(self): 116 head = head_lib.poisson_regression_head() 117 labels = ((0.,), (1.,), (1.,)) 118 logits = ((0.,), (-1.,), (3.,)) 119 with ops.Graph().as_default(), session.Session(): 120 model_fn_ops = head.create_model_fn_ops( 121 {}, 122 labels=labels, 123 mode=model_fn.ModeKeys.TRAIN, 124 train_op_fn=head_lib.no_op_train_fn, 125 logits=logits) 126 self._assert_output_alternatives(model_fn_ops) 127 _assert_summary_tags(self, ["loss"]) 128 _assert_no_variables(self) 129 loss = self._log_poisson_loss(logits, labels) 130 _assert_metrics(self, loss, {"loss": loss}, model_fn_ops) 131 132 133class RegressionHeadTest(test.TestCase): 134 135 def _assert_output_alternatives(self, model_fn_ops): 136 self.assertEquals({ 137 None: constants.ProblemType.LINEAR_REGRESSION 138 }, { 139 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 140 }) 141 142 # TODO(zakaria): test multilabel regression. 143 def testRegressionWithLogits(self): 144 head = head_lib.regression_head() 145 with ops.Graph().as_default(), session.Session(): 146 model_fn_ops = head.create_model_fn_ops( 147 {}, 148 labels=((0.,), (1.,), (1.,)), 149 mode=model_fn.ModeKeys.TRAIN, 150 train_op_fn=head_lib.no_op_train_fn, 151 logits=((1.,), (1.,), (3.,))) 152 self._assert_output_alternatives(model_fn_ops) 153 _assert_summary_tags(self, ["loss"]) 154 _assert_no_variables(self) 155 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) 156 157 def testRegressionWithLogitFn(self): 158 head = head_lib.regression_head(link_fn=math_ops.square) 159 def _assert_preditions(test_case, expected_predictions, model_fn_ops): 160 variables.initialize_local_variables().run() 161 test_case.assertAllClose(expected_predictions, 162 model_fn_ops.predictions["scores"].eval()) 163 with ops.Graph().as_default(), session.Session(): 164 model_fn_ops = head.create_model_fn_ops( 165 {}, 166 labels=((0.,), (1.,), (1.,)), 167 mode=model_fn.ModeKeys.TRAIN, 168 train_op_fn=head_lib.no_op_train_fn, 169 logits=((1.,), (1.,), (3.,))) 170 self._assert_output_alternatives(model_fn_ops) 171 _assert_summary_tags(self, ["loss"]) 172 _assert_no_variables(self) 173 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) 174 _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops) 175 176 def testRegressionWithInvalidLogits(self): 177 head = head_lib.regression_head() 178 with ops.Graph().as_default(), session.Session(): 179 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): 180 head.create_model_fn_ops( 181 {}, 182 labels=((0.,), (1.,), (1.,)), 183 mode=model_fn.ModeKeys.TRAIN, 184 train_op_fn=head_lib.no_op_train_fn, 185 logits=((1., 1.), (1., 1.), (3., 1.))) 186 187 def testRegressionWithLogitsInput(self): 188 head = head_lib.regression_head() 189 with ops.Graph().as_default(), session.Session(): 190 model_fn_ops = head.create_model_fn_ops( 191 {}, 192 labels=((0.,), (1.,), (1.,)), 193 mode=model_fn.ModeKeys.TRAIN, 194 train_op_fn=head_lib.no_op_train_fn, 195 logits_input=((0., 0.), (0., 0.), (0., 0.))) 196 self._assert_output_alternatives(model_fn_ops) 197 w = ("regression_head/logits/weights:0", 198 "regression_head/logits/biases:0") 199 _assert_variables( 200 self, expected_global=w, expected_model=w, expected_trainable=w) 201 variables.global_variables_initializer().run() 202 _assert_summary_tags(self, ["loss"]) 203 _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops) 204 205 def testRegressionWithLogitsAndLogitsInput(self): 206 head = head_lib.regression_head() 207 with ops.Graph().as_default(), session.Session(): 208 with self.assertRaisesRegexp( 209 ValueError, "Both logits and logits_input supplied"): 210 head.create_model_fn_ops( 211 {}, 212 labels=((0.,), (1.,), (1.,)), 213 mode=model_fn.ModeKeys.TRAIN, 214 train_op_fn=head_lib.no_op_train_fn, 215 logits_input=((0., 0.), (0., 0.), (0., 0.)), 216 logits=((1.,), (1.,), (3.,))) 217 218 def testRegressionEvalMode(self): 219 head = head_lib.regression_head() 220 with ops.Graph().as_default(), session.Session(): 221 model_fn_ops = head.create_model_fn_ops( 222 {}, 223 labels=((1.,), (1.,), (3.,)), 224 mode=model_fn.ModeKeys.EVAL, 225 train_op_fn=head_lib.no_op_train_fn, 226 logits=((0.,), (1.,), (1.,))) 227 self._assert_output_alternatives(model_fn_ops) 228 self.assertIsNone(model_fn_ops.train_op) 229 _assert_no_variables(self) 230 _assert_summary_tags(self, ["loss"]) 231 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) 232 233 def testRegressionWithLabelName(self): 234 label_name = "my_label" 235 head = head_lib.regression_head(label_name=label_name) 236 with ops.Graph().as_default(), session.Session(): 237 model_fn_ops = head.create_model_fn_ops( 238 {}, 239 labels={label_name: ((0.,), (1.,), (1.,))}, 240 mode=model_fn.ModeKeys.TRAIN, 241 train_op_fn=head_lib.no_op_train_fn, 242 logits=((1.,), (1.,), (3.,))) 243 self._assert_output_alternatives(model_fn_ops) 244 _assert_no_variables(self) 245 _assert_summary_tags(self, ["loss"]) 246 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) 247 248 def testRegressionWithScalarWeights(self): 249 head = head_lib.regression_head(weight_column_name="label_weight") 250 with ops.Graph().as_default(), session.Session(): 251 weights = 2. 252 labels = ((0.,), (1.,), (1.,)) 253 model_fn_ops = head.create_model_fn_ops( 254 features={"label_weight": weights}, 255 labels=labels, 256 mode=model_fn.ModeKeys.TRAIN, 257 train_op_fn=head_lib.no_op_train_fn, 258 logits=((1.,), (1.,), (3.,))) 259 self._assert_output_alternatives(model_fn_ops) 260 _assert_no_variables(self) 261 _assert_summary_tags(self, ["loss"]) 262 _assert_metrics(self, (weights * 5.) / len(labels), { 263 "loss": (weights * 5.) / (weights * len(labels)) 264 }, model_fn_ops) 265 266 def testRegressionWith1DWeights(self): 267 head = head_lib.regression_head(weight_column_name="label_weight") 268 with ops.Graph().as_default(), session.Session(): 269 weights = (2., 5., 0.) 270 labels = ((0.,), (1.,), (1.,)) 271 model_fn_ops = head.create_model_fn_ops( 272 features={"label_weight": weights}, 273 labels=labels, 274 mode=model_fn.ModeKeys.TRAIN, 275 train_op_fn=head_lib.no_op_train_fn, 276 logits=((1.,), (1.,), (3.,))) 277 self._assert_output_alternatives(model_fn_ops) 278 _assert_no_variables(self) 279 _assert_summary_tags(self, ["loss"]) 280 _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)}, 281 model_fn_ops) 282 283 def testRegressionWith2DWeights(self): 284 head = head_lib.regression_head(weight_column_name="label_weight") 285 with ops.Graph().as_default(), session.Session(): 286 weights = ((2.,), (5.,), (0.,)) 287 labels = ((0.,), (1.,), (1.,)) 288 model_fn_ops = head.create_model_fn_ops( 289 features={"label_weight": weights}, 290 labels=labels, 291 mode=model_fn.ModeKeys.TRAIN, 292 train_op_fn=head_lib.no_op_train_fn, 293 logits=((1.,), (1.,), (3.,))) 294 self._assert_output_alternatives(model_fn_ops) 295 _assert_no_variables(self) 296 _assert_summary_tags(self, ["loss"]) 297 _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)}, 298 model_fn_ops) 299 300 def testRegressionWithCenteredBias(self): 301 head = head_lib.regression_head(enable_centered_bias=True) 302 with ops.Graph().as_default(), session.Session(): 303 model_fn_ops = head.create_model_fn_ops( 304 {}, 305 labels=((0.,), (1.,), (1.,)), 306 mode=model_fn.ModeKeys.TRAIN, 307 train_op_fn=head_lib.no_op_train_fn, 308 logits=((1.,), (1.,), (3.,))) 309 self._assert_output_alternatives(model_fn_ops) 310 _assert_variables( 311 self, 312 expected_global=( 313 "regression_head/centered_bias_weight:0", 314 "regression_head/regression_head/centered_bias_weight/Adagrad:0", 315 ), 316 expected_trainable=("regression_head/centered_bias_weight:0",)) 317 variables.global_variables_initializer().run() 318 _assert_summary_tags(self, [ 319 "loss", 320 "regression_head/centered_bias/bias_0" 321 ]) 322 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) 323 324 def testRegressionErrorInSparseTensorLabels(self): 325 head = head_lib.regression_head() 326 with ops.Graph().as_default(): 327 labels = sparse_tensor.SparseTensorValue( 328 indices=((0, 0), (1, 0), (2, 0)), 329 values=(0., 1., 1.), 330 dense_shape=(3, 1)) 331 with self.assertRaisesRegexp(ValueError, 332 "SparseTensor is not supported"): 333 head.create_model_fn_ops( 334 {}, 335 labels=labels, 336 mode=model_fn.ModeKeys.TRAIN, 337 train_op_fn=head_lib.no_op_train_fn, 338 logits=((1.,), (1.,), (3.,))) 339 340 341class MultiLabelHeadTest(test.TestCase): 342 343 def _assert_output_alternatives(self, model_fn_ops): 344 self.assertEquals({ 345 None: constants.ProblemType.CLASSIFICATION 346 }, { 347 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 348 }) 349 350 def setUp(self): 351 self._logits = ((1., 0., 0.),) 352 self._labels = ((0, 0, 1),) 353 354 def _expected_eval_metrics(self, expected_loss): 355 return { 356 "accuracy": 1. / 3, 357 "loss": expected_loss, 358 "auc": 1. / 4, 359 "auc/class0": 1., 360 "auc/class1": 1., 361 "auc/class2": 0., 362 "auc_precision_recall": 0.166667, 363 "auc_precision_recall/class0": 0, 364 "auc_precision_recall/class1": 0., 365 "auc_precision_recall/class2": 0.49999, 366 "labels/actual_label_mean/class0": self._labels[0][0], 367 "labels/actual_label_mean/class1": self._labels[0][1], 368 "labels/actual_label_mean/class2": self._labels[0][2], 369 "labels/logits_mean/class0": self._logits[0][0], 370 "labels/logits_mean/class1": self._logits[0][1], 371 "labels/logits_mean/class2": self._logits[0][2], 372 "labels/prediction_mean/class0": self._logits[0][0], 373 "labels/prediction_mean/class1": self._logits[0][1], 374 "labels/prediction_mean/class2": self._logits[0][2], 375 "labels/probability_mean/class0": _sigmoid(self._logits[0][0]), 376 "labels/probability_mean/class1": _sigmoid(self._logits[0][1]), 377 "labels/probability_mean/class2": _sigmoid(self._logits[0][2]), 378 } 379 380 def testMultiLabelWithLogits(self): 381 n_classes = 3 382 head = head_lib.multi_label_head( 383 n_classes=n_classes, metric_class_ids=range(n_classes)) 384 with ops.Graph().as_default(), session.Session(): 385 model_fn_ops = head.create_model_fn_ops( 386 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 387 logits=self._logits) 388 self._assert_output_alternatives(model_fn_ops) 389 _assert_no_variables(self) 390 _assert_summary_tags(self, ["loss"]) 391 expected_loss = .89985204 392 _assert_metrics(self, expected_loss, 393 self._expected_eval_metrics(expected_loss), model_fn_ops) 394 395 def testMultiLabelTwoClasses(self): 396 n_classes = 2 397 labels = ((0, 1),) 398 logits = ((1., 0.),) 399 head = head_lib.multi_label_head( 400 n_classes=n_classes, metric_class_ids=range(n_classes)) 401 with ops.Graph().as_default(), session.Session(): 402 model_fn_ops = head.create_model_fn_ops( 403 {}, model_fn.ModeKeys.TRAIN, labels=labels, 404 train_op_fn=head_lib.no_op_train_fn, logits=logits) 405 self._assert_output_alternatives(model_fn_ops) 406 _assert_no_variables(self) 407 _assert_summary_tags(self, ["loss"]) 408 expected_loss = 1.00320443 409 _assert_metrics(self, expected_loss, { 410 "accuracy": 0., 411 "auc": 0., 412 "loss": expected_loss, 413 "auc/class0": 1., 414 "auc/class1": 0., 415 "labels/actual_label_mean/class0": labels[0][0], 416 "labels/actual_label_mean/class1": labels[0][1], 417 "labels/logits_mean/class0": logits[0][0], 418 "labels/logits_mean/class1": logits[0][1], 419 "labels/prediction_mean/class0": logits[0][0], 420 "labels/prediction_mean/class1": logits[0][1], 421 "labels/probability_mean/class0": _sigmoid(logits[0][0]), 422 "labels/probability_mean/class1": _sigmoid(logits[0][1]), 423 }, model_fn_ops) 424 425 def testMultiLabelWithInvalidLogits(self): 426 head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1) 427 with ops.Graph().as_default(), session.Session(): 428 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): 429 head.create_model_fn_ops( 430 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 431 logits=self._logits) 432 433 def testMultiLabelWithLogitsInput(self): 434 n_classes = 3 435 head = head_lib.multi_label_head( 436 n_classes=n_classes, metric_class_ids=range(n_classes)) 437 with ops.Graph().as_default(), session.Session(): 438 model_fn_ops = head.create_model_fn_ops( 439 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 440 logits_input=((0., 0.),)) 441 self._assert_output_alternatives(model_fn_ops) 442 w = ("multi_label_head/logits/weights:0", 443 "multi_label_head/logits/biases:0") 444 _assert_variables( 445 self, expected_global=w, expected_model=w, expected_trainable=w) 446 variables.global_variables_initializer().run() 447 _assert_summary_tags(self, ["loss"]) 448 expected_loss = .69314718 449 _assert_metrics(self, expected_loss, { 450 "accuracy": 2. / 3, 451 "auc": 2. / 4, 452 "loss": expected_loss, 453 "auc/class0": 1., 454 "auc/class1": 1., 455 "auc/class2": 0., 456 "labels/actual_label_mean/class0": self._labels[0][0], 457 "labels/actual_label_mean/class1": self._labels[0][1], 458 "labels/actual_label_mean/class2": self._labels[0][2], 459 "labels/logits_mean/class0": 0., 460 "labels/logits_mean/class1": 0., 461 "labels/logits_mean/class2": 0., 462 "labels/prediction_mean/class0": 0., 463 "labels/prediction_mean/class1": 0., 464 "labels/prediction_mean/class2": 0., 465 "labels/probability_mean/class0": .5, 466 "labels/probability_mean/class1": .5, 467 "labels/probability_mean/class2": .5, 468 }, model_fn_ops) 469 470 def testMultiLabelWithLogitsAndLogitsInput(self): 471 n_classes = 3 472 head = head_lib.multi_label_head( 473 n_classes=n_classes, metric_class_ids=range(n_classes)) 474 with ops.Graph().as_default(), session.Session(): 475 with self.assertRaisesRegexp( 476 ValueError, "Both logits and logits_input supplied"): 477 head.create_model_fn_ops( 478 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 479 logits_input=((0., 0.),), logits=self._logits) 480 481 def testMultiLabelEval(self): 482 n_classes = 3 483 head = head_lib.multi_label_head( 484 n_classes=n_classes, metric_class_ids=range(n_classes)) 485 with ops.Graph().as_default(), session.Session(): 486 model_fn_ops = head.create_model_fn_ops( 487 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn, 488 logits=self._logits) 489 self._assert_output_alternatives(model_fn_ops) 490 self.assertIsNone(model_fn_ops.train_op) 491 _assert_no_variables(self) 492 _assert_summary_tags(self, ["loss"]) 493 expected_loss = .89985204 494 _assert_metrics(self, expected_loss, 495 self._expected_eval_metrics(expected_loss), model_fn_ops) 496 497 def testMultiClassEvalWithLargeLogits(self): 498 n_classes = 3 499 head = head_lib.multi_label_head( 500 n_classes=n_classes, metric_class_ids=range(n_classes)) 501 logits = ((2., 0., -1),) 502 with ops.Graph().as_default(), session.Session(): 503 # logloss: z:label, x:logit 504 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 505 model_fn_ops = head.create_model_fn_ops( 506 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn, 507 logits=logits) 508 self._assert_output_alternatives(model_fn_ops) 509 self.assertIsNone(model_fn_ops.train_op) 510 _assert_no_variables(self) 511 _assert_summary_tags(self, ["loss"]) 512 expected_loss = 1.377779 513 expected_eval_metrics = { 514 "accuracy": 1. / 3, 515 "auc": 9.99999e-07, 516 "loss": expected_loss, 517 "auc/class0": 1., 518 "auc/class1": 1., 519 "auc/class2": 0., 520 "labels/actual_label_mean/class0": 0. / 1, 521 "labels/actual_label_mean/class1": 0. / 1, 522 "labels/actual_label_mean/class2": 1. / 1, 523 "labels/logits_mean/class0": logits[0][0], 524 "labels/logits_mean/class1": logits[0][1], 525 "labels/logits_mean/class2": logits[0][2], 526 "labels/prediction_mean/class0": 1, 527 "labels/prediction_mean/class1": 0, 528 "labels/prediction_mean/class2": 0, 529 "labels/probability_mean/class0": _sigmoid(logits[0][0]), 530 "labels/probability_mean/class1": _sigmoid(logits[0][1]), 531 "labels/probability_mean/class2": _sigmoid(logits[0][2]), 532 } 533 _assert_metrics(self, expected_loss, 534 expected_eval_metrics, model_fn_ops) 535 536 def testMultiLabelInfer(self): 537 n_classes = 3 538 head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name") 539 with ops.Graph().as_default(), session.Session(): 540 model_fn_ops = head.create_model_fn_ops( 541 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, 542 logits=((1., 0., 0.), (0., 0., 1))) 543 self.assertIsNone(model_fn_ops.train_op) 544 _assert_no_variables(self) 545 with session.Session(): 546 self.assertListEqual( 547 [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0]) 548 self.assertItemsEqual( 549 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) 550 self.assertEqual( 551 constants.ProblemType.CLASSIFICATION, 552 model_fn_ops.output_alternatives["head_name"][0]) 553 554 predictions_for_serving = ( 555 model_fn_ops.output_alternatives["head_name"][1]) 556 self.assertIn("classes", six.iterkeys(predictions_for_serving)) 557 self.assertAllEqual( 558 [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], 559 predictions_for_serving["classes"].eval()) 560 self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) 561 self.assertAllClose( 562 [[0.731059, 0.5, 0.5], 563 [0.5, 0.5, 0.731059,]], 564 predictions_for_serving["probabilities"].eval()) 565 566 def testMultiLabelWithLabelName(self): 567 n_classes = 3 568 label_name = "my_label" 569 head = head_lib.multi_label_head( 570 n_classes=n_classes, 571 label_name=label_name, 572 metric_class_ids=range(n_classes)) 573 with ops.Graph().as_default(), session.Session(): 574 model_fn_ops = head.create_model_fn_ops( 575 {}, model_fn.ModeKeys.TRAIN, {label_name: self._labels}, 576 head_lib.no_op_train_fn, logits=self._logits) 577 self._assert_output_alternatives(model_fn_ops) 578 _assert_no_variables(self) 579 _assert_summary_tags(self, ["loss"]) 580 expected_loss = .89985204 581 _assert_metrics(self, expected_loss, 582 self._expected_eval_metrics(expected_loss), model_fn_ops) 583 584 def testMultiLabelWithScalarWeight(self): 585 n_classes = 3 586 head = head_lib.multi_label_head( 587 n_classes=n_classes, 588 weight_column_name="label_weight", 589 metric_class_ids=range(n_classes)) 590 with ops.Graph().as_default(), session.Session(): 591 model_fn_ops = head.create_model_fn_ops( 592 features={"label_weight": .1}, 593 labels=self._labels, 594 mode=model_fn.ModeKeys.TRAIN, 595 train_op_fn=head_lib.no_op_train_fn, 596 logits=self._logits) 597 self._assert_output_alternatives(model_fn_ops) 598 _assert_no_variables(self) 599 _assert_summary_tags(self, ["loss"]) 600 _assert_metrics(self, .089985214, 601 self._expected_eval_metrics(.89985214), model_fn_ops) 602 603 def testMultiLabelWith1DWeight(self): 604 n_classes = 3 605 head = head_lib.multi_label_head( 606 n_classes=n_classes, 607 weight_column_name="label_weight", 608 metric_class_ids=range(n_classes)) 609 with ops.Graph().as_default(), session.Session(): 610 with self.assertRaisesRegexp( 611 ValueError, "weights can not be broadcast to values"): 612 head.create_model_fn_ops( 613 features={"label_weight": (.1, .1, .1)}, 614 labels=self._labels, 615 mode=model_fn.ModeKeys.TRAIN, 616 train_op_fn=head_lib.no_op_train_fn, 617 logits=self._logits) 618 619 def testMultiLabelWith2DWeight(self): 620 n_classes = 3 621 head = head_lib.multi_label_head( 622 n_classes=n_classes, 623 weight_column_name="label_weight", 624 metric_class_ids=range(n_classes)) 625 with ops.Graph().as_default(), session.Session(): 626 model_fn_ops = head.create_model_fn_ops( 627 features={"label_weight": ((.1, .1, .1),)}, 628 labels=self._labels, 629 mode=model_fn.ModeKeys.TRAIN, 630 train_op_fn=head_lib.no_op_train_fn, 631 logits=self._logits) 632 self._assert_output_alternatives(model_fn_ops) 633 _assert_no_variables(self) 634 _assert_summary_tags(self, ["loss"]) 635 _assert_metrics(self, .089985214, 636 self._expected_eval_metrics(.89985214), model_fn_ops) 637 638 def testMultiLabelWithCustomLoss(self): 639 n_classes = 3 640 head = head_lib.multi_label_head( 641 n_classes=n_classes, 642 weight_column_name="label_weight", 643 metric_class_ids=range(n_classes), 644 loss_fn=_sigmoid_cross_entropy) 645 with ops.Graph().as_default(), session.Session(): 646 model_fn_ops = head.create_model_fn_ops( 647 features={"label_weight": .1}, 648 labels=self._labels, 649 mode=model_fn.ModeKeys.TRAIN, 650 train_op_fn=head_lib.no_op_train_fn, 651 logits=self._logits) 652 self._assert_output_alternatives(model_fn_ops) 653 _assert_no_variables(self) 654 _assert_summary_tags(self, ["loss"]) 655 expected_loss = .089985214 656 _assert_metrics(self, expected_loss, 657 self._expected_eval_metrics(expected_loss), model_fn_ops) 658 659 def testMultiLabelWithCenteredBias(self): 660 n_classes = 3 661 head = head_lib.multi_label_head( 662 n_classes=n_classes, 663 enable_centered_bias=True, 664 metric_class_ids=range(n_classes)) 665 with ops.Graph().as_default(), session.Session(): 666 model_fn_ops = head.create_model_fn_ops( 667 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 668 logits=self._logits) 669 self._assert_output_alternatives(model_fn_ops) 670 _assert_variables( 671 self, 672 expected_global=( 673 "multi_label_head/centered_bias_weight:0", 674 ("multi_label_head/multi_label_head/centered_bias_weight/" 675 "Adagrad:0"),), 676 expected_trainable=("multi_label_head/centered_bias_weight:0",)) 677 variables.global_variables_initializer().run() 678 _assert_summary_tags(self, ( 679 "loss", 680 "multi_label_head/centered_bias/bias_0", 681 "multi_label_head/centered_bias/bias_1", 682 "multi_label_head/centered_bias/bias_2" 683 )) 684 expected_loss = .89985204 685 _assert_metrics(self, expected_loss, 686 self._expected_eval_metrics(expected_loss), model_fn_ops) 687 688 def testMultiLabelSparseTensorLabels(self): 689 n_classes = 3 690 head = head_lib.multi_label_head( 691 n_classes=n_classes, metric_class_ids=range(n_classes)) 692 with ops.Graph().as_default(), session.Session(): 693 labels = sparse_tensor.SparseTensorValue( 694 indices=((0, 0),), 695 values=(2,), 696 dense_shape=(1, 1)) 697 model_fn_ops = head.create_model_fn_ops( 698 features={}, 699 mode=model_fn.ModeKeys.TRAIN, 700 labels=labels, 701 train_op_fn=head_lib.no_op_train_fn, 702 logits=self._logits) 703 _assert_no_variables(self) 704 _assert_summary_tags(self, ["loss"]) 705 expected_loss = .89985204 706 _assert_metrics(self, expected_loss, 707 self._expected_eval_metrics(expected_loss), model_fn_ops) 708 709 def testMultiLabelSparseTensorLabelsTooFewClasses(self): 710 n_classes = 3 711 head = head_lib.multi_label_head( 712 n_classes=n_classes, metric_class_ids=range(n_classes)) 713 # Set _logits_dimension (n_classes) to a lower value; if it's set to 1 714 # upfront, the class throws an error during initialization. 715 head._logits_dimension = 1 716 with ops.Graph().as_default(), session.Session(): 717 labels = sparse_tensor.SparseTensorValue( 718 indices=((0, 0),), 719 values=(2,), 720 dense_shape=(1, 1)) 721 with self.assertRaisesRegexp(ValueError, 722 "Must set num_classes >= 2 when passing"): 723 head.create_model_fn_ops( 724 features={}, 725 labels=labels, 726 mode=model_fn.ModeKeys.TRAIN, 727 train_op_fn=head_lib.no_op_train_fn, 728 logits=[0.]) 729 730 731class BinaryClassificationHeadTest(test.TestCase): 732 733 def _assert_output_alternatives(self, model_fn_ops): 734 self.assertEquals({ 735 None: constants.ProblemType.LOGISTIC_REGRESSION 736 }, { 737 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 738 }) 739 740 def setUp(self): 741 self._logits = ((1.,), (1.,)) 742 self._labels = ((1.,), (0.,)) 743 744 def _expected_eval_metrics(self, expected_loss): 745 label_mean = np.mean(self._labels) 746 return { 747 "accuracy": 1. / 2, 748 "accuracy/baseline_label_mean": label_mean, 749 "accuracy/threshold_0.500000_mean": 1. / 2, 750 "auc": 1. / 2, 751 "auc_precision_recall": 0.25, 752 "labels/actual_label_mean": label_mean, 753 "labels/prediction_mean": .731059, # softmax 754 "loss": expected_loss, 755 "precision/positive_threshold_0.500000_mean": 1. / 2, 756 "recall/positive_threshold_0.500000_mean": 1. / 1, 757 } 758 759 def testBinaryClassificationWithLogits(self): 760 n_classes = 2 761 head = head_lib.multi_class_head(n_classes=n_classes) 762 with ops.Graph().as_default(), session.Session(): 763 # logloss: z:label, x:logit 764 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 765 model_fn_ops = head.create_model_fn_ops( 766 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 767 logits=self._logits) 768 self._assert_output_alternatives(model_fn_ops) 769 _assert_no_variables(self) 770 _assert_summary_tags(self, ["loss"]) 771 expected_loss = .81326175 772 _assert_metrics(self, expected_loss, 773 self._expected_eval_metrics(expected_loss), model_fn_ops) 774 775 def testBinaryClassificationWithInvalidLogits(self): 776 head = head_lib.multi_class_head(n_classes=len(self._labels) + 1) 777 with ops.Graph().as_default(), session.Session(): 778 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): 779 head.create_model_fn_ops( 780 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 781 logits=self._logits) 782 783 def testBinaryClassificationWithLogitsInput(self): 784 n_classes = 2 785 head = head_lib.multi_class_head(n_classes=n_classes) 786 with ops.Graph().as_default(), session.Session(): 787 # logloss: z:label, x:logit 788 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 789 model_fn_ops = head.create_model_fn_ops( 790 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 791 logits_input=((0., 0.), (0., 0.))) 792 self._assert_output_alternatives(model_fn_ops) 793 w = ("binary_logistic_head/logits/weights:0", 794 "binary_logistic_head/logits/biases:0") 795 _assert_variables( 796 self, expected_global=w, expected_model=w, expected_trainable=w) 797 variables.global_variables_initializer().run() 798 _assert_summary_tags(self, ["loss"]) 799 expected_loss = .69314718 800 label_mean = np.mean(self._labels) 801 _assert_metrics(self, expected_loss, { 802 "accuracy": 1. / 2, 803 "accuracy/baseline_label_mean": label_mean, 804 "accuracy/threshold_0.500000_mean": 1. / 2, 805 "auc": 1. / 2, 806 "labels/actual_label_mean": label_mean, 807 "labels/prediction_mean": .5, # softmax 808 "loss": expected_loss, 809 "precision/positive_threshold_0.500000_mean": 0. / 2, 810 "recall/positive_threshold_0.500000_mean": 0. / 1, 811 }, model_fn_ops) 812 813 def testBinaryClassificationWithLogitsAndLogitsInput(self): 814 head = head_lib.multi_class_head(n_classes=2) 815 with ops.Graph().as_default(), session.Session(): 816 with self.assertRaisesRegexp( 817 ValueError, "Both logits and logits_input supplied"): 818 head.create_model_fn_ops( 819 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 820 logits_input=((0., 0.), (0., 0.)), logits=self._logits) 821 822 def testBinaryClassificationEval(self): 823 n_classes = 2 824 head = head_lib.multi_class_head(n_classes=n_classes) 825 with ops.Graph().as_default(), session.Session(): 826 # logloss: z:label, x:logit 827 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 828 model_fn_ops = head.create_model_fn_ops( 829 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn, 830 logits=self._logits) 831 self._assert_output_alternatives(model_fn_ops) 832 self.assertIsNone(model_fn_ops.train_op) 833 _assert_no_variables(self) 834 _assert_summary_tags(self, ["loss"]) 835 expected_loss = .81326175 836 _assert_metrics(self, expected_loss, 837 self._expected_eval_metrics(expected_loss), model_fn_ops) 838 839 def testBinaryClassificationInfer(self): 840 n_classes = 2 841 head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name") 842 with ops.Graph().as_default(), session.Session(): 843 # logloss: z:label, x:logit 844 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 845 model_fn_ops = head.create_model_fn_ops( 846 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, 847 logits=self._logits) 848 self.assertIsNone(model_fn_ops.train_op) 849 _assert_no_variables(self) 850 with session.Session(): 851 self.assertListEqual( 852 [1, 1], list(model_fn_ops.predictions["classes"].eval())) 853 self.assertItemsEqual( 854 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) 855 self.assertEqual( 856 constants.ProblemType.LOGISTIC_REGRESSION, 857 model_fn_ops.output_alternatives["head_name"][0]) 858 predictions_for_serving = ( 859 model_fn_ops.output_alternatives["head_name"][1]) 860 self.assertIn("classes", six.iterkeys(predictions_for_serving)) 861 predicted_classes = predictions_for_serving["classes"].eval().tolist() 862 self.assertListEqual( 863 [b"0", b"1"], predicted_classes[0]) 864 self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) 865 866 def testBinaryClassificationInferMode_withWeightColumn(self): 867 n_classes = 2 868 head = head_lib.multi_class_head(n_classes=n_classes, 869 weight_column_name="label_weight") 870 with ops.Graph().as_default(), session.Session(): 871 # logloss: z:label, x:logit 872 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 873 model_fn_ops = head.create_model_fn_ops( 874 # This is what is being tested, features should not have weight for 875 # inference. 876 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn, 877 logits=self._logits) 878 self._assert_output_alternatives(model_fn_ops) 879 self.assertIsNone(model_fn_ops.train_op) 880 _assert_no_variables(self) 881 882 def testErrorInSparseTensorLabels(self): 883 n_classes = 2 884 head = head_lib.multi_class_head(n_classes=n_classes) 885 with ops.Graph().as_default(): 886 labels = sparse_tensor.SparseTensorValue( 887 indices=((0, 0), (1, 0), (2, 0)), 888 values=(0, 1, 1), 889 dense_shape=(3, 1)) 890 with self.assertRaisesRegexp(ValueError, 891 "SparseTensor is not supported"): 892 head.create_model_fn_ops( 893 {}, 894 model_fn.ModeKeys.TRAIN, 895 labels, 896 head_lib.no_op_train_fn, 897 logits=((1.,), (1.,), (3.,))) 898 899 def testBinaryClassificationWithLabelName(self): 900 label_name = "my_label" 901 head = head_lib.multi_class_head(n_classes=2, label_name=label_name) 902 with ops.Graph().as_default(), session.Session(): 903 # logloss: z:label, x:logit 904 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 905 model_fn_ops = head.create_model_fn_ops( 906 {}, 907 labels={label_name: self._labels}, 908 mode=model_fn.ModeKeys.TRAIN, 909 train_op_fn=head_lib.no_op_train_fn, 910 logits=self._logits) 911 self._assert_output_alternatives(model_fn_ops) 912 _assert_no_variables(self) 913 _assert_summary_tags(self, ["loss"]) 914 expected_loss = .81326175 915 _assert_metrics(self, expected_loss, 916 self._expected_eval_metrics(expected_loss), model_fn_ops) 917 918 def testBinaryClassificationWith1DWeights(self): 919 n_classes = 2 920 head = head_lib.multi_class_head( 921 n_classes=n_classes, weight_column_name="label_weight") 922 with ops.Graph().as_default(), session.Session(): 923 weights = (1., 0.) 924 # logloss: z:label, x:logit 925 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 926 model_fn_ops = head.create_model_fn_ops( 927 features={"label_weight": weights}, 928 labels=self._labels, 929 mode=model_fn.ModeKeys.TRAIN, 930 train_op_fn=head_lib.no_op_train_fn, 931 logits=self._logits) 932 self._assert_output_alternatives(model_fn_ops) 933 _assert_no_variables(self) 934 _assert_summary_tags(self, ["loss"]) 935 expected_total_loss = .31326166 936 _assert_metrics( 937 self, 938 expected_total_loss / len(weights), 939 { 940 "accuracy": 1. / 1, 941 "accuracy/baseline_label_mean": 1. / 1, 942 "accuracy/threshold_0.500000_mean": 1. / 1, 943 "auc": 0. / 1, 944 "labels/actual_label_mean": 1. / 1, 945 "labels/prediction_mean": .731059, # softmax 946 # eval loss is weighted loss divided by sum of weights. 947 "loss": expected_total_loss, 948 "precision/positive_threshold_0.500000_mean": 1. / 1, 949 "recall/positive_threshold_0.500000_mean": 1. / 1, 950 }, 951 model_fn_ops) 952 953 def testBinaryClassificationWith2DWeights(self): 954 n_classes = 2 955 head = head_lib.multi_class_head( 956 n_classes=n_classes, weight_column_name="label_weight") 957 with ops.Graph().as_default(), session.Session(): 958 weights = ((1.,), (0.,)) 959 # logloss: z:label, x:logit 960 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 961 model_fn_ops = head.create_model_fn_ops( 962 features={"label_weight": weights}, 963 labels=self._labels, 964 mode=model_fn.ModeKeys.TRAIN, 965 train_op_fn=head_lib.no_op_train_fn, 966 logits=self._logits) 967 self._assert_output_alternatives(model_fn_ops) 968 _assert_no_variables(self) 969 _assert_summary_tags(self, ["loss"]) 970 expected_total_loss = .31326166 971 _assert_metrics( 972 self, 973 expected_total_loss / len(weights), 974 { 975 "accuracy": 1. / 1, 976 "accuracy/baseline_label_mean": 1. / 1, 977 "accuracy/threshold_0.500000_mean": 1. / 1, 978 "auc": 0. / 1, 979 "labels/actual_label_mean": 1. / 1, 980 "labels/prediction_mean": .731059, # softmax 981 # eval loss is weighted loss divided by sum of weights. 982 "loss": expected_total_loss, 983 "precision/positive_threshold_0.500000_mean": 1. / 1, 984 "recall/positive_threshold_0.500000_mean": 1. / 1, 985 }, 986 model_fn_ops) 987 988 def testBinaryClassificationWithCustomLoss(self): 989 head = head_lib.multi_class_head( 990 n_classes=2, weight_column_name="label_weight", 991 loss_fn=_sigmoid_cross_entropy) 992 with ops.Graph().as_default(), session.Session(): 993 weights = ((.2,), (0.,)) 994 model_fn_ops = head.create_model_fn_ops( 995 features={"label_weight": weights}, 996 labels=self._labels, 997 mode=model_fn.ModeKeys.TRAIN, 998 train_op_fn=head_lib.no_op_train_fn, 999 logits=self._logits) 1000 self._assert_output_alternatives(model_fn_ops) 1001 _assert_no_variables(self) 1002 _assert_summary_tags(self, ["loss"]) 1003 # logloss: z:label, x:logit 1004 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1005 # expected_loss is (total_weighted_loss)/1 since there is 1 nonzero 1006 # weight. 1007 expected_loss = 0.062652342 1008 _assert_metrics( 1009 self, 1010 expected_loss, 1011 { 1012 "accuracy": 1. / 1, 1013 "accuracy/baseline_label_mean": 1. / 1, 1014 "accuracy/threshold_0.500000_mean": 1. / 1, 1015 "auc": 0. / 1, 1016 "labels/actual_label_mean": 1. / 1, 1017 "labels/prediction_mean": .731059, # softmax 1018 "loss": expected_loss, 1019 "precision/positive_threshold_0.500000_mean": 1. / 1, 1020 "recall/positive_threshold_0.500000_mean": 1. / 1, 1021 }, 1022 model_fn_ops) 1023 1024 def testBinaryClassificationWithCenteredBias(self): 1025 head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True) 1026 with ops.Graph().as_default(), session.Session(): 1027 # logloss: z:label, x:logit 1028 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1029 model_fn_ops = head.create_model_fn_ops( 1030 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1031 logits=self._logits) 1032 self._assert_output_alternatives(model_fn_ops) 1033 _assert_variables( 1034 self, 1035 expected_global=( 1036 "binary_logistic_head/centered_bias_weight:0", 1037 ("binary_logistic_head/binary_logistic_head/centered_bias_weight/" 1038 "Adagrad:0"),), 1039 expected_trainable=("binary_logistic_head/centered_bias_weight:0",)) 1040 variables.global_variables_initializer().run() 1041 _assert_summary_tags(self, [ 1042 "loss", 1043 "binary_logistic_head/centered_bias/bias_0" 1044 ]) 1045 expected_loss = .81326175 1046 _assert_metrics(self, expected_loss, 1047 self._expected_eval_metrics(expected_loss), model_fn_ops) 1048 1049 1050class MultiClassHeadTest(test.TestCase): 1051 1052 def _assert_output_alternatives(self, model_fn_ops): 1053 self.assertEquals({ 1054 None: constants.ProblemType.CLASSIFICATION 1055 }, { 1056 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 1057 }) 1058 1059 def setUp(self): 1060 self._logits = ((1., 0., 0.),) 1061 self._labels = ((2,),) 1062 1063 def _expected_eval_metrics(self, expected_loss): 1064 return { 1065 "accuracy": 0., 1066 "loss": expected_loss, 1067 "labels/actual_label_mean/class0": 0. / 1, 1068 "labels/actual_label_mean/class1": 0. / 1, 1069 "labels/actual_label_mean/class2": 1. / 1, 1070 "labels/logits_mean/class0": self._logits[0][0], 1071 "labels/logits_mean/class1": self._logits[0][1], 1072 "labels/logits_mean/class2": self._logits[0][2], 1073 "labels/prediction_mean/class0": self._logits[0][0], 1074 "labels/prediction_mean/class1": self._logits[0][1], 1075 "labels/prediction_mean/class2": self._logits[0][2], 1076 "labels/probability_mean/class0": 0.576117, # softmax 1077 "labels/probability_mean/class1": 0.211942, # softmax 1078 "labels/probability_mean/class2": 0.211942, # softmax 1079 } 1080 1081 def testMultiClassWithLogits(self): 1082 n_classes = 3 1083 head = head_lib.multi_class_head( 1084 n_classes=n_classes, metric_class_ids=range(n_classes)) 1085 with ops.Graph().as_default(), session.Session(): 1086 # logloss: z:label, x:logit 1087 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1088 model_fn_ops = head.create_model_fn_ops( 1089 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1090 logits=self._logits) 1091 self._assert_output_alternatives(model_fn_ops) 1092 _assert_no_variables(self) 1093 _assert_summary_tags(self, ["loss"]) 1094 expected_loss = 1.5514447 1095 _assert_metrics(self, expected_loss, 1096 self._expected_eval_metrics(expected_loss), model_fn_ops) 1097 1098 def testMultiClassWithInvalidLogits(self): 1099 head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1) 1100 with ops.Graph().as_default(), session.Session(): 1101 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): 1102 head.create_model_fn_ops( 1103 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1104 logits=self._logits) 1105 1106 def testMultiClassWithNoneTrainOpFnInTrain(self): 1107 head = head_lib.multi_class_head(n_classes=3) 1108 with ops.Graph().as_default(), session.Session(): 1109 with self.assertRaisesRegexp( 1110 ValueError, "train_op_fn can not be None in TRAIN mode"): 1111 head.create_model_fn_ops( 1112 {}, model_fn.ModeKeys.TRAIN, self._labels, 1113 train_op_fn=None, 1114 logits=self._logits) 1115 1116 def testMultiClassWithLogitsInput(self): 1117 n_classes = 3 1118 head = head_lib.multi_class_head( 1119 n_classes=n_classes, metric_class_ids=range(n_classes)) 1120 with ops.Graph().as_default(), session.Session(): 1121 # logloss: z:label, x:logit 1122 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1123 model_fn_ops = head.create_model_fn_ops( 1124 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1125 logits_input=((0., 0.),)) 1126 self._assert_output_alternatives(model_fn_ops) 1127 w = ("multi_class_head/logits/weights:0", 1128 "multi_class_head/logits/biases:0") 1129 _assert_variables( 1130 self, expected_global=w, expected_model=w, expected_trainable=w) 1131 variables.global_variables_initializer().run() 1132 _assert_summary_tags(self, ["loss"]) 1133 expected_loss = 1.0986123 1134 _assert_metrics(self, expected_loss, { 1135 "accuracy": 0., 1136 "loss": expected_loss, 1137 "labels/actual_label_mean/class0": 0. / 1, 1138 "labels/actual_label_mean/class1": 0. / 1, 1139 "labels/actual_label_mean/class2": 1. / 1, 1140 "labels/logits_mean/class0": 0., 1141 "labels/logits_mean/class1": 0., 1142 "labels/logits_mean/class2": 0., 1143 "labels/prediction_mean/class0": 1., 1144 "labels/prediction_mean/class1": 0., 1145 "labels/prediction_mean/class2": 0., 1146 "labels/probability_mean/class0": 0.333333, # softmax 1147 "labels/probability_mean/class1": 0.333333, # softmax 1148 "labels/probability_mean/class2": 0.333333, # softmax 1149 }, model_fn_ops) 1150 1151 def testMultiClassWithLogitsAndLogitsInput(self): 1152 n_classes = 3 1153 head = head_lib.multi_class_head( 1154 n_classes=n_classes, metric_class_ids=range(n_classes)) 1155 with ops.Graph().as_default(), session.Session(): 1156 with self.assertRaisesRegexp( 1157 ValueError, "Both logits and logits_input supplied"): 1158 head.create_model_fn_ops( 1159 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1160 logits_input=((0., 0.),), logits=self._logits) 1161 1162 def testMultiClassEnableCenteredBias(self): 1163 n_classes = 3 1164 head = head_lib.multi_class_head( 1165 n_classes=n_classes, enable_centered_bias=True) 1166 with ops.Graph().as_default(), session.Session(): 1167 # logloss: z:label, x:logit 1168 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1169 model_fn_ops = head.create_model_fn_ops( 1170 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1171 logits=self._logits) 1172 self._assert_output_alternatives(model_fn_ops) 1173 _assert_variables( 1174 self, 1175 expected_global=( 1176 "multi_class_head/centered_bias_weight:0", 1177 ("multi_class_head/multi_class_head/centered_bias_weight/" 1178 "Adagrad:0"), 1179 ), 1180 expected_trainable=("multi_class_head/centered_bias_weight:0",)) 1181 variables.global_variables_initializer().run() 1182 _assert_summary_tags(self, 1183 ["loss", 1184 "multi_class_head/centered_bias/bias_0", 1185 "multi_class_head/centered_bias/bias_1", 1186 "multi_class_head/centered_bias/bias_2"]) 1187 1188 def testMultiClassEval(self): 1189 n_classes = 3 1190 head = head_lib.multi_class_head( 1191 n_classes=n_classes, metric_class_ids=range(n_classes)) 1192 with ops.Graph().as_default(), session.Session(): 1193 # logloss: z:label, x:logit 1194 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1195 model_fn_ops = head.create_model_fn_ops( 1196 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn, 1197 logits=self._logits) 1198 self._assert_output_alternatives(model_fn_ops) 1199 self.assertIsNone(model_fn_ops.train_op) 1200 _assert_no_variables(self) 1201 _assert_summary_tags(self, ["loss"]) 1202 expected_loss = 1.5514447 1203 _assert_metrics(self, expected_loss, 1204 self._expected_eval_metrics(expected_loss), model_fn_ops) 1205 1206 def testMultiClassEvalModeWithLargeLogits(self): 1207 n_classes = 3 1208 head = head_lib.multi_class_head( 1209 n_classes=n_classes, metric_class_ids=range(n_classes)) 1210 logits = ((2., 0., -1),) 1211 with ops.Graph().as_default(), session.Session(): 1212 # logloss: z:label, x:logit 1213 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1214 model_fn_ops = head.create_model_fn_ops( 1215 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn, 1216 logits=logits) 1217 self._assert_output_alternatives(model_fn_ops) 1218 self.assertIsNone(model_fn_ops.train_op) 1219 _assert_no_variables(self) 1220 _assert_summary_tags(self, ["loss"]) 1221 expected_loss = 3.1698461 1222 expected_eval_metrics = { 1223 "accuracy": 0., 1224 "loss": expected_loss, 1225 "labels/actual_label_mean/class0": 0. / 1, 1226 "labels/actual_label_mean/class1": 0. / 1, 1227 "labels/actual_label_mean/class2": 1. / 1, 1228 "labels/logits_mean/class0": logits[0][0], 1229 "labels/logits_mean/class1": logits[0][1], 1230 "labels/logits_mean/class2": logits[0][2], 1231 "labels/prediction_mean/class0": 1, 1232 "labels/prediction_mean/class1": 0, 1233 "labels/prediction_mean/class2": 0, 1234 "labels/probability_mean/class0": 0.843795, # softmax 1235 "labels/probability_mean/class1": 0.114195, # softmax 1236 "labels/probability_mean/class2": 0.0420101, # softmax 1237 } 1238 _assert_metrics(self, expected_loss, 1239 expected_eval_metrics, model_fn_ops) 1240 1241 def testMultiClassWithScalarWeight(self): 1242 n_classes = 3 1243 head = head_lib.multi_class_head( 1244 n_classes=n_classes, 1245 weight_column_name="label_weight", 1246 metric_class_ids=range(n_classes)) 1247 with ops.Graph().as_default(), session.Session(): 1248 weight = .1 1249 # logloss: z:label, x:logit 1250 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1251 model_fn_ops = head.create_model_fn_ops( 1252 features={"label_weight": weight}, 1253 labels=self._labels, 1254 mode=model_fn.ModeKeys.TRAIN, 1255 train_op_fn=head_lib.no_op_train_fn, 1256 logits=self._logits) 1257 self._assert_output_alternatives(model_fn_ops) 1258 _assert_no_variables(self) 1259 _assert_summary_tags(self, ["loss"]) 1260 expected_loss = 1.5514447 1261 _assert_metrics(self, expected_loss * weight, 1262 self._expected_eval_metrics(expected_loss), model_fn_ops) 1263 1264 def testMultiClassWith1DWeight(self): 1265 n_classes = 3 1266 head = head_lib.multi_class_head( 1267 n_classes=n_classes, 1268 weight_column_name="label_weight", 1269 metric_class_ids=range(n_classes)) 1270 with ops.Graph().as_default(), session.Session(): 1271 weight = .1 1272 weights = (weight,) 1273 # logloss: z:label, x:logit 1274 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1275 model_fn_ops = head.create_model_fn_ops( 1276 features={"label_weight": weights}, 1277 labels=self._labels, 1278 mode=model_fn.ModeKeys.TRAIN, 1279 train_op_fn=head_lib.no_op_train_fn, 1280 logits=self._logits) 1281 self._assert_output_alternatives(model_fn_ops) 1282 _assert_no_variables(self) 1283 _assert_summary_tags(self, ["loss"]) 1284 expected_loss = 1.5514447 1285 _assert_metrics(self, expected_loss * weight, 1286 self._expected_eval_metrics(expected_loss), model_fn_ops) 1287 1288 def testMultiClassWith2DWeight(self): 1289 n_classes = 3 1290 head = head_lib.multi_class_head( 1291 n_classes=n_classes, 1292 weight_column_name="label_weight", 1293 metric_class_ids=range(n_classes)) 1294 with ops.Graph().as_default(), session.Session(): 1295 weight = .1 1296 weights = ((weight,),) 1297 # logloss: z:label, x:logit 1298 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1299 model_fn_ops = head.create_model_fn_ops( 1300 features={"label_weight": weights}, 1301 labels=self._labels, 1302 mode=model_fn.ModeKeys.TRAIN, 1303 train_op_fn=head_lib.no_op_train_fn, 1304 logits=self._logits) 1305 self._assert_output_alternatives(model_fn_ops) 1306 _assert_no_variables(self) 1307 _assert_summary_tags(self, ["loss"]) 1308 expected_loss = 1.5514447 1309 _assert_metrics(self, expected_loss * weight, 1310 self._expected_eval_metrics(expected_loss), model_fn_ops) 1311 1312 def testMultiClassWithCustomLoss(self): 1313 n_classes = 3 1314 head = head_lib.multi_class_head( 1315 n_classes=n_classes, 1316 weight_column_name="label_weight", 1317 metric_class_ids=range(n_classes), 1318 loss_fn=losses_lib.sparse_softmax_cross_entropy) 1319 with ops.Graph().as_default(), session.Session(): 1320 weight = .1 1321 # logloss: z:label, x:logit 1322 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 1323 model_fn_ops = head.create_model_fn_ops( 1324 features={"label_weight": weight}, 1325 labels=self._labels, 1326 mode=model_fn.ModeKeys.TRAIN, 1327 train_op_fn=head_lib.no_op_train_fn, 1328 logits=self._logits) 1329 self._assert_output_alternatives(model_fn_ops) 1330 _assert_no_variables(self) 1331 _assert_summary_tags(self, ["loss"]) 1332 expected_loss = 1.5514447 * weight 1333 _assert_metrics(self, expected_loss, 1334 self._expected_eval_metrics(expected_loss), model_fn_ops) 1335 1336 def testMultiClassInfer(self): 1337 n_classes = 3 1338 head = head_lib._multi_class_head( 1339 n_classes=n_classes, 1340 head_name="head_name") 1341 with ops.Graph().as_default(): 1342 model_fn_ops = head.create_model_fn_ops( 1343 features={}, 1344 mode=model_fn.ModeKeys.INFER, 1345 train_op_fn=head_lib.no_op_train_fn, 1346 logits=((1., 0., 0.), (0., 0., 1.),)) 1347 with session.Session(): 1348 lookup_ops.tables_initializer().run() 1349 self.assertAllEqual( 1350 [0, 2], 1351 model_fn_ops.predictions["classes"].eval()) 1352 self.assertItemsEqual( 1353 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) 1354 self.assertEqual( 1355 constants.ProblemType.CLASSIFICATION, 1356 model_fn_ops.output_alternatives["head_name"][0]) 1357 predictions_for_serving = ( 1358 model_fn_ops.output_alternatives["head_name"][1]) 1359 self.assertIn("classes", six.iterkeys(predictions_for_serving)) 1360 self.assertAllEqual( 1361 [[b"0", b"1", b"2"], [b"0", b"1", b"2"]], 1362 predictions_for_serving["classes"].eval()) 1363 self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) 1364 self.assertAllClose( 1365 [[0.576117, 0.2119416, 0.2119416], 1366 [0.2119416, 0.2119416, 0.576117]], 1367 predictions_for_serving["probabilities"].eval()) 1368 1369 def testInvalidNClasses(self): 1370 for n_classes in (None, -1, 0, 1): 1371 with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"): 1372 head_lib.multi_class_head(n_classes=n_classes) 1373 1374 def testMultiClassWithLabelKeysInvalidShape(self): 1375 with self.assertRaisesRegexp( 1376 ValueError, "Length of label_keys must equal n_classes"): 1377 head_lib._multi_class_head( 1378 n_classes=3, label_keys=("key0", "key1")) 1379 1380 def testMultiClassWithLabelKeysTwoClasses(self): 1381 with self.assertRaisesRegexp( 1382 ValueError, "label_keys is not supported for n_classes=2"): 1383 head_lib._multi_class_head( 1384 n_classes=2, label_keys=("key0", "key1")) 1385 1386 def testMultiClassWithLabelKeysInfer(self): 1387 n_classes = 3 1388 label_keys = ("key0", "key1", "key2") 1389 head = head_lib._multi_class_head( 1390 n_classes=n_classes, label_keys=label_keys, 1391 metric_class_ids=range(n_classes), 1392 head_name="head_name") 1393 with ops.Graph().as_default(): 1394 model_fn_ops = head.create_model_fn_ops( 1395 features={}, 1396 mode=model_fn.ModeKeys.INFER, 1397 train_op_fn=head_lib.no_op_train_fn, 1398 logits=((1., 0., 0.), (0., 0., 1.),)) 1399 with session.Session(): 1400 lookup_ops.tables_initializer().run() 1401 self.assertAllEqual( 1402 [b"key0", b"key2"], 1403 model_fn_ops.predictions["classes"].eval()) 1404 self.assertItemsEqual( 1405 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives)) 1406 self.assertEqual( 1407 constants.ProblemType.CLASSIFICATION, 1408 model_fn_ops.output_alternatives["head_name"][0]) 1409 predictions_for_serving = ( 1410 model_fn_ops.output_alternatives["head_name"][1]) 1411 self.assertIn("classes", six.iterkeys(predictions_for_serving)) 1412 self.assertAllEqual( 1413 [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]], 1414 predictions_for_serving["classes"].eval()) 1415 self.assertIn("probabilities", six.iterkeys(predictions_for_serving)) 1416 self.assertAllClose( 1417 [[0.576117, 0.2119416, 0.2119416], 1418 [0.2119416, 0.2119416, 0.576117]], 1419 predictions_for_serving["probabilities"].eval()) 1420 1421 def testMultiClassWithLabelKeysEvalAccuracy0(self): 1422 n_classes = 3 1423 label_keys = ("key0", "key1", "key2") 1424 head = head_lib._multi_class_head( 1425 n_classes=n_classes, 1426 label_keys=label_keys) 1427 with ops.Graph().as_default(): 1428 model_fn_ops = head.create_model_fn_ops( 1429 features={}, 1430 mode=model_fn.ModeKeys.EVAL, 1431 labels=("key2",), 1432 train_op_fn=head_lib.no_op_train_fn, 1433 logits=((1., 0., 0.),)) 1434 with session.Session(): 1435 lookup_ops.tables_initializer().run() 1436 self.assertIsNone(model_fn_ops.train_op) 1437 _assert_no_variables(self) 1438 _assert_summary_tags(self, ["loss"]) 1439 expected_loss = 1.5514447 1440 expected_eval_metrics = { 1441 "accuracy": 0., 1442 "loss": expected_loss, 1443 } 1444 _assert_metrics(self, expected_loss, 1445 expected_eval_metrics, model_fn_ops) 1446 1447 def testMultiClassWithLabelKeysEvalAccuracy1(self): 1448 n_classes = 3 1449 label_keys = ("key0", "key1", "key2") 1450 head = head_lib._multi_class_head( 1451 n_classes=n_classes, 1452 label_keys=label_keys) 1453 with ops.Graph().as_default(): 1454 model_fn_ops = head.create_model_fn_ops( 1455 features={}, 1456 mode=model_fn.ModeKeys.EVAL, 1457 labels=("key2",), 1458 train_op_fn=head_lib.no_op_train_fn, 1459 logits=((0., 0., 1.),)) 1460 with session.Session(): 1461 lookup_ops.tables_initializer().run() 1462 self.assertIsNone(model_fn_ops.train_op) 1463 _assert_no_variables(self) 1464 _assert_summary_tags(self, ["loss"]) 1465 expected_loss = 0.5514447 1466 expected_eval_metrics = { 1467 "accuracy": 1., 1468 "loss": expected_loss, 1469 } 1470 _assert_metrics(self, expected_loss, 1471 expected_eval_metrics, model_fn_ops) 1472 1473 1474class BinarySvmHeadTest(test.TestCase): 1475 1476 def _assert_output_alternatives(self, model_fn_ops): 1477 self.assertEquals({ 1478 None: constants.ProblemType.LOGISTIC_REGRESSION 1479 }, { 1480 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 1481 }) 1482 1483 def setUp(self): 1484 # Prediction for first example is in the right side of the hyperplane 1485 # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss 1486 # incurred by this example. The 2nd prediction is outside the margin so it 1487 # incurs no loss at all. 1488 self._predictions = ((-.5,), (1.2,)) 1489 self._labels = (0, 1) 1490 self._expected_losses = (.5, 0.) 1491 1492 def testBinarySVMWithLogits(self): 1493 head = head_lib.binary_svm_head() 1494 with ops.Graph().as_default(), session.Session(): 1495 model_fn_ops = head.create_model_fn_ops( 1496 {}, 1497 model_fn.ModeKeys.TRAIN, 1498 self._labels, 1499 head_lib.no_op_train_fn, 1500 logits=self._predictions) 1501 self._assert_output_alternatives(model_fn_ops) 1502 _assert_no_variables(self) 1503 _assert_summary_tags(self, ["loss"]) 1504 expected_loss = np.average(self._expected_losses) 1505 _assert_metrics(self, expected_loss, { 1506 "accuracy": 1., 1507 "loss": expected_loss, 1508 }, model_fn_ops) 1509 1510 def testBinarySVMWithInvalidLogits(self): 1511 head = head_lib.binary_svm_head() 1512 with ops.Graph().as_default(), session.Session(): 1513 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): 1514 head.create_model_fn_ops( 1515 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn, 1516 logits=np.ones((2, 2))) 1517 1518 def testBinarySVMWithLogitsInput(self): 1519 head = head_lib.binary_svm_head() 1520 with ops.Graph().as_default(), session.Session(): 1521 model_fn_ops = head.create_model_fn_ops( 1522 {}, 1523 model_fn.ModeKeys.TRAIN, 1524 self._labels, 1525 head_lib.no_op_train_fn, 1526 logits_input=((0., 0.), (0., 0.))) 1527 self._assert_output_alternatives(model_fn_ops) 1528 w = ("binary_svm_head/logits/weights:0", 1529 "binary_svm_head/logits/biases:0") 1530 _assert_variables( 1531 self, expected_global=w, expected_model=w, expected_trainable=w) 1532 variables.global_variables_initializer().run() 1533 _assert_summary_tags(self, ["loss"]) 1534 expected_loss = 1. 1535 _assert_metrics(self, expected_loss, { 1536 "accuracy": .5, 1537 "loss": expected_loss, 1538 }, model_fn_ops) 1539 1540 def testBinarySVMWithLogitsAndLogitsInput(self): 1541 head = head_lib.binary_svm_head() 1542 with ops.Graph().as_default(), session.Session(): 1543 with self.assertRaisesRegexp( 1544 ValueError, "Both logits and logits_input supplied"): 1545 head.create_model_fn_ops( 1546 {}, 1547 model_fn.ModeKeys.TRAIN, 1548 self._labels, 1549 head_lib.no_op_train_fn, 1550 logits_input=((0., 0.), (0., 0.)), 1551 logits=self._predictions) 1552 1553 def testBinarySVMEvalMode(self): 1554 head = head_lib.binary_svm_head() 1555 with ops.Graph().as_default(), session.Session(): 1556 model_fn_ops = head.create_model_fn_ops( 1557 {}, 1558 model_fn.ModeKeys.EVAL, 1559 self._labels, 1560 head_lib.no_op_train_fn, 1561 logits=self._predictions) 1562 self._assert_output_alternatives(model_fn_ops) 1563 self.assertIsNone(model_fn_ops.train_op) 1564 _assert_no_variables(self) 1565 _assert_summary_tags(self, ["loss"]) 1566 expected_loss = np.average(self._expected_losses) 1567 _assert_metrics(self, expected_loss, { 1568 "accuracy": 1., 1569 "loss": expected_loss, 1570 }, model_fn_ops) 1571 1572 def testBinarySVMWithLabelName(self): 1573 label_name = "my_label" 1574 head = head_lib.binary_svm_head(label_name=label_name) 1575 with ops.Graph().as_default(), session.Session(): 1576 model_fn_ops = head.create_model_fn_ops( 1577 {}, 1578 model_fn.ModeKeys.TRAIN, 1579 {label_name: self._labels}, 1580 head_lib.no_op_train_fn, 1581 logits=self._predictions) 1582 self._assert_output_alternatives(model_fn_ops) 1583 _assert_no_variables(self) 1584 _assert_summary_tags(self, ["loss"]) 1585 expected_loss = np.average(self._expected_losses) 1586 _assert_metrics(self, expected_loss, { 1587 "accuracy": 1., 1588 "loss": expected_loss, 1589 }, model_fn_ops) 1590 1591 def testBinarySVMWith1DWeights(self): 1592 head = head_lib.binary_svm_head(weight_column_name="weights") 1593 with ops.Graph().as_default(), session.Session(): 1594 weights = (7., 11.) 1595 model_fn_ops = head.create_model_fn_ops( 1596 # We have to add an extra dim here for weights broadcasting to work. 1597 features={"weights": weights}, 1598 mode=model_fn.ModeKeys.TRAIN, 1599 labels=self._labels, 1600 train_op_fn=head_lib.no_op_train_fn, 1601 logits=self._predictions) 1602 self._assert_output_alternatives(model_fn_ops) 1603 _assert_no_variables(self) 1604 _assert_summary_tags(self, ["loss"]) 1605 expected_weighted_losses = np.multiply(weights, self._expected_losses) 1606 _assert_metrics(self, np.mean(expected_weighted_losses), { 1607 "accuracy": 1., 1608 "loss": np.sum(expected_weighted_losses) / np.sum(weights), 1609 }, model_fn_ops) 1610 1611 def testBinarySVMWith2DWeights(self): 1612 head = head_lib.binary_svm_head(weight_column_name="weights") 1613 with ops.Graph().as_default(), session.Session(): 1614 weights = (7., 11.) 1615 model_fn_ops = head.create_model_fn_ops( 1616 # We have to add an extra dim here for weights broadcasting to work. 1617 features={"weights": tuple([(w,) for w in weights])}, 1618 mode=model_fn.ModeKeys.TRAIN, 1619 labels=self._labels, 1620 train_op_fn=head_lib.no_op_train_fn, 1621 logits=self._predictions) 1622 self._assert_output_alternatives(model_fn_ops) 1623 _assert_no_variables(self) 1624 _assert_summary_tags(self, ["loss"]) 1625 expected_weighted_losses = np.multiply(weights, self._expected_losses) 1626 _assert_metrics(self, np.mean(expected_weighted_losses), { 1627 "accuracy": 1., 1628 "loss": np.sum(expected_weighted_losses) / np.sum(weights), 1629 }, model_fn_ops) 1630 1631 def testBinarySVMWithCenteredBias(self): 1632 head = head_lib.binary_svm_head(enable_centered_bias=True) 1633 with ops.Graph().as_default(), session.Session(): 1634 model_fn_ops = head.create_model_fn_ops( 1635 {}, 1636 model_fn.ModeKeys.TRAIN, 1637 self._labels, 1638 head_lib.no_op_train_fn, 1639 logits=self._predictions) 1640 self._assert_output_alternatives(model_fn_ops) 1641 _assert_variables( 1642 self, 1643 expected_global=( 1644 "binary_svm_head/centered_bias_weight:0", 1645 ("binary_svm_head/binary_svm_head/centered_bias_weight/" 1646 "Adagrad:0"), 1647 ), 1648 expected_trainable=("binary_svm_head/centered_bias_weight:0",)) 1649 variables.global_variables_initializer().run() 1650 _assert_summary_tags(self, [ 1651 "loss", 1652 "binary_svm_head/centered_bias/bias_0" 1653 ]) 1654 expected_loss = np.average(self._expected_losses) 1655 _assert_metrics(self, expected_loss, { 1656 "accuracy": 1., 1657 "loss": expected_loss, 1658 }, model_fn_ops) 1659 1660 1661class LossOnlyHead(test.TestCase): 1662 1663 def testNoPredictionsAndNoMetrics(self): 1664 head = head_lib.loss_only_head(lambda: 1, head_name="const") 1665 model_fn_ops = head.create_model_fn_ops( 1666 features={}, 1667 mode=model_fn.ModeKeys.TRAIN, 1668 train_op_fn=head_lib.no_op_train_fn) 1669 self.assertDictEqual(model_fn_ops.predictions, {}) 1670 self.assertDictEqual(model_fn_ops.eval_metric_ops, {}) 1671 self.assertIsNotNone(model_fn_ops.loss) 1672 with session.Session() as sess: 1673 self.assertEqual(1, sess.run(model_fn_ops.loss)) 1674 1675 1676class MultiHeadTest(test.TestCase): 1677 1678 def testInvalidHeads(self): 1679 named_head = head_lib.multi_class_head( 1680 n_classes=3, label_name="label", head_name="head1") 1681 unnamed_head = head_lib.multi_class_head( 1682 n_classes=4, label_name="label") 1683 with self.assertRaisesRegexp(ValueError, "must have names"): 1684 head_lib.multi_head((named_head, unnamed_head)) 1685 1686 def testTrainWithNoneTrainOpFn(self): 1687 head1 = head_lib.multi_class_head( 1688 n_classes=3, label_name="label1", head_name="head1") 1689 head2 = head_lib.multi_class_head( 1690 n_classes=4, label_name="label2", head_name="head2") 1691 head = head_lib.multi_head((head1, head2)) 1692 labels = { 1693 "label1": (1,), 1694 "label2": (1,) 1695 } 1696 with self.assertRaisesRegexp( 1697 ValueError, "train_op_fn can not be None in TRAIN mode"): 1698 head.create_model_fn_ops( 1699 features={"weights": (2.0, 10.0)}, 1700 labels=labels, 1701 mode=model_fn.ModeKeys.TRAIN, 1702 train_op_fn=None, 1703 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) 1704 1705 def testTrain_withNoHeadWeights(self): 1706 head1 = head_lib.multi_class_head( 1707 n_classes=3, label_name="label1", head_name="head1") 1708 head2 = head_lib.multi_class_head( 1709 n_classes=4, label_name="label2", head_name="head2") 1710 head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const") 1711 head = head_lib.multi_head((head1, head2, head3)) 1712 labels = { 1713 "label1": (1,), 1714 "label2": (1,) 1715 } 1716 model_fn_ops = head.create_model_fn_ops( 1717 features={"weights": (2.0, 10.0)}, 1718 labels=labels, 1719 mode=model_fn.ModeKeys.TRAIN, 1720 train_op_fn=head_lib.no_op_train_fn, 1721 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) 1722 1723 self.assertIsNone(model_fn_ops.predictions) 1724 self.assertIsNotNone(model_fn_ops.loss) 1725 self.assertIsNotNone(model_fn_ops.train_op) 1726 self.assertTrue(model_fn_ops.eval_metric_ops) 1727 self.assertIsNone(model_fn_ops.output_alternatives) 1728 1729 with session.Session() as sess: 1730 self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3) 1731 1732 def testTrain_withHeadWeights(self): 1733 head1 = head_lib.multi_class_head( 1734 n_classes=3, label_name="label1", head_name="head1") 1735 head2 = head_lib.multi_class_head( 1736 n_classes=4, label_name="label2", head_name="head2") 1737 head = head_lib.multi_head((head1, head2), (1, .5)) 1738 labels = { 1739 "label1": (1,), 1740 "label2": (1,) 1741 } 1742 model_fn_ops = head.create_model_fn_ops( 1743 features={"weights": (2.0, 10.0)}, 1744 labels=labels, 1745 mode=model_fn.ModeKeys.TRAIN, 1746 train_op_fn=head_lib.no_op_train_fn, 1747 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) 1748 self.assertIsNone(model_fn_ops.predictions) 1749 self.assertIsNotNone(model_fn_ops.loss) 1750 self.assertIsNotNone(model_fn_ops.train_op) 1751 self.assertTrue(model_fn_ops.eval_metric_ops) 1752 self.assertIsNone(model_fn_ops.output_alternatives) 1753 1754 with session.Session() as sess: 1755 self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3) 1756 1757 def testTrain_withDictLogits(self): 1758 head1 = head_lib.multi_class_head( 1759 n_classes=3, label_name="label1", head_name="head1") 1760 head2 = head_lib.multi_class_head( 1761 n_classes=4, label_name="label2", head_name="head2") 1762 head = head_lib.multi_head((head1, head2)) 1763 labels = { 1764 "label1": (1,), 1765 "label2": (1,) 1766 } 1767 model_fn_ops = head.create_model_fn_ops( 1768 features={"weights": (2.0, 10.0)}, 1769 labels=labels, 1770 mode=model_fn.ModeKeys.TRAIN, 1771 train_op_fn=head_lib.no_op_train_fn, 1772 logits={head1.head_name: ((-0.7, 0.2, .1),), 1773 head2.head_name: ((.1, .1, .1, .1),)}) 1774 1775 self.assertIsNone(model_fn_ops.predictions) 1776 self.assertIsNotNone(model_fn_ops.loss) 1777 self.assertIsNotNone(model_fn_ops.train_op) 1778 self.assertTrue(model_fn_ops.eval_metric_ops) 1779 self.assertIsNone(model_fn_ops.output_alternatives) 1780 1781 with session.Session() as sess: 1782 self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) 1783 1784 def testInfer(self): 1785 head1 = head_lib.multi_class_head( 1786 n_classes=3, label_name="label1", head_name="head1") 1787 head2 = head_lib.multi_class_head( 1788 n_classes=4, label_name="label2", head_name="head2") 1789 head = head_lib.multi_head((head1, head2), (1, .5)) 1790 labels = { 1791 "label1": (1,), 1792 "label2": (1,) 1793 } 1794 model_fn_ops = head.create_model_fn_ops( 1795 features={"weights": (2.0, 10.0)}, 1796 labels=labels, 1797 mode=model_fn.ModeKeys.INFER, 1798 train_op_fn=head_lib.no_op_train_fn, 1799 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) 1800 1801 self.assertIsNotNone(model_fn_ops.predictions) 1802 self.assertIsNone(model_fn_ops.loss) 1803 self.assertIsNone(model_fn_ops.train_op) 1804 self.assertFalse(model_fn_ops.eval_metric_ops) 1805 1806 # Tests predictions keys. 1807 self.assertItemsEqual(( 1808 ("head1", prediction_key.PredictionKey.LOGITS), 1809 ("head1", prediction_key.PredictionKey.PROBABILITIES), 1810 ("head1", prediction_key.PredictionKey.CLASSES), 1811 ("head2", prediction_key.PredictionKey.LOGITS), 1812 ("head2", prediction_key.PredictionKey.PROBABILITIES), 1813 ("head2", prediction_key.PredictionKey.CLASSES), 1814 ), model_fn_ops.predictions.keys()) 1815 1816 # Tests output alternative. 1817 self.assertEquals({ 1818 "head1": constants.ProblemType.CLASSIFICATION, 1819 "head2": constants.ProblemType.CLASSIFICATION, 1820 }, { 1821 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives) 1822 }) 1823 self.assertItemsEqual(( 1824 prediction_key.PredictionKey.PROBABILITIES, 1825 prediction_key.PredictionKey.CLASSES, 1826 ), model_fn_ops.output_alternatives["head1"][1].keys()) 1827 self.assertItemsEqual(( 1828 prediction_key.PredictionKey.PROBABILITIES, 1829 prediction_key.PredictionKey.CLASSES, 1830 ), model_fn_ops.output_alternatives["head2"][1].keys()) 1831 1832 def testEval(self): 1833 head1 = head_lib.multi_class_head( 1834 n_classes=3, label_name="label1", head_name="head1") 1835 head2 = head_lib.multi_class_head( 1836 n_classes=4, label_name="label2", head_name="head2") 1837 head = head_lib.multi_head((head1, head2), (1, .5)) 1838 labels = { 1839 "label1": (1,), 1840 "label2": (1,) 1841 } 1842 model_fn_ops = head.create_model_fn_ops( 1843 features={"weights": (2.0, 10.0)}, 1844 labels=labels, 1845 mode=model_fn.ModeKeys.EVAL, 1846 train_op_fn=head_lib.no_op_train_fn, 1847 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) 1848 1849 self.assertIsNotNone(model_fn_ops.predictions) 1850 self.assertIsNotNone(model_fn_ops.loss) 1851 self.assertIsNone(model_fn_ops.train_op) 1852 self.assertIsNotNone(model_fn_ops.eval_metric_ops) 1853 self.assertIsNone(model_fn_ops.output_alternatives) 1854 1855 metric_ops = model_fn_ops.eval_metric_ops 1856 1857 # Tests eval keys. 1858 self.assertIn("accuracy/head1", metric_ops.keys()) 1859 self.assertIn("accuracy/head2", metric_ops.keys()) 1860 1861 1862def _sigmoid_cross_entropy(labels, logits, weights): 1863 return losses_lib.sigmoid_cross_entropy(labels, logits, weights) 1864 1865 1866if __name__ == "__main__": 1867 test.main() 1868