1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental API for controlling threading in `tf.data` pipelines.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20 21from tensorflow.core.framework import dataset_options_pb2 22from tensorflow.python.data.util import options 23from tensorflow.python.util.tf_export import tf_export 24 25 26@tf_export("data.experimental.ThreadingOptions") 27class ThreadingOptions(options.OptionsBase): 28 """Represents options for dataset threading. 29 30 You can set the threading options of a dataset through the 31 `experimental_threading` property of `tf.data.Options`; the property is 32 an instance of `tf.data.experimental.ThreadingOptions`. 33 34 ```python 35 options = tf.data.Options() 36 options.experimental_threading.private_threadpool_size = 10 37 dataset = dataset.with_options(options) 38 ``` 39 """ 40 41 max_intra_op_parallelism = options.create_option( 42 name="max_intra_op_parallelism", 43 ty=int, 44 docstring= 45 "If set, it overrides the maximum degree of intra-op parallelism.") 46 47 private_threadpool_size = options.create_option( 48 name="private_threadpool_size", 49 ty=int, 50 docstring= 51 "If set, the dataset will use a private threadpool of the given size.") 52 53 def _to_proto(self): 54 pb = dataset_options_pb2.ThreadingOptions() 55 if self.max_intra_op_parallelism is not None: 56 pb.max_intra_op_parallelism = self.max_intra_op_parallelism 57 if self.private_threadpool_size is not None: 58 pb.private_threadpool_size = self.private_threadpool_size 59 return pb 60 61 def _from_proto(self, pb): 62 if pb.WhichOneof("optional_max_intra_op_parallelism") is not None: 63 self.max_intra_op_parallelism = pb.max_intra_op_parallelism 64 if pb.WhichOneof("optional_private_threadpool_size") is not None: 65 self.private_threadpool_size = pb.private_threadpool_size 66