1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for V2 summary ops from summary_ops_v2.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import unittest 23 24import six 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import node_def_pb2 28from tensorflow.core.framework import step_stats_pb2 29from tensorflow.core.framework import summary_pb2 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.util import event_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.framework import test_util 41from tensorflow.python.keras.engine.sequential import Sequential 42from tensorflow.python.keras.engine.training import Model 43from tensorflow.python.keras.layers.core import Activation 44from tensorflow.python.keras.layers.core import Dense 45from tensorflow.python.lib.io import tf_record 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import summary_ops_v2 as summary_ops 48from tensorflow.python.ops import variables 49from tensorflow.python.platform import gfile 50from tensorflow.python.platform import test 51from tensorflow.python.platform import tf_logging as logging 52 53 54class SummaryOpsCoreTest(test_util.TensorFlowTestCase): 55 56 def testWrite(self): 57 logdir = self.get_temp_dir() 58 with context.eager_mode(): 59 with summary_ops.create_file_writer_v2(logdir).as_default(): 60 output = summary_ops.write('tag', 42, step=12) 61 self.assertTrue(output.numpy()) 62 events = events_from_logdir(logdir) 63 self.assertEqual(2, len(events)) 64 self.assertEqual(12, events[1].step) 65 value = events[1].summary.value[0] 66 self.assertEqual('tag', value.tag) 67 self.assertEqual(42, to_numpy(value)) 68 69 def testWrite_fromFunction(self): 70 logdir = self.get_temp_dir() 71 with context.eager_mode(): 72 writer = summary_ops.create_file_writer_v2(logdir) 73 @def_function.function 74 def f(): 75 with writer.as_default(): 76 return summary_ops.write('tag', 42, step=12) 77 output = f() 78 self.assertTrue(output.numpy()) 79 events = events_from_logdir(logdir) 80 self.assertEqual(2, len(events)) 81 self.assertEqual(12, events[1].step) 82 value = events[1].summary.value[0] 83 self.assertEqual('tag', value.tag) 84 self.assertEqual(42, to_numpy(value)) 85 86 def testWrite_metadata(self): 87 logdir = self.get_temp_dir() 88 metadata = summary_pb2.SummaryMetadata() 89 metadata.plugin_data.plugin_name = 'foo' 90 with context.eager_mode(): 91 with summary_ops.create_file_writer_v2(logdir).as_default(): 92 summary_ops.write('obj', 0, 0, metadata=metadata) 93 summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString()) 94 m = constant_op.constant(metadata.SerializeToString()) 95 summary_ops.write('string_tensor', 0, 0, metadata=m) 96 events = events_from_logdir(logdir) 97 self.assertEqual(4, len(events)) 98 self.assertEqual(metadata, events[1].summary.value[0].metadata) 99 self.assertEqual(metadata, events[2].summary.value[0].metadata) 100 self.assertEqual(metadata, events[3].summary.value[0].metadata) 101 102 def testWrite_name(self): 103 @def_function.function 104 def f(): 105 output = summary_ops.write('tag', 42, step=12, name='anonymous') 106 self.assertTrue(output.name.startswith('anonymous')) 107 f() 108 109 def testWrite_ndarray(self): 110 logdir = self.get_temp_dir() 111 with context.eager_mode(): 112 with summary_ops.create_file_writer_v2(logdir).as_default(): 113 summary_ops.write('tag', [[1, 2], [3, 4]], step=12) 114 events = events_from_logdir(logdir) 115 value = events[1].summary.value[0] 116 self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value)) 117 118 def testWrite_tensor(self): 119 logdir = self.get_temp_dir() 120 with context.eager_mode(): 121 t = constant_op.constant([[1, 2], [3, 4]]) 122 with summary_ops.create_file_writer_v2(logdir).as_default(): 123 summary_ops.write('tag', t, step=12) 124 expected = t.numpy() 125 events = events_from_logdir(logdir) 126 value = events[1].summary.value[0] 127 self.assertAllEqual(expected, to_numpy(value)) 128 129 def testWrite_tensor_fromFunction(self): 130 logdir = self.get_temp_dir() 131 with context.eager_mode(): 132 writer = summary_ops.create_file_writer_v2(logdir) 133 @def_function.function 134 def f(t): 135 with writer.as_default(): 136 summary_ops.write('tag', t, step=12) 137 t = constant_op.constant([[1, 2], [3, 4]]) 138 f(t) 139 expected = t.numpy() 140 events = events_from_logdir(logdir) 141 value = events[1].summary.value[0] 142 self.assertAllEqual(expected, to_numpy(value)) 143 144 def testWrite_stringTensor(self): 145 logdir = self.get_temp_dir() 146 with context.eager_mode(): 147 with summary_ops.create_file_writer_v2(logdir).as_default(): 148 summary_ops.write('tag', [b'foo', b'bar'], step=12) 149 events = events_from_logdir(logdir) 150 value = events[1].summary.value[0] 151 self.assertAllEqual([b'foo', b'bar'], to_numpy(value)) 152 153 @test_util.run_gpu_only 154 def testWrite_gpuDeviceContext(self): 155 logdir = self.get_temp_dir() 156 with context.eager_mode(): 157 with summary_ops.create_file_writer(logdir).as_default(): 158 with ops.device('/GPU:0'): 159 value = constant_op.constant(42.0) 160 step = constant_op.constant(12, dtype=dtypes.int64) 161 summary_ops.write('tag', value, step=step).numpy() 162 empty_metadata = summary_pb2.SummaryMetadata() 163 events = events_from_logdir(logdir) 164 self.assertEqual(2, len(events)) 165 self.assertEqual(12, events[1].step) 166 self.assertEqual(42, to_numpy(events[1].summary.value[0])) 167 self.assertEqual(empty_metadata, events[1].summary.value[0].metadata) 168 169 @test_util.also_run_as_tf_function 170 def testWrite_noDefaultWriter(self): 171 # Use assertAllEqual instead of assertFalse since it works in a defun. 172 self.assertAllEqual(False, summary_ops.write('tag', 42, step=0)) 173 174 @test_util.also_run_as_tf_function 175 def testWrite_noStep_okayIfAlsoNoDefaultWriter(self): 176 # Use assertAllEqual instead of assertFalse since it works in a defun. 177 self.assertAllEqual(False, summary_ops.write('tag', 42)) 178 179 @test_util.also_run_as_tf_function 180 def testWrite_noStep(self): 181 logdir = self.get_temp_dir() 182 with summary_ops.create_file_writer(logdir).as_default(): 183 with self.assertRaisesRegex(ValueError, 'No step set'): 184 summary_ops.write('tag', 42) 185 186 def testWrite_usingDefaultStep(self): 187 logdir = self.get_temp_dir() 188 try: 189 with context.eager_mode(): 190 with summary_ops.create_file_writer(logdir).as_default(): 191 summary_ops.set_step(1) 192 summary_ops.write('tag', 1.0) 193 summary_ops.set_step(2) 194 summary_ops.write('tag', 1.0) 195 mystep = variables.Variable(10, dtype=dtypes.int64) 196 summary_ops.set_step(mystep) 197 summary_ops.write('tag', 1.0) 198 mystep.assign_add(1) 199 summary_ops.write('tag', 1.0) 200 events = events_from_logdir(logdir) 201 self.assertEqual(5, len(events)) 202 self.assertEqual(1, events[1].step) 203 self.assertEqual(2, events[2].step) 204 self.assertEqual(10, events[3].step) 205 self.assertEqual(11, events[4].step) 206 finally: 207 # Reset to default state for other tests. 208 summary_ops.set_step(None) 209 210 def testWrite_usingDefaultStepConstant_fromFunction(self): 211 logdir = self.get_temp_dir() 212 try: 213 with context.eager_mode(): 214 writer = summary_ops.create_file_writer(logdir) 215 @def_function.function 216 def f(): 217 with writer.as_default(): 218 summary_ops.write('tag', 1.0) 219 summary_ops.set_step(1) 220 f() 221 summary_ops.set_step(2) 222 f() 223 events = events_from_logdir(logdir) 224 self.assertEqual(3, len(events)) 225 self.assertEqual(1, events[1].step) 226 # The step value will still be 1 because the value was captured at the 227 # time the function was first traced. 228 self.assertEqual(1, events[2].step) 229 finally: 230 # Reset to default state for other tests. 231 summary_ops.set_step(None) 232 233 def testWrite_usingDefaultStepVariable_fromFunction(self): 234 logdir = self.get_temp_dir() 235 try: 236 with context.eager_mode(): 237 writer = summary_ops.create_file_writer(logdir) 238 @def_function.function 239 def f(): 240 with writer.as_default(): 241 summary_ops.write('tag', 1.0) 242 mystep = variables.Variable(0, dtype=dtypes.int64) 243 summary_ops.set_step(mystep) 244 f() 245 mystep.assign_add(1) 246 f() 247 mystep.assign(10) 248 f() 249 events = events_from_logdir(logdir) 250 self.assertEqual(4, len(events)) 251 self.assertEqual(0, events[1].step) 252 self.assertEqual(1, events[2].step) 253 self.assertEqual(10, events[3].step) 254 finally: 255 # Reset to default state for other tests. 256 summary_ops.set_step(None) 257 258 def testWrite_usingDefaultStepConstant_fromLegacyGraph(self): 259 logdir = self.get_temp_dir() 260 try: 261 with context.graph_mode(): 262 writer = summary_ops.create_file_writer(logdir) 263 summary_ops.set_step(1) 264 with writer.as_default(): 265 write_op = summary_ops.write('tag', 1.0) 266 summary_ops.set_step(2) 267 with self.cached_session() as sess: 268 sess.run(writer.init()) 269 sess.run(write_op) 270 sess.run(write_op) 271 sess.run(writer.flush()) 272 events = events_from_logdir(logdir) 273 self.assertEqual(3, len(events)) 274 self.assertEqual(1, events[1].step) 275 # The step value will still be 1 because the value was captured at the 276 # time the graph was constructed. 277 self.assertEqual(1, events[2].step) 278 finally: 279 # Reset to default state for other tests. 280 summary_ops.set_step(None) 281 282 def testWrite_usingDefaultStepVariable_fromLegacyGraph(self): 283 logdir = self.get_temp_dir() 284 try: 285 with context.graph_mode(): 286 writer = summary_ops.create_file_writer(logdir) 287 mystep = variables.Variable(0, dtype=dtypes.int64) 288 summary_ops.set_step(mystep) 289 with writer.as_default(): 290 write_op = summary_ops.write('tag', 1.0) 291 first_assign_op = mystep.assign_add(1) 292 second_assign_op = mystep.assign(10) 293 with self.cached_session() as sess: 294 sess.run(writer.init()) 295 sess.run(mystep.initializer) 296 sess.run(write_op) 297 sess.run(first_assign_op) 298 sess.run(write_op) 299 sess.run(second_assign_op) 300 sess.run(write_op) 301 sess.run(writer.flush()) 302 events = events_from_logdir(logdir) 303 self.assertEqual(4, len(events)) 304 self.assertEqual(0, events[1].step) 305 self.assertEqual(1, events[2].step) 306 self.assertEqual(10, events[3].step) 307 finally: 308 # Reset to default state for other tests. 309 summary_ops.set_step(None) 310 311 def testWrite_recordIf_constant(self): 312 logdir = self.get_temp_dir() 313 with context.eager_mode(): 314 with summary_ops.create_file_writer_v2(logdir).as_default(): 315 self.assertTrue(summary_ops.write('default', 1, step=0)) 316 with summary_ops.record_if(True): 317 self.assertTrue(summary_ops.write('set_on', 1, step=0)) 318 with summary_ops.record_if(False): 319 self.assertFalse(summary_ops.write('set_off', 1, step=0)) 320 events = events_from_logdir(logdir) 321 self.assertEqual(3, len(events)) 322 self.assertEqual('default', events[1].summary.value[0].tag) 323 self.assertEqual('set_on', events[2].summary.value[0].tag) 324 325 def testWrite_recordIf_constant_fromFunction(self): 326 logdir = self.get_temp_dir() 327 with context.eager_mode(): 328 writer = summary_ops.create_file_writer_v2(logdir) 329 @def_function.function 330 def f(): 331 with writer.as_default(): 332 # Use assertAllEqual instead of assertTrue since it works in a defun. 333 self.assertAllEqual(summary_ops.write('default', 1, step=0), True) 334 with summary_ops.record_if(True): 335 self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True) 336 with summary_ops.record_if(False): 337 self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False) 338 f() 339 events = events_from_logdir(logdir) 340 self.assertEqual(3, len(events)) 341 self.assertEqual('default', events[1].summary.value[0].tag) 342 self.assertEqual('set_on', events[2].summary.value[0].tag) 343 344 def testWrite_recordIf_callable(self): 345 logdir = self.get_temp_dir() 346 with context.eager_mode(): 347 step = variables.Variable(-1, dtype=dtypes.int64) 348 def record_fn(): 349 step.assign_add(1) 350 return int(step % 2) == 0 351 with summary_ops.create_file_writer_v2(logdir).as_default(): 352 with summary_ops.record_if(record_fn): 353 self.assertTrue(summary_ops.write('tag', 1, step=step)) 354 self.assertFalse(summary_ops.write('tag', 1, step=step)) 355 self.assertTrue(summary_ops.write('tag', 1, step=step)) 356 self.assertFalse(summary_ops.write('tag', 1, step=step)) 357 self.assertTrue(summary_ops.write('tag', 1, step=step)) 358 events = events_from_logdir(logdir) 359 self.assertEqual(4, len(events)) 360 self.assertEqual(0, events[1].step) 361 self.assertEqual(2, events[2].step) 362 self.assertEqual(4, events[3].step) 363 364 def testWrite_recordIf_callable_fromFunction(self): 365 logdir = self.get_temp_dir() 366 with context.eager_mode(): 367 writer = summary_ops.create_file_writer_v2(logdir) 368 step = variables.Variable(-1, dtype=dtypes.int64) 369 @def_function.function 370 def record_fn(): 371 step.assign_add(1) 372 return math_ops.equal(step % 2, 0) 373 @def_function.function 374 def f(): 375 with writer.as_default(): 376 with summary_ops.record_if(record_fn): 377 return [ 378 summary_ops.write('tag', 1, step=step), 379 summary_ops.write('tag', 1, step=step), 380 summary_ops.write('tag', 1, step=step)] 381 self.assertAllEqual(f(), [True, False, True]) 382 self.assertAllEqual(f(), [False, True, False]) 383 events = events_from_logdir(logdir) 384 self.assertEqual(4, len(events)) 385 self.assertEqual(0, events[1].step) 386 self.assertEqual(2, events[2].step) 387 self.assertEqual(4, events[3].step) 388 389 def testWrite_recordIf_tensorInput_fromFunction(self): 390 logdir = self.get_temp_dir() 391 with context.eager_mode(): 392 writer = summary_ops.create_file_writer_v2(logdir) 393 @def_function.function(input_signature=[ 394 tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)]) 395 def f(step): 396 with writer.as_default(): 397 with summary_ops.record_if(math_ops.equal(step % 2, 0)): 398 return summary_ops.write('tag', 1, step=step) 399 self.assertTrue(f(0)) 400 self.assertFalse(f(1)) 401 self.assertTrue(f(2)) 402 self.assertFalse(f(3)) 403 self.assertTrue(f(4)) 404 events = events_from_logdir(logdir) 405 self.assertEqual(4, len(events)) 406 self.assertEqual(0, events[1].step) 407 self.assertEqual(2, events[2].step) 408 self.assertEqual(4, events[3].step) 409 410 @test_util.also_run_as_tf_function 411 def testGetSetStep(self): 412 try: 413 self.assertIsNone(summary_ops.get_step()) 414 summary_ops.set_step(1) 415 # Use assertAllEqual instead of assertEqual since it works in a defun. 416 self.assertAllEqual(1, summary_ops.get_step()) 417 summary_ops.set_step(constant_op.constant(2)) 418 self.assertAllEqual(2, summary_ops.get_step()) 419 finally: 420 # Reset to default state for other tests. 421 summary_ops.set_step(None) 422 423 def testGetSetStep_variable(self): 424 with context.eager_mode(): 425 try: 426 mystep = variables.Variable(0) 427 summary_ops.set_step(mystep) 428 self.assertAllEqual(0, summary_ops.get_step().read_value()) 429 mystep.assign_add(1) 430 self.assertAllEqual(1, summary_ops.get_step().read_value()) 431 # Check that set_step() properly maintains reference to variable. 432 del mystep 433 self.assertAllEqual(1, summary_ops.get_step().read_value()) 434 summary_ops.get_step().assign_add(1) 435 self.assertAllEqual(2, summary_ops.get_step().read_value()) 436 finally: 437 # Reset to default state for other tests. 438 summary_ops.set_step(None) 439 440 def testGetSetStep_variable_fromFunction(self): 441 with context.eager_mode(): 442 try: 443 @def_function.function 444 def set_step(step): 445 summary_ops.set_step(step) 446 return summary_ops.get_step() 447 @def_function.function 448 def get_and_increment(): 449 summary_ops.get_step().assign_add(1) 450 return summary_ops.get_step() 451 mystep = variables.Variable(0) 452 self.assertAllEqual(0, set_step(mystep)) 453 self.assertAllEqual(0, summary_ops.get_step().read_value()) 454 self.assertAllEqual(1, get_and_increment()) 455 self.assertAllEqual(2, get_and_increment()) 456 # Check that set_step() properly maintains reference to variable. 457 del mystep 458 self.assertAllEqual(3, get_and_increment()) 459 finally: 460 # Reset to default state for other tests. 461 summary_ops.set_step(None) 462 463 @test_util.also_run_as_tf_function 464 def testSummaryScope(self): 465 with summary_ops.summary_scope('foo') as (tag, scope): 466 self.assertEqual('foo', tag) 467 self.assertEqual('foo/', scope) 468 with summary_ops.summary_scope('bar') as (tag, scope): 469 self.assertEqual('foo/bar', tag) 470 self.assertEqual('foo/bar/', scope) 471 with summary_ops.summary_scope('with/slash') as (tag, scope): 472 self.assertEqual('foo/with/slash', tag) 473 self.assertEqual('foo/with/slash/', scope) 474 with ops.name_scope(None): 475 with summary_ops.summary_scope('unnested') as (tag, scope): 476 self.assertEqual('unnested', tag) 477 self.assertEqual('unnested/', scope) 478 479 @test_util.also_run_as_tf_function 480 def testSummaryScope_defaultName(self): 481 with summary_ops.summary_scope(None) as (tag, scope): 482 self.assertEqual('summary', tag) 483 self.assertEqual('summary/', scope) 484 with summary_ops.summary_scope(None, 'backup') as (tag, scope): 485 self.assertEqual('backup', tag) 486 self.assertEqual('backup/', scope) 487 488 @test_util.also_run_as_tf_function 489 def testSummaryScope_handlesCharactersIllegalForScope(self): 490 with summary_ops.summary_scope('f?o?o') as (tag, scope): 491 self.assertEqual('f?o?o', tag) 492 self.assertEqual('foo/', scope) 493 # If all characters aren't legal for a scope name, use default name. 494 with summary_ops.summary_scope('???', 'backup') as (tag, scope): 495 self.assertEqual('???', tag) 496 self.assertEqual('backup/', scope) 497 498 @test_util.also_run_as_tf_function 499 def testSummaryScope_nameNotUniquifiedForTag(self): 500 constant_op.constant(0, name='foo') 501 with summary_ops.summary_scope('foo') as (tag, _): 502 self.assertEqual('foo', tag) 503 with summary_ops.summary_scope('foo') as (tag, _): 504 self.assertEqual('foo', tag) 505 with ops.name_scope('with'): 506 constant_op.constant(0, name='slash') 507 with summary_ops.summary_scope('with/slash') as (tag, _): 508 self.assertEqual('with/slash', tag) 509 510 511class SummaryWriterTest(test_util.TensorFlowTestCase): 512 513 def testCreate_withInitAndClose(self): 514 logdir = self.get_temp_dir() 515 with context.eager_mode(): 516 writer = summary_ops.create_file_writer_v2( 517 logdir, max_queue=1000, flush_millis=1000000) 518 get_total = lambda: len(events_from_logdir(logdir)) 519 self.assertEqual(1, get_total()) # file_version Event 520 # Calling init() again while writer is open has no effect 521 writer.init() 522 self.assertEqual(1, get_total()) 523 with writer.as_default(): 524 summary_ops.write('tag', 1, step=0) 525 self.assertEqual(1, get_total()) 526 # Calling .close() should do an implicit flush 527 writer.close() 528 self.assertEqual(2, get_total()) 529 530 def testCreate_fromFunction(self): 531 logdir = self.get_temp_dir() 532 @def_function.function 533 def f(): 534 # Returned SummaryWriter must be stored in a non-local variable so it 535 # lives throughout the function execution. 536 if not hasattr(f, 'writer'): 537 f.writer = summary_ops.create_file_writer_v2(logdir) 538 with context.eager_mode(): 539 f() 540 event_files = gfile.Glob(os.path.join(logdir, '*')) 541 self.assertEqual(1, len(event_files)) 542 543 def testCreate_graphTensorArgument_raisesError(self): 544 logdir = self.get_temp_dir() 545 with context.graph_mode(): 546 logdir_tensor = constant_op.constant(logdir) 547 with context.eager_mode(): 548 with self.assertRaisesRegex( 549 ValueError, 'Invalid graph Tensor argument.*logdir'): 550 summary_ops.create_file_writer_v2(logdir_tensor) 551 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 552 553 def testCreate_fromFunction_graphTensorArgument_raisesError(self): 554 logdir = self.get_temp_dir() 555 @def_function.function 556 def f(): 557 summary_ops.create_file_writer_v2(constant_op.constant(logdir)) 558 with context.eager_mode(): 559 with self.assertRaisesRegex( 560 ValueError, 'Invalid graph Tensor argument.*logdir'): 561 f() 562 self.assertEmpty(gfile.Glob(os.path.join(logdir, '*'))) 563 564 def testCreate_fromFunction_unpersistedResource_raisesError(self): 565 logdir = self.get_temp_dir() 566 @def_function.function 567 def f(): 568 with summary_ops.create_file_writer_v2(logdir).as_default(): 569 pass # Calling .as_default() is enough to indicate use. 570 with context.eager_mode(): 571 # TODO(nickfelt): change this to a better error 572 with self.assertRaisesRegex( 573 errors.NotFoundError, 'Resource.*does not exist'): 574 f() 575 # Even though we didn't use it, an event file will have been created. 576 self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*')))) 577 578 def testCreate_immediateSetAsDefault_retainsReference(self): 579 logdir = self.get_temp_dir() 580 try: 581 with context.eager_mode(): 582 summary_ops.create_file_writer_v2(logdir).set_as_default() 583 summary_ops.flush() 584 finally: 585 # Ensure we clean up no matter how the test executes. 586 context.context().summary_writer_resource = None 587 588 def testCreate_immediateAsDefault_retainsReference(self): 589 logdir = self.get_temp_dir() 590 with context.eager_mode(): 591 with summary_ops.create_file_writer_v2(logdir).as_default(): 592 summary_ops.flush() 593 594 def testNoSharing(self): 595 # Two writers with the same logdir should not share state. 596 logdir = self.get_temp_dir() 597 with context.eager_mode(): 598 writer1 = summary_ops.create_file_writer_v2(logdir) 599 with writer1.as_default(): 600 summary_ops.write('tag', 1, step=1) 601 event_files = gfile.Glob(os.path.join(logdir, '*')) 602 self.assertEqual(1, len(event_files)) 603 file1 = event_files[0] 604 605 writer2 = summary_ops.create_file_writer_v2(logdir) 606 with writer2.as_default(): 607 summary_ops.write('tag', 1, step=2) 608 event_files = gfile.Glob(os.path.join(logdir, '*')) 609 self.assertEqual(2, len(event_files)) 610 event_files.remove(file1) 611 file2 = event_files[0] 612 613 # Extra writes to ensure interleaved usage works. 614 with writer1.as_default(): 615 summary_ops.write('tag', 1, step=1) 616 with writer2.as_default(): 617 summary_ops.write('tag', 1, step=2) 618 619 events = iter(events_from_file(file1)) 620 self.assertEqual('brain.Event:2', next(events).file_version) 621 self.assertEqual(1, next(events).step) 622 self.assertEqual(1, next(events).step) 623 self.assertRaises(StopIteration, lambda: next(events)) 624 events = iter(events_from_file(file2)) 625 self.assertEqual('brain.Event:2', next(events).file_version) 626 self.assertEqual(2, next(events).step) 627 self.assertEqual(2, next(events).step) 628 self.assertRaises(StopIteration, lambda: next(events)) 629 630 def testNoSharing_fromFunction(self): 631 logdir = self.get_temp_dir() 632 @def_function.function 633 def f1(): 634 if not hasattr(f1, 'writer'): 635 f1.writer = summary_ops.create_file_writer_v2(logdir) 636 with f1.writer.as_default(): 637 summary_ops.write('tag', 1, step=1) 638 @def_function.function 639 def f2(): 640 if not hasattr(f2, 'writer'): 641 f2.writer = summary_ops.create_file_writer_v2(logdir) 642 with f2.writer.as_default(): 643 summary_ops.write('tag', 1, step=2) 644 with context.eager_mode(): 645 f1() 646 event_files = gfile.Glob(os.path.join(logdir, '*')) 647 self.assertEqual(1, len(event_files)) 648 file1 = event_files[0] 649 650 f2() 651 event_files = gfile.Glob(os.path.join(logdir, '*')) 652 self.assertEqual(2, len(event_files)) 653 event_files.remove(file1) 654 file2 = event_files[0] 655 656 # Extra writes to ensure interleaved usage works. 657 f1() 658 f2() 659 660 events = iter(events_from_file(file1)) 661 self.assertEqual('brain.Event:2', next(events).file_version) 662 self.assertEqual(1, next(events).step) 663 self.assertEqual(1, next(events).step) 664 self.assertRaises(StopIteration, lambda: next(events)) 665 events = iter(events_from_file(file2)) 666 self.assertEqual('brain.Event:2', next(events).file_version) 667 self.assertEqual(2, next(events).step) 668 self.assertEqual(2, next(events).step) 669 self.assertRaises(StopIteration, lambda: next(events)) 670 671 def testMaxQueue(self): 672 logdir = self.get_temp_dir() 673 with context.eager_mode(): 674 with summary_ops.create_file_writer_v2( 675 logdir, max_queue=1, flush_millis=999999).as_default(): 676 get_total = lambda: len(events_from_logdir(logdir)) 677 # Note: First tf.Event is always file_version. 678 self.assertEqual(1, get_total()) 679 summary_ops.write('tag', 1, step=0) 680 self.assertEqual(1, get_total()) 681 # Should flush after second summary since max_queue = 1 682 summary_ops.write('tag', 1, step=0) 683 self.assertEqual(3, get_total()) 684 685 def testWriterFlush(self): 686 logdir = self.get_temp_dir() 687 get_total = lambda: len(events_from_logdir(logdir)) 688 with context.eager_mode(): 689 writer = summary_ops.create_file_writer_v2( 690 logdir, max_queue=1000, flush_millis=1000000) 691 self.assertEqual(1, get_total()) # file_version Event 692 with writer.as_default(): 693 summary_ops.write('tag', 1, step=0) 694 self.assertEqual(1, get_total()) 695 writer.flush() 696 self.assertEqual(2, get_total()) 697 summary_ops.write('tag', 1, step=0) 698 self.assertEqual(2, get_total()) 699 # Exiting the "as_default()" should do an implicit flush 700 self.assertEqual(3, get_total()) 701 702 def testFlushFunction(self): 703 logdir = self.get_temp_dir() 704 with context.eager_mode(): 705 writer = summary_ops.create_file_writer_v2( 706 logdir, max_queue=999999, flush_millis=999999) 707 with writer.as_default(): 708 get_total = lambda: len(events_from_logdir(logdir)) 709 # Note: First tf.Event is always file_version. 710 self.assertEqual(1, get_total()) 711 summary_ops.write('tag', 1, step=0) 712 summary_ops.write('tag', 1, step=0) 713 self.assertEqual(1, get_total()) 714 summary_ops.flush() 715 self.assertEqual(3, get_total()) 716 # Test "writer" parameter 717 summary_ops.write('tag', 1, step=0) 718 self.assertEqual(3, get_total()) 719 summary_ops.flush(writer=writer) 720 self.assertEqual(4, get_total()) 721 summary_ops.write('tag', 1, step=0) 722 self.assertEqual(4, get_total()) 723 summary_ops.flush(writer=writer._resource) # pylint:disable=protected-access 724 self.assertEqual(5, get_total()) 725 726 @test_util.assert_no_new_pyobjects_executing_eagerly 727 def testEagerMemory(self): 728 logdir = self.get_temp_dir() 729 with summary_ops.create_file_writer_v2(logdir).as_default(): 730 summary_ops.write('tag', 1, step=0) 731 732 def testClose_preventsLaterUse(self): 733 logdir = self.get_temp_dir() 734 with context.eager_mode(): 735 writer = summary_ops.create_file_writer_v2(logdir) 736 writer.close() 737 writer.close() # redundant close() is a no-op 738 writer.flush() # redundant flush() is a no-op 739 with self.assertRaisesRegex(RuntimeError, 'already closed'): 740 writer.init() 741 with self.assertRaisesRegex(RuntimeError, 'already closed'): 742 with writer.as_default(): 743 self.fail('should not get here') 744 with self.assertRaisesRegex(RuntimeError, 'already closed'): 745 writer.set_as_default() 746 747 def testClose_closesOpenFile(self): 748 try: 749 import psutil # pylint: disable=g-import-not-at-top 750 except ImportError: 751 raise unittest.SkipTest('test requires psutil') 752 proc = psutil.Process() 753 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 754 logdir = self.get_temp_dir() 755 with context.eager_mode(): 756 writer = summary_ops.create_file_writer_v2(logdir) 757 files = gfile.Glob(os.path.join(logdir, '*')) 758 self.assertEqual(1, len(files)) 759 eventfile = files[0] 760 self.assertIn(eventfile, get_open_filenames()) 761 writer.close() 762 self.assertNotIn(eventfile, get_open_filenames()) 763 764 def testDereference_closesOpenFile(self): 765 try: 766 import psutil # pylint: disable=g-import-not-at-top 767 except ImportError: 768 raise unittest.SkipTest('test requires psutil') 769 proc = psutil.Process() 770 get_open_filenames = lambda: set(info[0] for info in proc.open_files()) 771 logdir = self.get_temp_dir() 772 with context.eager_mode(): 773 writer = summary_ops.create_file_writer_v2(logdir) 774 files = gfile.Glob(os.path.join(logdir, '*')) 775 self.assertEqual(1, len(files)) 776 eventfile = files[0] 777 self.assertIn(eventfile, get_open_filenames()) 778 del writer 779 self.assertNotIn(eventfile, get_open_filenames()) 780 781 782class SummaryOpsTest(test_util.TensorFlowTestCase): 783 784 def tearDown(self): 785 summary_ops.trace_off() 786 787 def run_metadata(self, *args, **kwargs): 788 assert context.executing_eagerly() 789 logdir = self.get_temp_dir() 790 writer = summary_ops.create_file_writer(logdir) 791 with writer.as_default(): 792 summary_ops.run_metadata(*args, **kwargs) 793 writer.close() 794 events = events_from_logdir(logdir) 795 return events[1] 796 797 def run_metadata_graphs(self, *args, **kwargs): 798 assert context.executing_eagerly() 799 logdir = self.get_temp_dir() 800 writer = summary_ops.create_file_writer(logdir) 801 with writer.as_default(): 802 summary_ops.run_metadata_graphs(*args, **kwargs) 803 writer.close() 804 events = events_from_logdir(logdir) 805 return events[1] 806 807 def create_run_metadata(self): 808 step_stats = step_stats_pb2.StepStats(dev_stats=[ 809 step_stats_pb2.DeviceStepStats( 810 device='cpu:0', 811 node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')]) 812 ]) 813 return config_pb2.RunMetadata( 814 function_graphs=[ 815 config_pb2.RunMetadata.FunctionGraphs( 816 pre_optimization_graph=graph_pb2.GraphDef( 817 node=[node_def_pb2.NodeDef(name='foo')])) 818 ], 819 step_stats=step_stats) 820 821 def keras_model(self, *args, **kwargs): 822 logdir = self.get_temp_dir() 823 writer = summary_ops.create_file_writer(logdir) 824 with writer.as_default(): 825 summary_ops.keras_model(*args, **kwargs) 826 writer.close() 827 events = events_from_logdir(logdir) 828 # The first event contains no summary values. The written content goes to 829 # the second event. 830 return events[1] 831 832 def run_trace(self, f, step=1): 833 assert context.executing_eagerly() 834 logdir = self.get_temp_dir() 835 writer = summary_ops.create_file_writer(logdir) 836 summary_ops.trace_on(graph=True, profiler=False) 837 with writer.as_default(): 838 f() 839 summary_ops.trace_export(name='foo', step=step) 840 writer.close() 841 events = events_from_logdir(logdir) 842 return events[1] 843 844 @test_util.run_v2_only 845 def testRunMetadata_usesNameAsTag(self): 846 meta = config_pb2.RunMetadata() 847 848 with ops.name_scope('foo'): 849 event = self.run_metadata(name='my_name', data=meta, step=1) 850 first_val = event.summary.value[0] 851 852 self.assertEqual('foo/my_name', first_val.tag) 853 854 @test_util.run_v2_only 855 def testRunMetadata_summaryMetadata(self): 856 expected_summary_metadata = """ 857 plugin_data { 858 plugin_name: "graph_run_metadata" 859 content: "1" 860 } 861 """ 862 meta = config_pb2.RunMetadata() 863 event = self.run_metadata(name='my_name', data=meta, step=1) 864 actual_summary_metadata = event.summary.value[0].metadata 865 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 866 867 @test_util.run_v2_only 868 def testRunMetadata_wholeRunMetadata(self): 869 expected_run_metadata = """ 870 step_stats { 871 dev_stats { 872 device: "cpu:0" 873 node_stats { 874 node_name: "hello" 875 } 876 } 877 } 878 function_graphs { 879 pre_optimization_graph { 880 node { 881 name: "foo" 882 } 883 } 884 } 885 """ 886 meta = self.create_run_metadata() 887 event = self.run_metadata(name='my_name', data=meta, step=1) 888 first_val = event.summary.value[0] 889 890 actual_run_metadata = config_pb2.RunMetadata.FromString( 891 first_val.tensor.string_val[0]) 892 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 893 894 @test_util.run_v2_only 895 def testRunMetadata_usesDefaultStep(self): 896 meta = config_pb2.RunMetadata() 897 try: 898 summary_ops.set_step(42) 899 event = self.run_metadata(name='my_name', data=meta) 900 self.assertEqual(42, event.step) 901 finally: 902 # Reset to default state for other tests. 903 summary_ops.set_step(None) 904 905 @test_util.run_v2_only 906 def testRunMetadataGraph_usesNameAsTag(self): 907 meta = config_pb2.RunMetadata() 908 909 with ops.name_scope('foo'): 910 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 911 first_val = event.summary.value[0] 912 913 self.assertEqual('foo/my_name', first_val.tag) 914 915 @test_util.run_v2_only 916 def testRunMetadataGraph_summaryMetadata(self): 917 expected_summary_metadata = """ 918 plugin_data { 919 plugin_name: "graph_run_metadata_graph" 920 content: "1" 921 } 922 """ 923 meta = config_pb2.RunMetadata() 924 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 925 actual_summary_metadata = event.summary.value[0].metadata 926 self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata) 927 928 @test_util.run_v2_only 929 def testRunMetadataGraph_runMetadataFragment(self): 930 expected_run_metadata = """ 931 function_graphs { 932 pre_optimization_graph { 933 node { 934 name: "foo" 935 } 936 } 937 } 938 """ 939 meta = self.create_run_metadata() 940 941 event = self.run_metadata_graphs(name='my_name', data=meta, step=1) 942 first_val = event.summary.value[0] 943 944 actual_run_metadata = config_pb2.RunMetadata.FromString( 945 first_val.tensor.string_val[0]) 946 self.assertProtoEquals(expected_run_metadata, actual_run_metadata) 947 948 @test_util.run_v2_only 949 def testRunMetadataGraph_usesDefaultStep(self): 950 meta = config_pb2.RunMetadata() 951 try: 952 summary_ops.set_step(42) 953 event = self.run_metadata_graphs(name='my_name', data=meta) 954 self.assertEqual(42, event.step) 955 finally: 956 # Reset to default state for other tests. 957 summary_ops.set_step(None) 958 959 @test_util.run_v2_only 960 def testKerasModel(self): 961 model = Sequential( 962 [Dense(10, input_shape=(100,)), 963 Activation('relu', name='my_relu')]) 964 event = self.keras_model(name='my_name', data=model, step=1) 965 first_val = event.summary.value[0] 966 self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode()) 967 968 @test_util.run_v2_only 969 def testKerasModel_usesDefaultStep(self): 970 model = Sequential( 971 [Dense(10, input_shape=(100,)), 972 Activation('relu', name='my_relu')]) 973 try: 974 summary_ops.set_step(42) 975 event = self.keras_model(name='my_name', data=model) 976 self.assertEqual(42, event.step) 977 finally: 978 # Reset to default state for other tests. 979 summary_ops.set_step(None) 980 981 @test_util.run_v2_only 982 def testKerasModel_subclass(self): 983 984 class SimpleSubclass(Model): 985 986 def __init__(self): 987 super(SimpleSubclass, self).__init__(name='subclass') 988 self.dense = Dense(10, input_shape=(100,)) 989 self.activation = Activation('relu', name='my_relu') 990 991 def call(self, inputs): 992 x = self.dense(inputs) 993 return self.activation(x) 994 995 model = SimpleSubclass() 996 with test.mock.patch.object(logging, 'warn') as mock_log: 997 self.assertFalse( 998 summary_ops.keras_model(name='my_name', data=model, step=1)) 999 self.assertRegexpMatches( 1000 str(mock_log.call_args), 'Model failed to serialize as JSON.') 1001 1002 @test_util.run_v2_only 1003 def testKerasModel_otherExceptions(self): 1004 model = Sequential() 1005 1006 with test.mock.patch.object(model, 'to_json') as mock_to_json: 1007 with test.mock.patch.object(logging, 'warn') as mock_log: 1008 mock_to_json.side_effect = Exception('oops') 1009 self.assertFalse( 1010 summary_ops.keras_model(name='my_name', data=model, step=1)) 1011 self.assertRegexpMatches( 1012 str(mock_log.call_args), 1013 'Model failed to serialize as JSON. Ignoring... oops') 1014 1015 @test_util.run_v2_only 1016 def testTrace(self): 1017 1018 @def_function.function 1019 def f(): 1020 x = constant_op.constant(2) 1021 y = constant_op.constant(3) 1022 return x**y 1023 1024 event = self.run_trace(f) 1025 1026 first_val = event.summary.value[0] 1027 actual_run_metadata = config_pb2.RunMetadata.FromString( 1028 first_val.tensor.string_val[0]) 1029 1030 # Content of function_graphs is large and, for instance, device can change. 1031 self.assertTrue(hasattr(actual_run_metadata, 'function_graphs')) 1032 1033 @test_util.run_v2_only 1034 def testTrace_cannotEnableTraceInFunction(self): 1035 1036 @def_function.function 1037 def f(): 1038 summary_ops.trace_on(graph=True, profiler=False) 1039 x = constant_op.constant(2) 1040 y = constant_op.constant(3) 1041 return x**y 1042 1043 with test.mock.patch.object(logging, 'warn') as mock_log: 1044 f() 1045 self.assertRegexpMatches( 1046 str(mock_log.call_args), 'Cannot enable trace inside a tf.function.') 1047 1048 @test_util.run_v2_only 1049 def testTrace_cannotEnableTraceInGraphMode(self): 1050 with test.mock.patch.object(logging, 'warn') as mock_log: 1051 with context.graph_mode(): 1052 summary_ops.trace_on(graph=True, profiler=False) 1053 self.assertRegexpMatches( 1054 str(mock_log.call_args), 'Must enable trace in eager mode.') 1055 1056 @test_util.run_v2_only 1057 def testTrace_cannotExportTraceWithoutTrace(self): 1058 with six.assertRaisesRegex(self, ValueError, 1059 'Must enable trace before export.'): 1060 summary_ops.trace_export(name='foo', step=1) 1061 1062 @test_util.run_v2_only 1063 def testTrace_cannotExportTraceInFunction(self): 1064 summary_ops.trace_on(graph=True, profiler=False) 1065 1066 @def_function.function 1067 def f(): 1068 x = constant_op.constant(2) 1069 y = constant_op.constant(3) 1070 summary_ops.trace_export(name='foo', step=1) 1071 return x**y 1072 1073 with test.mock.patch.object(logging, 'warn') as mock_log: 1074 f() 1075 self.assertRegexpMatches( 1076 str(mock_log.call_args), 1077 'Cannot export trace inside a tf.function.') 1078 1079 @test_util.run_v2_only 1080 def testTrace_cannotExportTraceInGraphMode(self): 1081 with test.mock.patch.object(logging, 'warn') as mock_log: 1082 with context.graph_mode(): 1083 summary_ops.trace_export(name='foo', step=1) 1084 self.assertRegexpMatches( 1085 str(mock_log.call_args), 1086 'Can only export trace while executing eagerly.') 1087 1088 @test_util.run_v2_only 1089 def testTrace_usesDefaultStep(self): 1090 1091 @def_function.function 1092 def f(): 1093 x = constant_op.constant(2) 1094 y = constant_op.constant(3) 1095 return x**y 1096 1097 try: 1098 summary_ops.set_step(42) 1099 event = self.run_trace(f, step=None) 1100 self.assertEqual(42, event.step) 1101 finally: 1102 # Reset to default state for other tests. 1103 summary_ops.set_step(None) 1104 1105 1106def events_from_file(filepath): 1107 """Returns all events in a single event file. 1108 1109 Args: 1110 filepath: Path to the event file. 1111 1112 Returns: 1113 A list of all tf.Event protos in the event file. 1114 """ 1115 records = list(tf_record.tf_record_iterator(filepath)) 1116 result = [] 1117 for r in records: 1118 event = event_pb2.Event() 1119 event.ParseFromString(r) 1120 result.append(event) 1121 return result 1122 1123 1124def events_from_logdir(logdir): 1125 """Returns all events in the single eventfile in logdir. 1126 1127 Args: 1128 logdir: The directory in which the single event file is sought. 1129 1130 Returns: 1131 A list of all tf.Event protos from the single event file. 1132 1133 Raises: 1134 AssertionError: If logdir does not contain exactly one file. 1135 """ 1136 assert gfile.Exists(logdir) 1137 files = gfile.ListDirectory(logdir) 1138 assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files 1139 return events_from_file(os.path.join(logdir, files[0])) 1140 1141 1142def to_numpy(summary_value): 1143 return tensor_util.MakeNdarray(summary_value.tensor) 1144 1145 1146if __name__ == '__main__': 1147 test.main() 1148