• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for V2 summary ops from summary_ops_v2."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import unittest
23
24import six
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import node_def_pb2
28from tensorflow.core.framework import step_stats_pb2
29from tensorflow.core.framework import summary_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.util import event_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.framework import test_util
41from tensorflow.python.keras.engine.sequential import Sequential
42from tensorflow.python.keras.engine.training import Model
43from tensorflow.python.keras.layers.core import Activation
44from tensorflow.python.keras.layers.core import Dense
45from tensorflow.python.lib.io import tf_record
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import summary_ops_v2 as summary_ops
48from tensorflow.python.ops import variables
49from tensorflow.python.platform import gfile
50from tensorflow.python.platform import test
51from tensorflow.python.platform import tf_logging as logging
52
53
54class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
55
56  def testWrite(self):
57    logdir = self.get_temp_dir()
58    with context.eager_mode():
59      with summary_ops.create_file_writer_v2(logdir).as_default():
60        output = summary_ops.write('tag', 42, step=12)
61        self.assertTrue(output.numpy())
62    events = events_from_logdir(logdir)
63    self.assertEqual(2, len(events))
64    self.assertEqual(12, events[1].step)
65    value = events[1].summary.value[0]
66    self.assertEqual('tag', value.tag)
67    self.assertEqual(42, to_numpy(value))
68
69  def testWrite_fromFunction(self):
70    logdir = self.get_temp_dir()
71    with context.eager_mode():
72      writer = summary_ops.create_file_writer_v2(logdir)
73      @def_function.function
74      def f():
75        with writer.as_default():
76          return summary_ops.write('tag', 42, step=12)
77      output = f()
78      self.assertTrue(output.numpy())
79    events = events_from_logdir(logdir)
80    self.assertEqual(2, len(events))
81    self.assertEqual(12, events[1].step)
82    value = events[1].summary.value[0]
83    self.assertEqual('tag', value.tag)
84    self.assertEqual(42, to_numpy(value))
85
86  def testWrite_metadata(self):
87    logdir = self.get_temp_dir()
88    metadata = summary_pb2.SummaryMetadata()
89    metadata.plugin_data.plugin_name = 'foo'
90    with context.eager_mode():
91      with summary_ops.create_file_writer_v2(logdir).as_default():
92        summary_ops.write('obj', 0, 0, metadata=metadata)
93        summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString())
94        m = constant_op.constant(metadata.SerializeToString())
95        summary_ops.write('string_tensor', 0, 0, metadata=m)
96    events = events_from_logdir(logdir)
97    self.assertEqual(4, len(events))
98    self.assertEqual(metadata, events[1].summary.value[0].metadata)
99    self.assertEqual(metadata, events[2].summary.value[0].metadata)
100    self.assertEqual(metadata, events[3].summary.value[0].metadata)
101
102  def testWrite_name(self):
103    @def_function.function
104    def f():
105      output = summary_ops.write('tag', 42, step=12, name='anonymous')
106      self.assertTrue(output.name.startswith('anonymous'))
107    f()
108
109  def testWrite_ndarray(self):
110    logdir = self.get_temp_dir()
111    with context.eager_mode():
112      with summary_ops.create_file_writer_v2(logdir).as_default():
113        summary_ops.write('tag', [[1, 2], [3, 4]], step=12)
114    events = events_from_logdir(logdir)
115    value = events[1].summary.value[0]
116    self.assertAllEqual([[1, 2], [3, 4]], to_numpy(value))
117
118  def testWrite_tensor(self):
119    logdir = self.get_temp_dir()
120    with context.eager_mode():
121      t = constant_op.constant([[1, 2], [3, 4]])
122      with summary_ops.create_file_writer_v2(logdir).as_default():
123        summary_ops.write('tag', t, step=12)
124      expected = t.numpy()
125    events = events_from_logdir(logdir)
126    value = events[1].summary.value[0]
127    self.assertAllEqual(expected, to_numpy(value))
128
129  def testWrite_tensor_fromFunction(self):
130    logdir = self.get_temp_dir()
131    with context.eager_mode():
132      writer = summary_ops.create_file_writer_v2(logdir)
133      @def_function.function
134      def f(t):
135        with writer.as_default():
136          summary_ops.write('tag', t, step=12)
137      t = constant_op.constant([[1, 2], [3, 4]])
138      f(t)
139      expected = t.numpy()
140    events = events_from_logdir(logdir)
141    value = events[1].summary.value[0]
142    self.assertAllEqual(expected, to_numpy(value))
143
144  def testWrite_stringTensor(self):
145    logdir = self.get_temp_dir()
146    with context.eager_mode():
147      with summary_ops.create_file_writer_v2(logdir).as_default():
148        summary_ops.write('tag', [b'foo', b'bar'], step=12)
149    events = events_from_logdir(logdir)
150    value = events[1].summary.value[0]
151    self.assertAllEqual([b'foo', b'bar'], to_numpy(value))
152
153  @test_util.run_gpu_only
154  def testWrite_gpuDeviceContext(self):
155    logdir = self.get_temp_dir()
156    with context.eager_mode():
157      with summary_ops.create_file_writer(logdir).as_default():
158        with ops.device('/GPU:0'):
159          value = constant_op.constant(42.0)
160          step = constant_op.constant(12, dtype=dtypes.int64)
161          summary_ops.write('tag', value, step=step).numpy()
162    empty_metadata = summary_pb2.SummaryMetadata()
163    events = events_from_logdir(logdir)
164    self.assertEqual(2, len(events))
165    self.assertEqual(12, events[1].step)
166    self.assertEqual(42, to_numpy(events[1].summary.value[0]))
167    self.assertEqual(empty_metadata, events[1].summary.value[0].metadata)
168
169  @test_util.also_run_as_tf_function
170  def testWrite_noDefaultWriter(self):
171    # Use assertAllEqual instead of assertFalse since it works in a defun.
172    self.assertAllEqual(False, summary_ops.write('tag', 42, step=0))
173
174  @test_util.also_run_as_tf_function
175  def testWrite_noStep_okayIfAlsoNoDefaultWriter(self):
176    # Use assertAllEqual instead of assertFalse since it works in a defun.
177    self.assertAllEqual(False, summary_ops.write('tag', 42))
178
179  @test_util.also_run_as_tf_function
180  def testWrite_noStep(self):
181    logdir = self.get_temp_dir()
182    with summary_ops.create_file_writer(logdir).as_default():
183      with self.assertRaisesRegex(ValueError, 'No step set'):
184        summary_ops.write('tag', 42)
185
186  def testWrite_usingDefaultStep(self):
187    logdir = self.get_temp_dir()
188    try:
189      with context.eager_mode():
190        with summary_ops.create_file_writer(logdir).as_default():
191          summary_ops.set_step(1)
192          summary_ops.write('tag', 1.0)
193          summary_ops.set_step(2)
194          summary_ops.write('tag', 1.0)
195          mystep = variables.Variable(10, dtype=dtypes.int64)
196          summary_ops.set_step(mystep)
197          summary_ops.write('tag', 1.0)
198          mystep.assign_add(1)
199          summary_ops.write('tag', 1.0)
200      events = events_from_logdir(logdir)
201      self.assertEqual(5, len(events))
202      self.assertEqual(1, events[1].step)
203      self.assertEqual(2, events[2].step)
204      self.assertEqual(10, events[3].step)
205      self.assertEqual(11, events[4].step)
206    finally:
207      # Reset to default state for other tests.
208      summary_ops.set_step(None)
209
210  def testWrite_usingDefaultStepConstant_fromFunction(self):
211    logdir = self.get_temp_dir()
212    try:
213      with context.eager_mode():
214        writer = summary_ops.create_file_writer(logdir)
215        @def_function.function
216        def f():
217          with writer.as_default():
218            summary_ops.write('tag', 1.0)
219        summary_ops.set_step(1)
220        f()
221        summary_ops.set_step(2)
222        f()
223      events = events_from_logdir(logdir)
224      self.assertEqual(3, len(events))
225      self.assertEqual(1, events[1].step)
226      # The step value will still be 1 because the value was captured at the
227      # time the function was first traced.
228      self.assertEqual(1, events[2].step)
229    finally:
230      # Reset to default state for other tests.
231      summary_ops.set_step(None)
232
233  def testWrite_usingDefaultStepVariable_fromFunction(self):
234    logdir = self.get_temp_dir()
235    try:
236      with context.eager_mode():
237        writer = summary_ops.create_file_writer(logdir)
238        @def_function.function
239        def f():
240          with writer.as_default():
241            summary_ops.write('tag', 1.0)
242        mystep = variables.Variable(0, dtype=dtypes.int64)
243        summary_ops.set_step(mystep)
244        f()
245        mystep.assign_add(1)
246        f()
247        mystep.assign(10)
248        f()
249      events = events_from_logdir(logdir)
250      self.assertEqual(4, len(events))
251      self.assertEqual(0, events[1].step)
252      self.assertEqual(1, events[2].step)
253      self.assertEqual(10, events[3].step)
254    finally:
255      # Reset to default state for other tests.
256      summary_ops.set_step(None)
257
258  def testWrite_usingDefaultStepConstant_fromLegacyGraph(self):
259    logdir = self.get_temp_dir()
260    try:
261      with context.graph_mode():
262        writer = summary_ops.create_file_writer(logdir)
263        summary_ops.set_step(1)
264        with writer.as_default():
265          write_op = summary_ops.write('tag', 1.0)
266        summary_ops.set_step(2)
267        with self.cached_session() as sess:
268          sess.run(writer.init())
269          sess.run(write_op)
270          sess.run(write_op)
271          sess.run(writer.flush())
272      events = events_from_logdir(logdir)
273      self.assertEqual(3, len(events))
274      self.assertEqual(1, events[1].step)
275      # The step value will still be 1 because the value was captured at the
276      # time the graph was constructed.
277      self.assertEqual(1, events[2].step)
278    finally:
279      # Reset to default state for other tests.
280      summary_ops.set_step(None)
281
282  def testWrite_usingDefaultStepVariable_fromLegacyGraph(self):
283    logdir = self.get_temp_dir()
284    try:
285      with context.graph_mode():
286        writer = summary_ops.create_file_writer(logdir)
287        mystep = variables.Variable(0, dtype=dtypes.int64)
288        summary_ops.set_step(mystep)
289        with writer.as_default():
290          write_op = summary_ops.write('tag', 1.0)
291        first_assign_op = mystep.assign_add(1)
292        second_assign_op = mystep.assign(10)
293        with self.cached_session() as sess:
294          sess.run(writer.init())
295          sess.run(mystep.initializer)
296          sess.run(write_op)
297          sess.run(first_assign_op)
298          sess.run(write_op)
299          sess.run(second_assign_op)
300          sess.run(write_op)
301          sess.run(writer.flush())
302      events = events_from_logdir(logdir)
303      self.assertEqual(4, len(events))
304      self.assertEqual(0, events[1].step)
305      self.assertEqual(1, events[2].step)
306      self.assertEqual(10, events[3].step)
307    finally:
308      # Reset to default state for other tests.
309      summary_ops.set_step(None)
310
311  def testWrite_recordIf_constant(self):
312    logdir = self.get_temp_dir()
313    with context.eager_mode():
314      with summary_ops.create_file_writer_v2(logdir).as_default():
315        self.assertTrue(summary_ops.write('default', 1, step=0))
316        with summary_ops.record_if(True):
317          self.assertTrue(summary_ops.write('set_on', 1, step=0))
318        with summary_ops.record_if(False):
319          self.assertFalse(summary_ops.write('set_off', 1, step=0))
320    events = events_from_logdir(logdir)
321    self.assertEqual(3, len(events))
322    self.assertEqual('default', events[1].summary.value[0].tag)
323    self.assertEqual('set_on', events[2].summary.value[0].tag)
324
325  def testWrite_recordIf_constant_fromFunction(self):
326    logdir = self.get_temp_dir()
327    with context.eager_mode():
328      writer = summary_ops.create_file_writer_v2(logdir)
329      @def_function.function
330      def f():
331        with writer.as_default():
332          # Use assertAllEqual instead of assertTrue since it works in a defun.
333          self.assertAllEqual(summary_ops.write('default', 1, step=0), True)
334          with summary_ops.record_if(True):
335            self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True)
336          with summary_ops.record_if(False):
337            self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False)
338      f()
339    events = events_from_logdir(logdir)
340    self.assertEqual(3, len(events))
341    self.assertEqual('default', events[1].summary.value[0].tag)
342    self.assertEqual('set_on', events[2].summary.value[0].tag)
343
344  def testWrite_recordIf_callable(self):
345    logdir = self.get_temp_dir()
346    with context.eager_mode():
347      step = variables.Variable(-1, dtype=dtypes.int64)
348      def record_fn():
349        step.assign_add(1)
350        return int(step % 2) == 0
351      with summary_ops.create_file_writer_v2(logdir).as_default():
352        with summary_ops.record_if(record_fn):
353          self.assertTrue(summary_ops.write('tag', 1, step=step))
354          self.assertFalse(summary_ops.write('tag', 1, step=step))
355          self.assertTrue(summary_ops.write('tag', 1, step=step))
356          self.assertFalse(summary_ops.write('tag', 1, step=step))
357          self.assertTrue(summary_ops.write('tag', 1, step=step))
358    events = events_from_logdir(logdir)
359    self.assertEqual(4, len(events))
360    self.assertEqual(0, events[1].step)
361    self.assertEqual(2, events[2].step)
362    self.assertEqual(4, events[3].step)
363
364  def testWrite_recordIf_callable_fromFunction(self):
365    logdir = self.get_temp_dir()
366    with context.eager_mode():
367      writer = summary_ops.create_file_writer_v2(logdir)
368      step = variables.Variable(-1, dtype=dtypes.int64)
369      @def_function.function
370      def record_fn():
371        step.assign_add(1)
372        return math_ops.equal(step % 2, 0)
373      @def_function.function
374      def f():
375        with writer.as_default():
376          with summary_ops.record_if(record_fn):
377            return [
378                summary_ops.write('tag', 1, step=step),
379                summary_ops.write('tag', 1, step=step),
380                summary_ops.write('tag', 1, step=step)]
381      self.assertAllEqual(f(), [True, False, True])
382      self.assertAllEqual(f(), [False, True, False])
383    events = events_from_logdir(logdir)
384    self.assertEqual(4, len(events))
385    self.assertEqual(0, events[1].step)
386    self.assertEqual(2, events[2].step)
387    self.assertEqual(4, events[3].step)
388
389  def testWrite_recordIf_tensorInput_fromFunction(self):
390    logdir = self.get_temp_dir()
391    with context.eager_mode():
392      writer = summary_ops.create_file_writer_v2(logdir)
393      @def_function.function(input_signature=[
394          tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)])
395      def f(step):
396        with writer.as_default():
397          with summary_ops.record_if(math_ops.equal(step % 2, 0)):
398            return summary_ops.write('tag', 1, step=step)
399      self.assertTrue(f(0))
400      self.assertFalse(f(1))
401      self.assertTrue(f(2))
402      self.assertFalse(f(3))
403      self.assertTrue(f(4))
404    events = events_from_logdir(logdir)
405    self.assertEqual(4, len(events))
406    self.assertEqual(0, events[1].step)
407    self.assertEqual(2, events[2].step)
408    self.assertEqual(4, events[3].step)
409
410  @test_util.also_run_as_tf_function
411  def testGetSetStep(self):
412    try:
413      self.assertIsNone(summary_ops.get_step())
414      summary_ops.set_step(1)
415      # Use assertAllEqual instead of assertEqual since it works in a defun.
416      self.assertAllEqual(1, summary_ops.get_step())
417      summary_ops.set_step(constant_op.constant(2))
418      self.assertAllEqual(2, summary_ops.get_step())
419    finally:
420      # Reset to default state for other tests.
421      summary_ops.set_step(None)
422
423  def testGetSetStep_variable(self):
424    with context.eager_mode():
425      try:
426        mystep = variables.Variable(0)
427        summary_ops.set_step(mystep)
428        self.assertAllEqual(0, summary_ops.get_step().read_value())
429        mystep.assign_add(1)
430        self.assertAllEqual(1, summary_ops.get_step().read_value())
431        # Check that set_step() properly maintains reference to variable.
432        del mystep
433        self.assertAllEqual(1, summary_ops.get_step().read_value())
434        summary_ops.get_step().assign_add(1)
435        self.assertAllEqual(2, summary_ops.get_step().read_value())
436      finally:
437        # Reset to default state for other tests.
438        summary_ops.set_step(None)
439
440  def testGetSetStep_variable_fromFunction(self):
441    with context.eager_mode():
442      try:
443        @def_function.function
444        def set_step(step):
445          summary_ops.set_step(step)
446          return summary_ops.get_step()
447        @def_function.function
448        def get_and_increment():
449          summary_ops.get_step().assign_add(1)
450          return summary_ops.get_step()
451        mystep = variables.Variable(0)
452        self.assertAllEqual(0, set_step(mystep))
453        self.assertAllEqual(0, summary_ops.get_step().read_value())
454        self.assertAllEqual(1, get_and_increment())
455        self.assertAllEqual(2, get_and_increment())
456        # Check that set_step() properly maintains reference to variable.
457        del mystep
458        self.assertAllEqual(3, get_and_increment())
459      finally:
460        # Reset to default state for other tests.
461        summary_ops.set_step(None)
462
463  @test_util.also_run_as_tf_function
464  def testSummaryScope(self):
465    with summary_ops.summary_scope('foo') as (tag, scope):
466      self.assertEqual('foo', tag)
467      self.assertEqual('foo/', scope)
468      with summary_ops.summary_scope('bar') as (tag, scope):
469        self.assertEqual('foo/bar', tag)
470        self.assertEqual('foo/bar/', scope)
471      with summary_ops.summary_scope('with/slash') as (tag, scope):
472        self.assertEqual('foo/with/slash', tag)
473        self.assertEqual('foo/with/slash/', scope)
474      with ops.name_scope(None):
475        with summary_ops.summary_scope('unnested') as (tag, scope):
476          self.assertEqual('unnested', tag)
477          self.assertEqual('unnested/', scope)
478
479  @test_util.also_run_as_tf_function
480  def testSummaryScope_defaultName(self):
481    with summary_ops.summary_scope(None) as (tag, scope):
482      self.assertEqual('summary', tag)
483      self.assertEqual('summary/', scope)
484    with summary_ops.summary_scope(None, 'backup') as (tag, scope):
485      self.assertEqual('backup', tag)
486      self.assertEqual('backup/', scope)
487
488  @test_util.also_run_as_tf_function
489  def testSummaryScope_handlesCharactersIllegalForScope(self):
490    with summary_ops.summary_scope('f?o?o') as (tag, scope):
491      self.assertEqual('f?o?o', tag)
492      self.assertEqual('foo/', scope)
493    # If all characters aren't legal for a scope name, use default name.
494    with summary_ops.summary_scope('???', 'backup') as (tag, scope):
495      self.assertEqual('???', tag)
496      self.assertEqual('backup/', scope)
497
498  @test_util.also_run_as_tf_function
499  def testSummaryScope_nameNotUniquifiedForTag(self):
500    constant_op.constant(0, name='foo')
501    with summary_ops.summary_scope('foo') as (tag, _):
502      self.assertEqual('foo', tag)
503    with summary_ops.summary_scope('foo') as (tag, _):
504      self.assertEqual('foo', tag)
505    with ops.name_scope('with'):
506      constant_op.constant(0, name='slash')
507    with summary_ops.summary_scope('with/slash') as (tag, _):
508      self.assertEqual('with/slash', tag)
509
510
511class SummaryWriterTest(test_util.TensorFlowTestCase):
512
513  def testCreate_withInitAndClose(self):
514    logdir = self.get_temp_dir()
515    with context.eager_mode():
516      writer = summary_ops.create_file_writer_v2(
517          logdir, max_queue=1000, flush_millis=1000000)
518      get_total = lambda: len(events_from_logdir(logdir))
519      self.assertEqual(1, get_total())  # file_version Event
520      # Calling init() again while writer is open has no effect
521      writer.init()
522      self.assertEqual(1, get_total())
523      with writer.as_default():
524        summary_ops.write('tag', 1, step=0)
525        self.assertEqual(1, get_total())
526        # Calling .close() should do an implicit flush
527        writer.close()
528        self.assertEqual(2, get_total())
529
530  def testCreate_fromFunction(self):
531    logdir = self.get_temp_dir()
532    @def_function.function
533    def f():
534      # Returned SummaryWriter must be stored in a non-local variable so it
535      # lives throughout the function execution.
536      if not hasattr(f, 'writer'):
537        f.writer = summary_ops.create_file_writer_v2(logdir)
538    with context.eager_mode():
539      f()
540    event_files = gfile.Glob(os.path.join(logdir, '*'))
541    self.assertEqual(1, len(event_files))
542
543  def testCreate_graphTensorArgument_raisesError(self):
544    logdir = self.get_temp_dir()
545    with context.graph_mode():
546      logdir_tensor = constant_op.constant(logdir)
547    with context.eager_mode():
548      with self.assertRaisesRegex(
549          ValueError, 'Invalid graph Tensor argument.*logdir'):
550        summary_ops.create_file_writer_v2(logdir_tensor)
551    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
552
553  def testCreate_fromFunction_graphTensorArgument_raisesError(self):
554    logdir = self.get_temp_dir()
555    @def_function.function
556    def f():
557      summary_ops.create_file_writer_v2(constant_op.constant(logdir))
558    with context.eager_mode():
559      with self.assertRaisesRegex(
560          ValueError, 'Invalid graph Tensor argument.*logdir'):
561        f()
562    self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
563
564  def testCreate_fromFunction_unpersistedResource_raisesError(self):
565    logdir = self.get_temp_dir()
566    @def_function.function
567    def f():
568      with summary_ops.create_file_writer_v2(logdir).as_default():
569        pass  # Calling .as_default() is enough to indicate use.
570    with context.eager_mode():
571      # TODO(nickfelt): change this to a better error
572      with self.assertRaisesRegex(
573          errors.NotFoundError, 'Resource.*does not exist'):
574        f()
575    # Even though we didn't use it, an event file will have been created.
576    self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*'))))
577
578  def testCreate_immediateSetAsDefault_retainsReference(self):
579    logdir = self.get_temp_dir()
580    try:
581      with context.eager_mode():
582        summary_ops.create_file_writer_v2(logdir).set_as_default()
583        summary_ops.flush()
584    finally:
585      # Ensure we clean up no matter how the test executes.
586      context.context().summary_writer_resource = None
587
588  def testCreate_immediateAsDefault_retainsReference(self):
589    logdir = self.get_temp_dir()
590    with context.eager_mode():
591      with summary_ops.create_file_writer_v2(logdir).as_default():
592        summary_ops.flush()
593
594  def testNoSharing(self):
595    # Two writers with the same logdir should not share state.
596    logdir = self.get_temp_dir()
597    with context.eager_mode():
598      writer1 = summary_ops.create_file_writer_v2(logdir)
599      with writer1.as_default():
600        summary_ops.write('tag', 1, step=1)
601      event_files = gfile.Glob(os.path.join(logdir, '*'))
602      self.assertEqual(1, len(event_files))
603      file1 = event_files[0]
604
605      writer2 = summary_ops.create_file_writer_v2(logdir)
606      with writer2.as_default():
607        summary_ops.write('tag', 1, step=2)
608      event_files = gfile.Glob(os.path.join(logdir, '*'))
609      self.assertEqual(2, len(event_files))
610      event_files.remove(file1)
611      file2 = event_files[0]
612
613      # Extra writes to ensure interleaved usage works.
614      with writer1.as_default():
615        summary_ops.write('tag', 1, step=1)
616      with writer2.as_default():
617        summary_ops.write('tag', 1, step=2)
618
619    events = iter(events_from_file(file1))
620    self.assertEqual('brain.Event:2', next(events).file_version)
621    self.assertEqual(1, next(events).step)
622    self.assertEqual(1, next(events).step)
623    self.assertRaises(StopIteration, lambda: next(events))
624    events = iter(events_from_file(file2))
625    self.assertEqual('brain.Event:2', next(events).file_version)
626    self.assertEqual(2, next(events).step)
627    self.assertEqual(2, next(events).step)
628    self.assertRaises(StopIteration, lambda: next(events))
629
630  def testNoSharing_fromFunction(self):
631    logdir = self.get_temp_dir()
632    @def_function.function
633    def f1():
634      if not hasattr(f1, 'writer'):
635        f1.writer = summary_ops.create_file_writer_v2(logdir)
636      with f1.writer.as_default():
637        summary_ops.write('tag', 1, step=1)
638    @def_function.function
639    def f2():
640      if not hasattr(f2, 'writer'):
641        f2.writer = summary_ops.create_file_writer_v2(logdir)
642      with f2.writer.as_default():
643        summary_ops.write('tag', 1, step=2)
644    with context.eager_mode():
645      f1()
646      event_files = gfile.Glob(os.path.join(logdir, '*'))
647      self.assertEqual(1, len(event_files))
648      file1 = event_files[0]
649
650      f2()
651      event_files = gfile.Glob(os.path.join(logdir, '*'))
652      self.assertEqual(2, len(event_files))
653      event_files.remove(file1)
654      file2 = event_files[0]
655
656      # Extra writes to ensure interleaved usage works.
657      f1()
658      f2()
659
660    events = iter(events_from_file(file1))
661    self.assertEqual('brain.Event:2', next(events).file_version)
662    self.assertEqual(1, next(events).step)
663    self.assertEqual(1, next(events).step)
664    self.assertRaises(StopIteration, lambda: next(events))
665    events = iter(events_from_file(file2))
666    self.assertEqual('brain.Event:2', next(events).file_version)
667    self.assertEqual(2, next(events).step)
668    self.assertEqual(2, next(events).step)
669    self.assertRaises(StopIteration, lambda: next(events))
670
671  def testMaxQueue(self):
672    logdir = self.get_temp_dir()
673    with context.eager_mode():
674      with summary_ops.create_file_writer_v2(
675          logdir, max_queue=1, flush_millis=999999).as_default():
676        get_total = lambda: len(events_from_logdir(logdir))
677        # Note: First tf.Event is always file_version.
678        self.assertEqual(1, get_total())
679        summary_ops.write('tag', 1, step=0)
680        self.assertEqual(1, get_total())
681        # Should flush after second summary since max_queue = 1
682        summary_ops.write('tag', 1, step=0)
683        self.assertEqual(3, get_total())
684
685  def testWriterFlush(self):
686    logdir = self.get_temp_dir()
687    get_total = lambda: len(events_from_logdir(logdir))
688    with context.eager_mode():
689      writer = summary_ops.create_file_writer_v2(
690          logdir, max_queue=1000, flush_millis=1000000)
691      self.assertEqual(1, get_total())  # file_version Event
692      with writer.as_default():
693        summary_ops.write('tag', 1, step=0)
694        self.assertEqual(1, get_total())
695        writer.flush()
696        self.assertEqual(2, get_total())
697        summary_ops.write('tag', 1, step=0)
698        self.assertEqual(2, get_total())
699      # Exiting the "as_default()" should do an implicit flush
700      self.assertEqual(3, get_total())
701
702  def testFlushFunction(self):
703    logdir = self.get_temp_dir()
704    with context.eager_mode():
705      writer = summary_ops.create_file_writer_v2(
706          logdir, max_queue=999999, flush_millis=999999)
707      with writer.as_default():
708        get_total = lambda: len(events_from_logdir(logdir))
709        # Note: First tf.Event is always file_version.
710        self.assertEqual(1, get_total())
711        summary_ops.write('tag', 1, step=0)
712        summary_ops.write('tag', 1, step=0)
713        self.assertEqual(1, get_total())
714        summary_ops.flush()
715        self.assertEqual(3, get_total())
716        # Test "writer" parameter
717        summary_ops.write('tag', 1, step=0)
718        self.assertEqual(3, get_total())
719        summary_ops.flush(writer=writer)
720        self.assertEqual(4, get_total())
721        summary_ops.write('tag', 1, step=0)
722        self.assertEqual(4, get_total())
723        summary_ops.flush(writer=writer._resource)  # pylint:disable=protected-access
724        self.assertEqual(5, get_total())
725
726  @test_util.assert_no_new_pyobjects_executing_eagerly
727  def testEagerMemory(self):
728    logdir = self.get_temp_dir()
729    with summary_ops.create_file_writer_v2(logdir).as_default():
730      summary_ops.write('tag', 1, step=0)
731
732  def testClose_preventsLaterUse(self):
733    logdir = self.get_temp_dir()
734    with context.eager_mode():
735      writer = summary_ops.create_file_writer_v2(logdir)
736      writer.close()
737      writer.close()  # redundant close() is a no-op
738      writer.flush()  # redundant flush() is a no-op
739      with self.assertRaisesRegex(RuntimeError, 'already closed'):
740        writer.init()
741      with self.assertRaisesRegex(RuntimeError, 'already closed'):
742        with writer.as_default():
743          self.fail('should not get here')
744      with self.assertRaisesRegex(RuntimeError, 'already closed'):
745        writer.set_as_default()
746
747  def testClose_closesOpenFile(self):
748    try:
749      import psutil  # pylint: disable=g-import-not-at-top
750    except ImportError:
751      raise unittest.SkipTest('test requires psutil')
752    proc = psutil.Process()
753    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
754    logdir = self.get_temp_dir()
755    with context.eager_mode():
756      writer = summary_ops.create_file_writer_v2(logdir)
757      files = gfile.Glob(os.path.join(logdir, '*'))
758      self.assertEqual(1, len(files))
759      eventfile = files[0]
760      self.assertIn(eventfile, get_open_filenames())
761      writer.close()
762      self.assertNotIn(eventfile, get_open_filenames())
763
764  def testDereference_closesOpenFile(self):
765    try:
766      import psutil  # pylint: disable=g-import-not-at-top
767    except ImportError:
768      raise unittest.SkipTest('test requires psutil')
769    proc = psutil.Process()
770    get_open_filenames = lambda: set(info[0] for info in proc.open_files())
771    logdir = self.get_temp_dir()
772    with context.eager_mode():
773      writer = summary_ops.create_file_writer_v2(logdir)
774      files = gfile.Glob(os.path.join(logdir, '*'))
775      self.assertEqual(1, len(files))
776      eventfile = files[0]
777      self.assertIn(eventfile, get_open_filenames())
778      del writer
779      self.assertNotIn(eventfile, get_open_filenames())
780
781
782class SummaryOpsTest(test_util.TensorFlowTestCase):
783
784  def tearDown(self):
785    summary_ops.trace_off()
786
787  def run_metadata(self, *args, **kwargs):
788    assert context.executing_eagerly()
789    logdir = self.get_temp_dir()
790    writer = summary_ops.create_file_writer(logdir)
791    with writer.as_default():
792      summary_ops.run_metadata(*args, **kwargs)
793    writer.close()
794    events = events_from_logdir(logdir)
795    return events[1]
796
797  def run_metadata_graphs(self, *args, **kwargs):
798    assert context.executing_eagerly()
799    logdir = self.get_temp_dir()
800    writer = summary_ops.create_file_writer(logdir)
801    with writer.as_default():
802      summary_ops.run_metadata_graphs(*args, **kwargs)
803    writer.close()
804    events = events_from_logdir(logdir)
805    return events[1]
806
807  def create_run_metadata(self):
808    step_stats = step_stats_pb2.StepStats(dev_stats=[
809        step_stats_pb2.DeviceStepStats(
810            device='cpu:0',
811            node_stats=[step_stats_pb2.NodeExecStats(node_name='hello')])
812    ])
813    return config_pb2.RunMetadata(
814        function_graphs=[
815            config_pb2.RunMetadata.FunctionGraphs(
816                pre_optimization_graph=graph_pb2.GraphDef(
817                    node=[node_def_pb2.NodeDef(name='foo')]))
818        ],
819        step_stats=step_stats)
820
821  def keras_model(self, *args, **kwargs):
822    logdir = self.get_temp_dir()
823    writer = summary_ops.create_file_writer(logdir)
824    with writer.as_default():
825      summary_ops.keras_model(*args, **kwargs)
826    writer.close()
827    events = events_from_logdir(logdir)
828    # The first event contains no summary values. The written content goes to
829    # the second event.
830    return events[1]
831
832  def run_trace(self, f, step=1):
833    assert context.executing_eagerly()
834    logdir = self.get_temp_dir()
835    writer = summary_ops.create_file_writer(logdir)
836    summary_ops.trace_on(graph=True, profiler=False)
837    with writer.as_default():
838      f()
839      summary_ops.trace_export(name='foo', step=step)
840    writer.close()
841    events = events_from_logdir(logdir)
842    return events[1]
843
844  @test_util.run_v2_only
845  def testRunMetadata_usesNameAsTag(self):
846    meta = config_pb2.RunMetadata()
847
848    with ops.name_scope('foo'):
849      event = self.run_metadata(name='my_name', data=meta, step=1)
850      first_val = event.summary.value[0]
851
852    self.assertEqual('foo/my_name', first_val.tag)
853
854  @test_util.run_v2_only
855  def testRunMetadata_summaryMetadata(self):
856    expected_summary_metadata = """
857      plugin_data {
858        plugin_name: "graph_run_metadata"
859        content: "1"
860      }
861    """
862    meta = config_pb2.RunMetadata()
863    event = self.run_metadata(name='my_name', data=meta, step=1)
864    actual_summary_metadata = event.summary.value[0].metadata
865    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
866
867  @test_util.run_v2_only
868  def testRunMetadata_wholeRunMetadata(self):
869    expected_run_metadata = """
870      step_stats {
871        dev_stats {
872          device: "cpu:0"
873          node_stats {
874            node_name: "hello"
875          }
876        }
877      }
878      function_graphs {
879        pre_optimization_graph {
880          node {
881            name: "foo"
882          }
883        }
884      }
885    """
886    meta = self.create_run_metadata()
887    event = self.run_metadata(name='my_name', data=meta, step=1)
888    first_val = event.summary.value[0]
889
890    actual_run_metadata = config_pb2.RunMetadata.FromString(
891        first_val.tensor.string_val[0])
892    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
893
894  @test_util.run_v2_only
895  def testRunMetadata_usesDefaultStep(self):
896    meta = config_pb2.RunMetadata()
897    try:
898      summary_ops.set_step(42)
899      event = self.run_metadata(name='my_name', data=meta)
900      self.assertEqual(42, event.step)
901    finally:
902      # Reset to default state for other tests.
903      summary_ops.set_step(None)
904
905  @test_util.run_v2_only
906  def testRunMetadataGraph_usesNameAsTag(self):
907    meta = config_pb2.RunMetadata()
908
909    with ops.name_scope('foo'):
910      event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
911      first_val = event.summary.value[0]
912
913    self.assertEqual('foo/my_name', first_val.tag)
914
915  @test_util.run_v2_only
916  def testRunMetadataGraph_summaryMetadata(self):
917    expected_summary_metadata = """
918      plugin_data {
919        plugin_name: "graph_run_metadata_graph"
920        content: "1"
921      }
922    """
923    meta = config_pb2.RunMetadata()
924    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
925    actual_summary_metadata = event.summary.value[0].metadata
926    self.assertProtoEquals(expected_summary_metadata, actual_summary_metadata)
927
928  @test_util.run_v2_only
929  def testRunMetadataGraph_runMetadataFragment(self):
930    expected_run_metadata = """
931      function_graphs {
932        pre_optimization_graph {
933          node {
934            name: "foo"
935          }
936        }
937      }
938    """
939    meta = self.create_run_metadata()
940
941    event = self.run_metadata_graphs(name='my_name', data=meta, step=1)
942    first_val = event.summary.value[0]
943
944    actual_run_metadata = config_pb2.RunMetadata.FromString(
945        first_val.tensor.string_val[0])
946    self.assertProtoEquals(expected_run_metadata, actual_run_metadata)
947
948  @test_util.run_v2_only
949  def testRunMetadataGraph_usesDefaultStep(self):
950    meta = config_pb2.RunMetadata()
951    try:
952      summary_ops.set_step(42)
953      event = self.run_metadata_graphs(name='my_name', data=meta)
954      self.assertEqual(42, event.step)
955    finally:
956      # Reset to default state for other tests.
957      summary_ops.set_step(None)
958
959  @test_util.run_v2_only
960  def testKerasModel(self):
961    model = Sequential(
962        [Dense(10, input_shape=(100,)),
963         Activation('relu', name='my_relu')])
964    event = self.keras_model(name='my_name', data=model, step=1)
965    first_val = event.summary.value[0]
966    self.assertEqual(model.to_json(), first_val.tensor.string_val[0].decode())
967
968  @test_util.run_v2_only
969  def testKerasModel_usesDefaultStep(self):
970    model = Sequential(
971        [Dense(10, input_shape=(100,)),
972         Activation('relu', name='my_relu')])
973    try:
974      summary_ops.set_step(42)
975      event = self.keras_model(name='my_name', data=model)
976      self.assertEqual(42, event.step)
977    finally:
978      # Reset to default state for other tests.
979      summary_ops.set_step(None)
980
981  @test_util.run_v2_only
982  def testKerasModel_subclass(self):
983
984    class SimpleSubclass(Model):
985
986      def __init__(self):
987        super(SimpleSubclass, self).__init__(name='subclass')
988        self.dense = Dense(10, input_shape=(100,))
989        self.activation = Activation('relu', name='my_relu')
990
991      def call(self, inputs):
992        x = self.dense(inputs)
993        return self.activation(x)
994
995    model = SimpleSubclass()
996    with test.mock.patch.object(logging, 'warn') as mock_log:
997      self.assertFalse(
998          summary_ops.keras_model(name='my_name', data=model, step=1))
999      self.assertRegexpMatches(
1000          str(mock_log.call_args), 'Model failed to serialize as JSON.')
1001
1002  @test_util.run_v2_only
1003  def testKerasModel_otherExceptions(self):
1004    model = Sequential()
1005
1006    with test.mock.patch.object(model, 'to_json') as mock_to_json:
1007      with test.mock.patch.object(logging, 'warn') as mock_log:
1008        mock_to_json.side_effect = Exception('oops')
1009        self.assertFalse(
1010            summary_ops.keras_model(name='my_name', data=model, step=1))
1011        self.assertRegexpMatches(
1012            str(mock_log.call_args),
1013            'Model failed to serialize as JSON. Ignoring... oops')
1014
1015  @test_util.run_v2_only
1016  def testTrace(self):
1017
1018    @def_function.function
1019    def f():
1020      x = constant_op.constant(2)
1021      y = constant_op.constant(3)
1022      return x**y
1023
1024    event = self.run_trace(f)
1025
1026    first_val = event.summary.value[0]
1027    actual_run_metadata = config_pb2.RunMetadata.FromString(
1028        first_val.tensor.string_val[0])
1029
1030    # Content of function_graphs is large and, for instance, device can change.
1031    self.assertTrue(hasattr(actual_run_metadata, 'function_graphs'))
1032
1033  @test_util.run_v2_only
1034  def testTrace_cannotEnableTraceInFunction(self):
1035
1036    @def_function.function
1037    def f():
1038      summary_ops.trace_on(graph=True, profiler=False)
1039      x = constant_op.constant(2)
1040      y = constant_op.constant(3)
1041      return x**y
1042
1043    with test.mock.patch.object(logging, 'warn') as mock_log:
1044      f()
1045      self.assertRegexpMatches(
1046          str(mock_log.call_args), 'Cannot enable trace inside a tf.function.')
1047
1048  @test_util.run_v2_only
1049  def testTrace_cannotEnableTraceInGraphMode(self):
1050    with test.mock.patch.object(logging, 'warn') as mock_log:
1051      with context.graph_mode():
1052        summary_ops.trace_on(graph=True, profiler=False)
1053      self.assertRegexpMatches(
1054          str(mock_log.call_args), 'Must enable trace in eager mode.')
1055
1056  @test_util.run_v2_only
1057  def testTrace_cannotExportTraceWithoutTrace(self):
1058    with six.assertRaisesRegex(self, ValueError,
1059                               'Must enable trace before export.'):
1060      summary_ops.trace_export(name='foo', step=1)
1061
1062  @test_util.run_v2_only
1063  def testTrace_cannotExportTraceInFunction(self):
1064    summary_ops.trace_on(graph=True, profiler=False)
1065
1066    @def_function.function
1067    def f():
1068      x = constant_op.constant(2)
1069      y = constant_op.constant(3)
1070      summary_ops.trace_export(name='foo', step=1)
1071      return x**y
1072
1073    with test.mock.patch.object(logging, 'warn') as mock_log:
1074      f()
1075      self.assertRegexpMatches(
1076          str(mock_log.call_args),
1077          'Cannot export trace inside a tf.function.')
1078
1079  @test_util.run_v2_only
1080  def testTrace_cannotExportTraceInGraphMode(self):
1081    with test.mock.patch.object(logging, 'warn') as mock_log:
1082      with context.graph_mode():
1083        summary_ops.trace_export(name='foo', step=1)
1084      self.assertRegexpMatches(
1085          str(mock_log.call_args),
1086          'Can only export trace while executing eagerly.')
1087
1088  @test_util.run_v2_only
1089  def testTrace_usesDefaultStep(self):
1090
1091    @def_function.function
1092    def f():
1093      x = constant_op.constant(2)
1094      y = constant_op.constant(3)
1095      return x**y
1096
1097    try:
1098      summary_ops.set_step(42)
1099      event = self.run_trace(f, step=None)
1100      self.assertEqual(42, event.step)
1101    finally:
1102      # Reset to default state for other tests.
1103      summary_ops.set_step(None)
1104
1105
1106def events_from_file(filepath):
1107  """Returns all events in a single event file.
1108
1109  Args:
1110    filepath: Path to the event file.
1111
1112  Returns:
1113    A list of all tf.Event protos in the event file.
1114  """
1115  records = list(tf_record.tf_record_iterator(filepath))
1116  result = []
1117  for r in records:
1118    event = event_pb2.Event()
1119    event.ParseFromString(r)
1120    result.append(event)
1121  return result
1122
1123
1124def events_from_logdir(logdir):
1125  """Returns all events in the single eventfile in logdir.
1126
1127  Args:
1128    logdir: The directory in which the single event file is sought.
1129
1130  Returns:
1131    A list of all tf.Event protos from the single event file.
1132
1133  Raises:
1134    AssertionError: If logdir does not contain exactly one file.
1135  """
1136  assert gfile.Exists(logdir)
1137  files = gfile.ListDirectory(logdir)
1138  assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
1139  return events_from_file(os.path.join(logdir, files[0]))
1140
1141
1142def to_numpy(summary_value):
1143  return tensor_util.MakeNdarray(summary_value.tensor)
1144
1145
1146if __name__ == '__main__':
1147  test.main()
1148