1# Lint as python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16# pylint: disable=g-import-not-at-top
17"""Utilities for file download and caching."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from abc import abstractmethod
23from contextlib import closing
24import errno
25import functools
26import hashlib
27import multiprocessing
28import multiprocessing.dummy
29import os
30import random
31import shutil
32import sys
33import tarfile
34import threading
35import time
36import weakref
37import zipfile
38
39import numpy as np
40import six
41from six.moves.urllib.error import HTTPError
42from six.moves.urllib.error import URLError
43
44from tensorflow.python.framework import ops
45from six.moves.urllib.request import urlopen
46from tensorflow.python.keras.utils import tf_inspect
47from tensorflow.python.keras.utils.generic_utils import Progbar
48from tensorflow.python.keras.utils.io_utils import path_to_string
49from tensorflow.python.util.tf_export import keras_export
50
51
52try:
53  import queue
54except ImportError:
55  import Queue as queue
56
57try:
58  import typing
59  is_iterator = lambda x: isinstance(x, typing.Iterator)
60except ImportError:
61  # Python2 uses next, and Python3 should have typing so __next__ is not needed.
62  is_iterator = lambda x: hasattr(x, '__iter__') and hasattr(x, 'next')
63
64
65if sys.version_info[0] == 2:
66
67  def urlretrieve(url, filename, reporthook=None, data=None):
68    """Replacement for `urlretrieve` for Python 2.
69
70    Under Python 2, `urlretrieve` relies on `FancyURLopener` from legacy
71    `urllib` module, known to have issues with proxy management.
72
73    Args:
74        url: url to retrieve.
75        filename: where to store the retrieved data locally.
76        reporthook: a hook function that will be called once on establishment of
77          the network connection and once after each block read thereafter. The
78          hook will be passed three arguments; a count of blocks transferred so
79          far, a block size in bytes, and the total size of the file.
80        data: `data` argument passed to `urlopen`.
81    """
82
83    def chunk_read(response, chunk_size=8192, reporthook=None):
84      content_type = response.info().get('Content-Length')
85      total_size = -1
86      if content_type is not None:
87        total_size = int(content_type.strip())
88      count = 0
89      while True:
90        chunk = response.read(chunk_size)
91        count += 1
92        if reporthook is not None:
93          reporthook(count, chunk_size, total_size)
94        if chunk:
95          yield chunk
96        else:
97          break
98
99    response = urlopen(url, data)
100    with open(filename, 'wb') as fd:
101      for chunk in chunk_read(response, reporthook=reporthook):
102        fd.write(chunk)
103else:
104  from six.moves.urllib.request import urlretrieve
105
106
107def is_generator_or_sequence(x):
108  """Check if `x` is a Keras generator type."""
109  builtin_iterators = (str, list, tuple, dict, set, frozenset)
110  if isinstance(x, (ops.Tensor, np.ndarray) + builtin_iterators):
111    return False
112  return tf_inspect.isgenerator(x) or isinstance(x, Sequence) or is_iterator(x)
113
114
115def _extract_archive(file_path, path='.', archive_format='auto'):
116  """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
117
118  Args:
119      file_path: path to the archive file
120      path: path to extract the archive file
121      archive_format: Archive format to try for extracting the file.
122          Options are 'auto', 'tar', 'zip', and None.
123          'tar' includes tar, tar.gz, and tar.bz files.
124          The default 'auto' is ['tar', 'zip'].
125          None or an empty list will return no matches found.
126
127  Returns:
128      True if a match was found and an archive extraction was completed,
129      False otherwise.
130  """
131  if archive_format is None:
132    return False
133  if archive_format == 'auto':
134    archive_format = ['tar', 'zip']
135  if isinstance(archive_format, six.string_types):
136    archive_format = [archive_format]
137
138  file_path = path_to_string(file_path)
139  path = path_to_string(path)
140
141  for archive_type in archive_format:
142    if archive_type == 'tar':
143      open_fn = tarfile.open
144      is_match_fn = tarfile.is_tarfile
145    if archive_type == 'zip':
146      open_fn = zipfile.ZipFile
147      is_match_fn = zipfile.is_zipfile
148
149    if is_match_fn(file_path):
150      with open_fn(file_path) as archive:
151        try:
152          archive.extractall(path)
153        except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
154          if os.path.exists(path):
155            if os.path.isfile(path):
156              os.remove(path)
157            else:
158              shutil.rmtree(path)
159          raise
160      return True
161  return False
162
163
164@keras_export('keras.utils.get_file')
165def get_file(fname,
166             origin,
167             untar=False,
168             md5_hash=None,
169             file_hash=None,
170             cache_subdir='datasets',
171             hash_algorithm='auto',
172             extract=False,
173             archive_format='auto',
174             cache_dir=None):
175  """Downloads a file from a URL if it not already in the cache.
176
177  By default the file at the url `origin` is downloaded to the
178  cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
179  and given the filename `fname`. The final location of a file
180  `example.txt` would therefore be `~/.keras/datasets/example.txt`.
181
182  Files in tar, tar.gz, tar.bz, and zip formats can also be extracted.
183  Passing a hash will verify the file after download. The command line
184  programs `shasum` and `sha256sum` can compute the hash.
185
186  Example:
187
188  ```python
189  path_to_downloaded_file = tf.keras.utils.get_file(
190      "flower_photos",
191      "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
192      untar=True)
193  ```
194
195  Args:
196      fname: Name of the file. If an absolute path `/path/to/file.txt` is
197          specified the file will be saved at that location.
198      origin: Original URL of the file.
199      untar: Deprecated in favor of `extract` argument.
200          boolean, whether the file should be decompressed
201      md5_hash: Deprecated in favor of `file_hash` argument.
202          md5 hash of the file for verification
203      file_hash: The expected hash string of the file after download.
204          The sha256 and md5 hash algorithms are both supported.
205      cache_subdir: Subdirectory under the Keras cache dir where the file is
206          saved. If an absolute path `/path/to/folder` is
207          specified the file will be saved at that location.
208      hash_algorithm: Select the hash algorithm to verify the file.
209          options are `'md5'`, `'sha256'`, and `'auto'`.
210          The default 'auto' detects the hash algorithm in use.
211      extract: True tries extracting the file as an Archive, like tar or zip.
212      archive_format: Archive format to try for extracting the file.
213          Options are `'auto'`, `'tar'`, `'zip'`, and `None`.
214          `'tar'` includes tar, tar.gz, and tar.bz files.
215          The default `'auto'` corresponds to `['tar', 'zip']`.
216          None or an empty list will return no matches found.
217      cache_dir: Location to store cached files, when None it
218          defaults to the default directory `~/.keras/`.
219
220  Returns:
221      Path to the downloaded file
222  """
223  if cache_dir is None:
224    cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
225  if md5_hash is not None and file_hash is None:
226    file_hash = md5_hash
227    hash_algorithm = 'md5'
228  datadir_base = os.path.expanduser(cache_dir)
229  if not os.access(datadir_base, os.W_OK):
230    datadir_base = os.path.join('/tmp', '.keras')
231  datadir = os.path.join(datadir_base, cache_subdir)
232  _makedirs_exist_ok(datadir)
233
234  fname = path_to_string(fname)
235
236  if untar:
237    untar_fpath = os.path.join(datadir, fname)
238    fpath = untar_fpath + '.tar.gz'
239  else:
240    fpath = os.path.join(datadir, fname)
241
242  download = False
243  if os.path.exists(fpath):
244    # File found; verify integrity if a hash was provided.
245    if file_hash is not None:
246      if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
247        print('A local file was found, but it seems to be '
248              'incomplete or outdated because the ' + hash_algorithm +
249              ' file hash does not match the original value of ' + file_hash +
250              ' so we will re-download the data.')
251        download = True
252  else:
253    download = True
254
255  if download:
256    print('Downloading data from', origin)
257
258    class ProgressTracker(object):
259      # Maintain progbar for the lifetime of download.
260      # This design was chosen for Python 2.7 compatibility.
261      progbar = None
262
263    def dl_progress(count, block_size, total_size):
264      if ProgressTracker.progbar is None:
265        if total_size == -1:
266          total_size = None
267        ProgressTracker.progbar = Progbar(total_size)
268      else:
269        ProgressTracker.progbar.update(count * block_size)
270
271    error_msg = 'URL fetch failure on {}: {} -- {}'
272    try:
273      try:
274        urlretrieve(origin, fpath, dl_progress)
275      except HTTPError as e:
276        raise Exception(error_msg.format(origin, e.code, e.msg))
277      except URLError as e:
278        raise Exception(error_msg.format(origin, e.errno, e.reason))
279    except (Exception, KeyboardInterrupt) as e:
280      if os.path.exists(fpath):
281        os.remove(fpath)
282      raise
283    ProgressTracker.progbar = None
284
285  if untar:
286    if not os.path.exists(untar_fpath):
287      _extract_archive(fpath, datadir, archive_format='tar')
288    return untar_fpath
289
290  if extract:
291    _extract_archive(fpath, datadir, archive_format)
292
293  return fpath
294
295
296def _makedirs_exist_ok(datadir):
297  if six.PY2:
298    # Python 2 doesn't have the exist_ok arg, so we try-except here.
299    try:
300      os.makedirs(datadir)
301    except OSError as e:
302      if e.errno != errno.EEXIST:
303        raise
304  else:
305    os.makedirs(datadir, exist_ok=True)  # pylint: disable=unexpected-keyword-arg
306
307
308def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
309  """Calculates a file sha256 or md5 hash.
310
311  Example:
312
313  ```python
314  _hash_file('/path/to/file.zip')
315  'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
316  ```
317
318  Args:
319      fpath: path to the file being validated
320      algorithm: hash algorithm, one of `'auto'`, `'sha256'`, or `'md5'`.
321          The default `'auto'` detects the hash algorithm in use.
322      chunk_size: Bytes to read at a time, important for large files.
323
324  Returns:
325      The file hash
326  """
327  if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
328    hasher = hashlib.sha256()
329  else:
330    hasher = hashlib.md5()
331
332  with open(fpath, 'rb') as fpath_file:
333    for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
334      hasher.update(chunk)
335
336  return hasher.hexdigest()
337
338
339def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
340  """Validates a file against a sha256 or md5 hash.
341
342  Args:
343      fpath: path to the file being validated
344      file_hash:  The expected hash string of the file.
345          The sha256 and md5 hash algorithms are both supported.
346      algorithm: Hash algorithm, one of 'auto', 'sha256', or 'md5'.
347          The default 'auto' detects the hash algorithm in use.
348      chunk_size: Bytes to read at a time, important for large files.
349
350  Returns:
351      Whether the file is valid
352  """
353  if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
354    hasher = 'sha256'
355  else:
356    hasher = 'md5'
357
358  if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
359    return True
360  else:
361    return False
362
363
364class ThreadsafeIter(object):
365  """Wrap an iterator with a lock and propagate exceptions to all threads."""
366
367  def __init__(self, it):
368    self.it = it
369    self.lock = threading.Lock()
370
371    # After a generator throws an exception all subsequent next() calls raise a
372    # StopIteration Exception. This, however, presents an issue when mixing
373    # generators and threading because it means the order of retrieval need not
374    # match the order in which the generator was called. This can make it appear
375    # that a generator exited normally when in fact the terminating exception is
376    # just in a different thread. In order to provide thread safety, once
377    # self.it has thrown an exception we continue to throw the same exception.
378    self._exception = None
379
380  def __iter__(self):
381    return self
382
383  def next(self):
384    return self.__next__()
385
386  def __next__(self):
387    with self.lock:
388      if self._exception:
389        raise self._exception  # pylint: disable=raising-bad-type
390
391      try:
392        return next(self.it)
393      except Exception as e:
394        self._exception = e
395        raise
396
397
398def threadsafe_generator(f):
399
400  @functools.wraps(f)
401  def g(*a, **kw):
402    return ThreadsafeIter(f(*a, **kw))
403
404  return g
405
406
407@keras_export('keras.utils.Sequence')
408class Sequence(object):
409  """Base object for fitting to a sequence of data, such as a dataset.
410
411  Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
412  If you want to modify your dataset between epochs you may implement
413  `on_epoch_end`.
414  The method `__getitem__` should return a complete batch.
415
416  Notes:
417
418  `Sequence` are a safer way to do multiprocessing. This structure guarantees
419  that the network will only train once
420   on each sample per epoch which is not the case with generators.
421
422  Examples:
423
424  ```python
425  from skimage.io import imread
426  from skimage.transform import resize
427  import numpy as np
428  import math
429
430  # Here, `x_set` is list of path to the images
431  # and `y_set` are the associated classes.
432
433  class CIFAR10Sequence(Sequence):
434
435      def __init__(self, x_set, y_set, batch_size):
436          self.x, self.y = x_set, y_set
437          self.batch_size = batch_size
438
439      def __len__(self):
440          return math.ceil(len(self.x) / self.batch_size)
441
442      def __getitem__(self, idx):
443          batch_x = self.x[idx * self.batch_size:(idx + 1) *
444          self.batch_size]
445          batch_y = self.y[idx * self.batch_size:(idx + 1) *
446          self.batch_size]
447
448          return np.array([
449              resize(imread(file_name), (200, 200))
450                 for file_name in batch_x]), np.array(batch_y)
451  ```
452  """
453
454  @abstractmethod
455  def __getitem__(self, index):
456    """Gets batch at position `index`.
457
458    Args:
459        index: position of the batch in the Sequence.
460
461    Returns:
462        A batch
463    """
464    raise NotImplementedError
465
466  @abstractmethod
467  def __len__(self):
468    """Number of batch in the Sequence.
469
470    Returns:
471        The number of batches in the Sequence.
472    """
473    raise NotImplementedError
474
475  def on_epoch_end(self):
476    """Method called at the end of every epoch.
477    """
478    pass
479
480  def __iter__(self):
481    """Create a generator that iterate over the Sequence."""
482    for item in (self[i] for i in range(len(self))):
483      yield item
484
485
486def iter_sequence_infinite(seq):
487  """Iterates indefinitely over a Sequence.
488
489  Args:
490    seq: `Sequence` instance.
491
492  Yields:
493    Batches of data from the `Sequence`.
494  """
495  while True:
496    for item in seq:
497      yield item
498
499
500# Global variables to be shared across processes
501_SHARED_SEQUENCES = {}
502# We use a Value to provide unique id to different processes.
503_SEQUENCE_COUNTER = None
504
505
506# Because multiprocessing pools are inherently unsafe, starting from a clean
507# state can be essential to avoiding deadlocks. In order to accomplish this, we
508# need to be able to check on the status of Pools that we create.
509_DATA_POOLS = weakref.WeakSet()
510_WORKER_ID_QUEUE = None  # Only created if needed.
511_WORKER_IDS = set()
512_FORCE_THREADPOOL = False
513_FORCE_THREADPOOL_LOCK = threading.RLock()
514
515
516def dont_use_multiprocessing_pool(f):
517  @functools.wraps(f)
518  def wrapped(*args, **kwargs):
519    with _FORCE_THREADPOOL_LOCK:
520      global _FORCE_THREADPOOL
521      old_force_threadpool, _FORCE_THREADPOOL = _FORCE_THREADPOOL, True
522      out = f(*args, **kwargs)
523      _FORCE_THREADPOOL = old_force_threadpool
524      return out
525  return wrapped
526
527
528def get_pool_class(use_multiprocessing):
529  global _FORCE_THREADPOOL
530  if not use_multiprocessing or _FORCE_THREADPOOL:
531    return multiprocessing.dummy.Pool  # ThreadPool
532  return multiprocessing.Pool
533
534
535def get_worker_id_queue():
536  """Lazily create the queue to track worker ids."""
537  global _WORKER_ID_QUEUE
538  if _WORKER_ID_QUEUE is None:
539    _WORKER_ID_QUEUE = multiprocessing.Queue()
540  return _WORKER_ID_QUEUE
541
542
543def init_pool(seqs):
544  global _SHARED_SEQUENCES
545  _SHARED_SEQUENCES = seqs
546
547
548def get_index(uid, i):
549  """Get the value from the Sequence `uid` at index `i`.
550
551  To allow multiple Sequences to be used at the same time, we use `uid` to
552  get a specific one. A single Sequence would cause the validation to
553  overwrite the training Sequence.
554
555  Args:
556      uid: int, Sequence identifier
557      i: index
558
559  Returns:
560      The value at index `i`.
561  """
562  return _SHARED_SEQUENCES[uid][i]
563
564
565@keras_export('keras.utils.SequenceEnqueuer')
566class SequenceEnqueuer(object):
567  """Base class to enqueue inputs.
568
569  The task of an Enqueuer is to use parallelism to speed up preprocessing.
570  This is done with processes or threads.
571
572  Example:
573
574  ```python
575      enqueuer = SequenceEnqueuer(...)
576      enqueuer.start()
577      datas = enqueuer.get()
578      for data in datas:
579          # Use the inputs; training, evaluating, predicting.
580          # ... stop sometime.
581      enqueuer.stop()
582  ```
583
584  The `enqueuer.get()` should be an infinite stream of datas.
585  """
586
587  def __init__(self, sequence,
588               use_multiprocessing=False):
589    self.sequence = sequence
590    self.use_multiprocessing = use_multiprocessing
591
592    global _SEQUENCE_COUNTER
593    if _SEQUENCE_COUNTER is None:
594      try:
595        _SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
596      except OSError:
597        # In this case the OS does not allow us to use
598        # multiprocessing. We resort to an int
599        # for enqueuer indexing.
600        _SEQUENCE_COUNTER = 0
601
602    if isinstance(_SEQUENCE_COUNTER, int):
603      self.uid = _SEQUENCE_COUNTER
604      _SEQUENCE_COUNTER += 1
605    else:
606      # Doing Multiprocessing.Value += x is not process-safe.
607      with _SEQUENCE_COUNTER.get_lock():
608        self.uid = _SEQUENCE_COUNTER.value
609        _SEQUENCE_COUNTER.value += 1
610
611    self.workers = 0
612    self.executor_fn = None
613    self.queue = None
614    self.run_thread = None
615    self.stop_signal = None
616
617  def is_running(self):
618    return self.stop_signal is not None and not self.stop_signal.is_set()
619
620  def start(self, workers=1, max_queue_size=10):
621    """Starts the handler's workers.
622
623    Args:
624        workers: Number of workers.
625        max_queue_size: queue size
626            (when full, workers could block on `put()`)
627    """
628    if self.use_multiprocessing:
629      self.executor_fn = self._get_executor_init(workers)
630    else:
631      # We do not need the init since it's threads.
632      self.executor_fn = lambda _: get_pool_class(False)(workers)
633    self.workers = workers
634    self.queue = queue.Queue(max_queue_size)
635    self.stop_signal = threading.Event()
636    self.run_thread = threading.Thread(target=self._run)
637    self.run_thread.daemon = True
638    self.run_thread.start()
639
640  def _send_sequence(self):
641    """Sends current Iterable to all workers."""
642    # For new processes that may spawn
643    _SHARED_SEQUENCES[self.uid] = self.sequence
644
645  def stop(self, timeout=None):
646    """Stops running threads and wait for them to exit, if necessary.
647
648    Should be called by the same thread which called `start()`.
649
650    Args:
651        timeout: maximum time to wait on `thread.join()`
652    """
653    self.stop_signal.set()
654    with self.queue.mutex:
655      self.queue.queue.clear()
656      self.queue.unfinished_tasks = 0
657      self.queue.not_full.notify()
658    self.run_thread.join(timeout)
659    _SHARED_SEQUENCES[self.uid] = None
660
661  def __del__(self):
662    if self.is_running():
663      self.stop()
664
665  @abstractmethod
666  def _run(self):
667    """Submits request to the executor and queue the `Future` objects."""
668    raise NotImplementedError
669
670  @abstractmethod
671  def _get_executor_init(self, workers):
672    """Gets the Pool initializer for multiprocessing.
673
674    Args:
675        workers: Number of workers.
676
677    Returns:
678        Function, a Function to initialize the pool
679    """
680    raise NotImplementedError
681
682  @abstractmethod
683  def get(self):
684    """Creates a generator to extract data from the queue.
685
686    Skip the data if it is `None`.
687    # Returns
688        Generator yielding tuples `(inputs, targets)`
689            or `(inputs, targets, sample_weights)`.
690    """
691    raise NotImplementedError
692
693
694@keras_export('keras.utils.OrderedEnqueuer')
695class OrderedEnqueuer(SequenceEnqueuer):
696  """Builds a Enqueuer from a Sequence.
697
698  Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
699
700  Args:
701      sequence: A `tf.keras.utils.data_utils.Sequence` object.
702      use_multiprocessing: use multiprocessing if True, otherwise threading
703      shuffle: whether to shuffle the data at the beginning of each epoch
704  """
705
706  def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
707    super(OrderedEnqueuer, self).__init__(sequence, use_multiprocessing)
708    self.shuffle = shuffle
709
710  def _get_executor_init(self, workers):
711    """Gets the Pool initializer for multiprocessing.
712
713    Args:
714        workers: Number of workers.
715
716    Returns:
717        Function, a Function to initialize the pool
718    """
719    def pool_fn(seqs):
720      pool = get_pool_class(True)(
721          workers, initializer=init_pool_generator,
722          initargs=(seqs, None, get_worker_id_queue()))
723      _DATA_POOLS.add(pool)
724      return pool
725
726    return pool_fn
727
728  def _wait_queue(self):
729    """Wait for the queue to be empty."""
730    while True:
731      time.sleep(0.1)
732      if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
733        return
734
735  def _run(self):
736    """Submits request to the executor and queue the `Future` objects."""
737    sequence = list(range(len(self.sequence)))
738    self._send_sequence()  # Share the initial sequence
739    while True:
740      if self.shuffle:
741        random.shuffle(sequence)
742
743      with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
744        for i in sequence:
745          if self.stop_signal.is_set():
746            return
747
748          self.queue.put(
749              executor.apply_async(get_index, (self.uid, i)), block=True)
750
751        # Done with the current epoch, waiting for the final batches
752        self._wait_queue()
753
754        if self.stop_signal.is_set():
755          # We're done
756          return
757
758      # Call the internal on epoch end.
759      self.sequence.on_epoch_end()
760      self._send_sequence()  # Update the pool
761
762  def get(self):
763    """Creates a generator to extract data from the queue.
764
765    Skip the data if it is `None`.
766
767    Yields:
768        The next element in the queue, i.e. a tuple
769        `(inputs, targets)` or
770        `(inputs, targets, sample_weights)`.
771    """
772    while self.is_running():
773      try:
774        inputs = self.queue.get(block=True, timeout=5).get()
775        if self.is_running():
776          self.queue.task_done()
777        if inputs is not None:
778          yield inputs
779      except queue.Empty:
780        pass
781      except Exception:  # pylint: disable=broad-except
782        self.stop()
783        six.reraise(*sys.exc_info())
784
785
786def init_pool_generator(gens, random_seed=None, id_queue=None):
787  """Initializer function for pool workers.
788
789  Args:
790    gens: State which should be made available to worker processes.
791    random_seed: An optional value with which to seed child processes.
792    id_queue: A multiprocessing Queue of worker ids. This is used to indicate
793      that a worker process was created by Keras and can be terminated using
794      the cleanup_all_keras_forkpools utility.
795  """
796  global _SHARED_SEQUENCES
797  _SHARED_SEQUENCES = gens
798
799  worker_proc = multiprocessing.current_process()
800
801  # name isn't used for anything, but setting a more descriptive name is helpful
802  # when diagnosing orphaned processes.
803  worker_proc.name = 'Keras_worker_{}'.format(worker_proc.name)
804
805  if random_seed is not None:
806    np.random.seed(random_seed + worker_proc.ident)
807
808  if id_queue is not None:
809    # If a worker dies during init, the pool will just create a replacement.
810    id_queue.put(worker_proc.ident, block=True, timeout=0.1)
811
812
813def next_sample(uid):
814  """Gets the next value from the generator `uid`.
815
816  To allow multiple generators to be used at the same time, we use `uid` to
817  get a specific one. A single generator would cause the validation to
818  overwrite the training generator.
819
820  Args:
821      uid: int, generator identifier
822
823  Returns:
824      The next value of generator `uid`.
825  """
826  return six.next(_SHARED_SEQUENCES[uid])
827
828
829@keras_export('keras.utils.GeneratorEnqueuer')
830class GeneratorEnqueuer(SequenceEnqueuer):
831  """Builds a queue out of a data generator.
832
833  The provided generator can be finite in which case the class will throw
834  a `StopIteration` exception.
835
836  Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
837
838  Args:
839      generator: a generator function which yields data
840      use_multiprocessing: use multiprocessing if True, otherwise threading
841      wait_time: time to sleep in-between calls to `put()`
842      random_seed: Initial seed for workers,
843          will be incremented by one for each worker.
844  """
845
846  def __init__(self, sequence,
847               use_multiprocessing=False,
848               random_seed=None):
849    super(GeneratorEnqueuer, self).__init__(sequence, use_multiprocessing)
850    self.random_seed = random_seed
851
852  def _get_executor_init(self, workers):
853    """Gets the Pool initializer for multiprocessing.
854
855    Args:
856      workers: Number of works.
857
858    Returns:
859        A Function to initialize the pool
860    """
861    def pool_fn(seqs):
862      pool = get_pool_class(True)(
863          workers, initializer=init_pool_generator,
864          initargs=(seqs, self.random_seed, get_worker_id_queue()))
865      _DATA_POOLS.add(pool)
866      return pool
867    return pool_fn
868
869  def _run(self):
870    """Submits request to the executor and queue the `Future` objects."""
871    self._send_sequence()  # Share the initial generator
872    with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
873      while True:
874        if self.stop_signal.is_set():
875          return
876
877        self.queue.put(
878            executor.apply_async(next_sample, (self.uid,)), block=True)
879
880  def get(self):
881    """Creates a generator to extract data from the queue.
882
883    Skip the data if it is `None`.
884
885    Yields:
886        The next element in the queue, i.e. a tuple
887        `(inputs, targets)` or
888        `(inputs, targets, sample_weights)`.
889    """
890    try:
891      while self.is_running():
892        inputs = self.queue.get(block=True).get()
893        self.queue.task_done()
894        if inputs is not None:
895          yield inputs
896    except StopIteration:
897      # Special case for finite generators
898      last_ones = []
899      while self.queue.qsize() > 0:
900        last_ones.append(self.queue.get(block=True))
901      # Wait for them to complete
902      for f in last_ones:
903        f.wait()
904      # Keep the good ones
905      last_ones = [future.get() for future in last_ones if future.successful()]
906      for inputs in last_ones:
907        if inputs is not None:
908          yield inputs
909    except Exception as e:  # pylint: disable=broad-except
910      self.stop()
911      if 'generator already executing' in str(e):
912        raise RuntimeError(
913            'Your generator is NOT thread-safe. '
914            'Keras requires a thread-safe generator when '
915            '`use_multiprocessing=False, workers > 1`. ')
916      six.reraise(*sys.exc_info())
917