1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Python API for TensorFlow's Cloud Bigtable integration.
16
17TensorFlow has support for reading from and writing to Cloud Bigtable. To use
18TensorFlow + Cloud Bigtable integration, first create a BigtableClient to
19configure your connection to Cloud Bigtable, and then create a BigtableTable
20object to allow you to create numerous `tf.data.Dataset`s to read data, or
21write a `tf.data.Dataset` object to the underlying Cloud Bigtable table.
22
23For background on Cloud Bigtable, see: https://cloud.google.com/bigtable .
24"""
25
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30from six import iteritems
31from six import string_types
32
33from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
34from tensorflow.contrib.util import loader
35from tensorflow.python.data.experimental.ops import interleave_ops
36from tensorflow.python.data.ops import dataset_ops
37from tensorflow.python.data.util import nest
38from tensorflow.python.data.util import structure
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import tensor_shape
41from tensorflow.python.platform import resource_loader
42
43_bigtable_so = loader.load_op_library(
44    resource_loader.get_path_to_datafile("_bigtable.so"))
45
46
47class BigtableClient(object):
48  """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
49
50  BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
51  `table` method to open a Bigtable table.
52  """
53
54  def __init__(self,
55               project_id,
56               instance_id,
57               connection_pool_size=None,
58               max_receive_message_size=None):
59    """Creates a BigtableClient that can be used to open connections to tables.
60
61    Args:
62      project_id: A string representing the GCP project id to connect to.
63      instance_id: A string representing the Bigtable instance to connect to.
64      connection_pool_size: (Optional.) A number representing the number of
65        concurrent connections to the Cloud Bigtable service to make.
66      max_receive_message_size: (Optional.) The maximum bytes received in a
67        single gRPC response.
68
69    Raises:
70      ValueError: if the arguments are invalid (e.g. wrong type, or out of
71        expected ranges (e.g. negative).)
72    """
73    if not isinstance(project_id, str):
74      raise ValueError("`project_id` must be a string")
75    self._project_id = project_id
76
77    if not isinstance(instance_id, str):
78      raise ValueError("`instance_id` must be a string")
79    self._instance_id = instance_id
80
81    if connection_pool_size is None:
82      connection_pool_size = -1
83    elif connection_pool_size < 1:
84      raise ValueError("`connection_pool_size` must be positive")
85
86    if max_receive_message_size is None:
87      max_receive_message_size = -1
88    elif max_receive_message_size < 1:
89      raise ValueError("`max_receive_message_size` must be positive")
90
91    self._connection_pool_size = connection_pool_size
92
93    self._resource = gen_bigtable_ops.bigtable_client(
94        project_id, instance_id, connection_pool_size, max_receive_message_size)
95
96  def table(self, name, snapshot=None):
97    """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object.
98
99    Args:
100      name: A `tf.string` `tf.Tensor` name of the table to open.
101      snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to
102        request the creation of a snapshot. (Note: currently unimplemented.)
103
104    Returns:
105      A `tf.contrib.bigtable.BigtableTable` Python object representing the
106      operations available on the table.
107    """
108    # TODO(saeta): Implement snapshot functionality.
109    table = gen_bigtable_ops.bigtable_table(self._resource, name)
110    return BigtableTable(name, snapshot, table)
111
112
113class BigtableTable(object):
114  """Entry point for reading and writing data in Cloud Bigtable.
115
116  This BigtableTable class is the Python representation of the Cloud Bigtable
117  table within TensorFlow. Methods on this class allow data to be read from and
118  written to the Cloud Bigtable service in flexible and high performance
119  manners.
120  """
121
122  # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface.
123  # TODO(saeta): Consider variant tensors instead of resources (while supporting
124  #    connection pooling).
125
126  def __init__(self, name, snapshot, resource):
127    self._name = name
128    self._snapshot = snapshot
129    self._resource = resource
130
131  def lookup_columns(self, *args, **kwargs):
132    """Retrieves the values of columns for a dataset of keys.
133
134    Example usage:
135
136    ```python
137    table = bigtable_client.table("my_table")
138    key_dataset = table.get_keys_prefix("imagenet")
139    images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
140                                                    ("cf2", "label"),
141                                                    ("cf2", "boundingbox")))
142    training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
143    ```
144
145    Alternatively, you can use keyword arguments to specify the columns to
146    capture. Example (same as above, rewritten):
147
148    ```python
149    table = bigtable_client.table("my_table")
150    key_dataset = table.get_keys_prefix("imagenet")
151    images = key_dataset.apply(table.lookup_columns(
152        cf1="image", cf2=("label", "boundingbox")))
153    training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
154    ```
155
156    Note: certain `kwargs` keys are reserved, and thus, some column families
157    cannot be identified using the `kwargs` syntax. Instead, please use the
158    `args` syntax. This list includes:
159
160      - 'name'
161
162    Note: this list can change at any time.
163
164    Args:
165      *args: A list of tuples containing (column family, column name) pairs.
166      **kwargs: Column families (keys) and column qualifiers (values).
167
168    Returns:
169      A function that can be passed to `tf.data.Dataset.apply` to retrieve the
170      values of columns for the rows.
171    """
172    table = self  # Capture self
173    normalized = args
174    if normalized is None:
175      normalized = []
176    if isinstance(normalized, tuple):
177      normalized = list(normalized)
178    for key, value in iteritems(kwargs):
179      if key == "name":
180        continue
181      if isinstance(value, str):
182        normalized.append((key, value))
183        continue
184      for col in value:
185        normalized.append((key, col))
186
187    def _apply_fn(dataset):
188      # TODO(saeta): Verify dataset's types are correct!
189      return _BigtableLookupDataset(dataset, table, normalized)
190
191    return _apply_fn
192
193  def keys_by_range_dataset(self, start, end):
194    """Retrieves all row keys between start and end.
195
196    Note: it does NOT retrieve the values of columns.
197
198    Args:
199      start: The start row key. The row keys for rows after start (inclusive)
200        will be retrieved.
201      end: (Optional.) The end row key. Rows up to (but not including) end will
202        be retrieved. If end is None, all subsequent row keys will be retrieved.
203
204    Returns:
205      A `tf.data.Dataset` containing `tf.string` Tensors corresponding to all
206      of the row keys between `start` and `end`.
207    """
208    # TODO(saeta): Make inclusive / exclusive configurable?
209    if end is None:
210      end = ""
211    return _BigtableRangeKeyDataset(self, start, end)
212
213  def keys_by_prefix_dataset(self, prefix):
214    """Retrieves the row keys matching a given prefix.
215
216    Args:
217      prefix: All row keys that begin with `prefix` in the table will be
218        retrieved.
219
220    Returns:
221      A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all
222      of the row keys matching that prefix.
223    """
224    return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix))
225
226  def sample_keys(self):
227    """Retrieves a sampling of row keys from the Bigtable table.
228
229    This dataset is most often used in conjunction with
230    `tf.data.experimental.parallel_interleave` to construct a set of ranges for
231    scanning in parallel.
232
233    Returns:
234      A `tf.data.Dataset` returning string row keys.
235    """
236    return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self))
237
238  def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
239    """Retrieves row (including values) from the Bigtable service.
240
241    Rows with row-key prefixed by `prefix` will be retrieved.
242
243    Specifying the columns to retrieve for each row is done by either using
244    kwargs or in the columns parameter. To retrieve values of the columns "c1",
245    and "c2" from the column family "cfa", and the value of the column "c3"
246    from column family "cfb", the following datasets (`ds1`, and `ds2`) are
247    equivalent:
248
249    ```
250    table = # ...
251    ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"),
252                                                   ("cfa", "c2"),
253                                                   ("cfb", "c3")])
254    ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
255    ```
256
257    Note: only the latest value of a cell will be retrieved.
258
259    Args:
260      prefix: The prefix all row keys must match to be retrieved for prefix-
261        based scans.
262      probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
263        A non-1 value indicates to probabilistically sample rows with the
264        provided probability.
265      columns: The columns to read. Note: most commonly, they are expressed as
266        kwargs. Use the columns value if you are using column families that are
267        reserved. The value of columns and kwargs are merged. Columns is a list
268        of tuples of strings ("column_family", "column_qualifier").
269      **kwargs: The column families and columns to read. Keys are treated as
270        column_families, and values can be either lists of strings, or strings
271        that are treated as the column qualifier (column name).
272
273    Returns:
274      A `tf.data.Dataset` returning the row keys and the cell contents.
275
276    Raises:
277      ValueError: If the configured probability is unexpected.
278    """
279    probability = _normalize_probability(probability)
280    normalized = _normalize_columns(columns, kwargs)
281    return dataset_ops.DatasetV1Adapter(
282        _BigtableScanDataset(self, prefix, "", "", normalized, probability))
283
284  def scan_range(self, start, end, probability=None, columns=None, **kwargs):
285    """Retrieves rows (including values) from the Bigtable service.
286
287    Rows with row-keys between `start` and `end` will be retrieved.
288
289    Specifying the columns to retrieve for each row is done by either using
290    kwargs or in the columns parameter. To retrieve values of the columns "c1",
291    and "c2" from the column family "cfa", and the value of the column "c3"
292    from column family "cfb", the following datasets (`ds1`, and `ds2`) are
293    equivalent:
294
295    ```
296    table = # ...
297    ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
298                                                            ("cfa", "c2"),
299                                                            ("cfb", "c3")])
300    ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
301    ```
302
303    Note: only the latest value of a cell will be retrieved.
304
305    Args:
306      start: The start of the range when scanning by range.
307      end: (Optional.) The end of the range when scanning by range.
308      probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
309        A non-1 value indicates to probabilistically sample rows with the
310        provided probability.
311      columns: The columns to read. Note: most commonly, they are expressed as
312        kwargs. Use the columns value if you are using column families that are
313        reserved. The value of columns and kwargs are merged. Columns is a list
314        of tuples of strings ("column_family", "column_qualifier").
315      **kwargs: The column families and columns to read. Keys are treated as
316        column_families, and values can be either lists of strings, or strings
317        that are treated as the column qualifier (column name).
318
319    Returns:
320      A `tf.data.Dataset` returning the row keys and the cell contents.
321
322    Raises:
323      ValueError: If the configured probability is unexpected.
324    """
325    probability = _normalize_probability(probability)
326    normalized = _normalize_columns(columns, kwargs)
327    return dataset_ops.DatasetV1Adapter(
328        _BigtableScanDataset(self, "", start, end, normalized, probability))
329
330  def parallel_scan_prefix(self,
331                           prefix,
332                           num_parallel_scans=None,
333                           probability=None,
334                           columns=None,
335                           **kwargs):
336    """Retrieves row (including values) from the Bigtable service at high speed.
337
338    Rows with row-key prefixed by `prefix` will be retrieved. This method is
339    similar to `scan_prefix`, but by contrast performs multiple sub-scans in
340    parallel in order to achieve higher performance.
341
342    Note: The dataset produced by this method is not deterministic!
343
344    Specifying the columns to retrieve for each row is done by either using
345    kwargs or in the columns parameter. To retrieve values of the columns "c1",
346    and "c2" from the column family "cfa", and the value of the column "c3"
347    from column family "cfb", the following datasets (`ds1`, and `ds2`) are
348    equivalent:
349
350    ```
351    table = # ...
352    ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
353                                                            ("cfa", "c2"),
354                                                            ("cfb", "c3")])
355    ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
356    ```
357
358    Note: only the latest value of a cell will be retrieved.
359
360    Args:
361      prefix: The prefix all row keys must match to be retrieved for prefix-
362        based scans.
363      num_parallel_scans: (Optional.) The number of concurrent scans against the
364        Cloud Bigtable instance.
365      probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
366        A non-1 value indicates to probabilistically sample rows with the
367        provided probability.
368      columns: The columns to read. Note: most commonly, they are expressed as
369        kwargs. Use the columns value if you are using column families that are
370        reserved. The value of columns and kwargs are merged. Columns is a list
371        of tuples of strings ("column_family", "column_qualifier").
372      **kwargs: The column families and columns to read. Keys are treated as
373        column_families, and values can be either lists of strings, or strings
374        that are treated as the column qualifier (column name).
375
376    Returns:
377      A `tf.data.Dataset` returning the row keys and the cell contents.
378
379    Raises:
380      ValueError: If the configured probability is unexpected.
381    """
382    probability = _normalize_probability(probability)
383    normalized = _normalize_columns(columns, kwargs)
384    ds = dataset_ops.DatasetV1Adapter(
385        _BigtableSampleKeyPairsDataset(self, prefix, "", ""))
386    return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
387                                            normalized)
388
389  def parallel_scan_range(self,
390                          start,
391                          end,
392                          num_parallel_scans=None,
393                          probability=None,
394                          columns=None,
395                          **kwargs):
396    """Retrieves rows (including values) from the Bigtable service.
397
398    Rows with row-keys between `start` and `end` will be retrieved. This method
399    is similar to `scan_range`, but by contrast performs multiple sub-scans in
400    parallel in order to achieve higher performance.
401
402    Note: The dataset produced by this method is not deterministic!
403
404    Specifying the columns to retrieve for each row is done by either using
405    kwargs or in the columns parameter. To retrieve values of the columns "c1",
406    and "c2" from the column family "cfa", and the value of the column "c3"
407    from column family "cfb", the following datasets (`ds1`, and `ds2`) are
408    equivalent:
409
410    ```
411    table = # ...
412    ds1 = table.parallel_scan_range("row_start",
413                                    "row_end",
414                                    columns=[("cfa", "c1"),
415                                             ("cfa", "c2"),
416                                             ("cfb", "c3")])
417    ds2 = table.parallel_scan_range("row_start", "row_end",
418                                    cfa=["c1", "c2"], cfb="c3")
419    ```
420
421    Note: only the latest value of a cell will be retrieved.
422
423    Args:
424      start: The start of the range when scanning by range.
425      end: (Optional.) The end of the range when scanning by range.
426      num_parallel_scans: (Optional.) The number of concurrent scans against the
427        Cloud Bigtable instance.
428      probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
429        A non-1 value indicates to probabilistically sample rows with the
430        provided probability.
431      columns: The columns to read. Note: most commonly, they are expressed as
432        kwargs. Use the columns value if you are using column families that are
433        reserved. The value of columns and kwargs are merged. Columns is a list
434        of tuples of strings ("column_family", "column_qualifier").
435      **kwargs: The column families and columns to read. Keys are treated as
436        column_families, and values can be either lists of strings, or strings
437        that are treated as the column qualifier (column name).
438
439    Returns:
440      A `tf.data.Dataset` returning the row keys and the cell contents.
441
442    Raises:
443      ValueError: If the configured probability is unexpected.
444    """
445    probability = _normalize_probability(probability)
446    normalized = _normalize_columns(columns, kwargs)
447    ds = dataset_ops.DatasetV1Adapter(
448        _BigtableSampleKeyPairsDataset(self, "", start, end))
449    return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
450                                            normalized)
451
452  def write(self, dataset, column_families, columns, timestamp=None):
453    """Writes a dataset to the table.
454
455    Args:
456      dataset: A `tf.data.Dataset` to be written to this table. It must produce
457        a list of number-of-columns+1 elements, all of which must be strings.
458        The first value will be used as the row key, and subsequent values will
459        be used as cell values for the corresponding columns from the
460        corresponding column_families and columns entries.
461      column_families: A `tf.Tensor` of `tf.string`s corresponding to the
462        column names to store the dataset's elements into.
463      columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
464        to store the dataset's elements into.
465      timestamp: (Optional.) An int64 timestamp to write all the values at.
466        Leave as None to use server-provided timestamps.
467
468    Returns:
469      A `tf.Operation` that can be run to perform the write.
470
471    Raises:
472      ValueError: If there are unexpected or incompatible types, or if the
473        number of columns and column_families does not match the output of
474        `dataset`.
475    """
476    if timestamp is None:
477      timestamp = -1  # Bigtable server provided timestamp.
478    for tensor_type in nest.flatten(
479        dataset_ops.get_legacy_output_types(dataset)):
480      if tensor_type != dtypes.string:
481        raise ValueError("Not all elements of the dataset were `tf.string`")
482    for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
483      if not shape.is_compatible_with(tensor_shape.scalar()):
484        raise ValueError("Not all elements of the dataset were scalars")
485    if len(column_families) != len(columns):
486      raise ValueError("len(column_families) != len(columns)")
487    if len(nest.flatten(
488        dataset_ops.get_legacy_output_types(dataset))) != len(columns) + 1:
489      raise ValueError("A column name must be specified for every component of "
490                       "the dataset elements. (e.g.: len(columns) != "
491                       "len(dataset.output_types))")
492    return gen_bigtable_ops.dataset_to_bigtable(
493        self._resource,
494        dataset._variant_tensor,  # pylint: disable=protected-access
495        column_families,
496        columns,
497        timestamp)
498
499  def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
500                                  normalized_probability, normalized_columns):
501    """Builds a parallel dataset from a given range.
502
503    Args:
504      ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
505      num_parallel_scans: The number of concurrent parallel scans to use.
506      normalized_probability: A number between 0 and 1 for the keep probability.
507      normalized_columns: The column families and column qualifiers to retrieve.
508
509    Returns:
510      A `tf.data.Dataset` representing the result of the parallel scan.
511    """
512    if num_parallel_scans is None:
513      num_parallel_scans = 50
514
515    ds = ds.shuffle(buffer_size=10000)  # TODO(saeta): Make configurable.
516
517    def _interleave_fn(start, end):
518      return _BigtableScanDataset(
519          self,
520          prefix="",
521          start=start,
522          end=end,
523          normalized=normalized_columns,
524          probability=normalized_probability)
525
526    # Note prefetch_input_elements must be set in order to avoid rpc timeouts.
527    ds = ds.apply(
528        interleave_ops.parallel_interleave(
529            _interleave_fn,
530            cycle_length=num_parallel_scans,
531            sloppy=True,
532            prefetch_input_elements=1))
533    return ds
534
535
536def _normalize_probability(probability):
537  if probability is None:
538    probability = 1.0
539  if isinstance(probability, float) and (probability <= 0.0 or
540                                         probability > 1.0):
541    raise ValueError("probability must be in the range (0, 1].")
542  return probability
543
544
545def _normalize_columns(columns, provided_kwargs):
546  """Converts arguments (columns, and kwargs dict) to C++ representation.
547
548  Args:
549    columns: a datastructure containing the column families and qualifier to
550      retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
551      strings.
552    provided_kwargs: a dictionary containing the column families and qualifiers
553      to retrieve
554
555  Returns:
556    A list of pairs of column family+qualifier to retrieve.
557
558  Raises:
559    ValueError: If there are no cells to retrieve or the columns are in an
560      incorrect format.
561  """
562  normalized = columns
563  if normalized is None:
564    normalized = []
565  if isinstance(normalized, tuple):
566    if len(normalized) == 2:
567      normalized = [normalized]
568    else:
569      raise ValueError("columns was a tuple of inappropriate length")
570  for key, value in iteritems(provided_kwargs):
571    if key == "name":
572      continue
573    if isinstance(value, string_types):
574      normalized.append((key, value))
575      continue
576    for col in value:
577      normalized.append((key, col))
578  if not normalized:
579    raise ValueError("At least one column + column family must be specified.")
580  return normalized
581
582
583class _BigtableKeyDataset(dataset_ops.DatasetSource):
584  """_BigtableKeyDataset is an abstract class representing the keys of a table.
585  """
586
587  def __init__(self, table, variant_tensor):
588    """Constructs a _BigtableKeyDataset.
589
590    Args:
591      table: a Bigtable class.
592      variant_tensor: DT_VARIANT representation of the dataset.
593    """
594    super(_BigtableKeyDataset, self).__init__(variant_tensor)
595    self._table = table
596
597  @property
598  def _element_structure(self):
599    return structure.TensorStructure(dtypes.string, [])
600
601
602class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
603  """_BigtablePrefixKeyDataset represents looking up keys by prefix.
604  """
605
606  def __init__(self, table, prefix):
607    self._prefix = prefix
608    variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset(
609        table=table._resource,  # pylint: disable=protected-access
610        prefix=self._prefix)
611    super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor)
612
613
614class _BigtableRangeKeyDataset(_BigtableKeyDataset):
615  """_BigtableRangeKeyDataset represents looking up keys by range.
616  """
617
618  def __init__(self, table, start, end):
619    self._start = start
620    self._end = end
621    variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset(
622        table=table._resource,  # pylint: disable=protected-access
623        start_key=self._start,
624        end_key=self._end)
625    super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor)
626
627
628class _BigtableSampleKeysDataset(_BigtableKeyDataset):
629  """_BigtableSampleKeysDataset represents a sampling of row keys.
630  """
631
632  # TODO(saeta): Expose the data size offsets into the keys.
633
634  def __init__(self, table):
635    variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset(
636        table=table._resource)  # pylint: disable=protected-access
637    super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor)
638
639
640class _BigtableLookupDataset(dataset_ops.DatasetSource):
641  """_BigtableLookupDataset represents a dataset that retrieves values for keys.
642  """
643
644  def __init__(self, dataset, table, normalized):
645    self._num_outputs = len(normalized) + 1  # 1 for row key
646    self._dataset = dataset
647    self._table = table
648    self._normalized = normalized
649    self._column_families = [i[0] for i in normalized]
650    self._columns = [i[1] for i in normalized]
651    variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset(
652        keys_dataset=self._dataset._variant_tensor,  # pylint: disable=protected-access
653        table=self._table._resource,  # pylint: disable=protected-access
654        column_families=self._column_families,
655        columns=self._columns)
656    super(_BigtableLookupDataset, self).__init__(variant_tensor)
657
658  @property
659  def _element_structure(self):
660    return structure.NestedStructure(tuple(
661        [structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
662
663
664class _BigtableScanDataset(dataset_ops.DatasetSource):
665  """_BigtableScanDataset represents a dataset that retrieves keys and values.
666  """
667
668  def __init__(self, table, prefix, start, end, normalized, probability):
669    self._table = table
670    self._prefix = prefix
671    self._start = start
672    self._end = end
673    self._column_families = [i[0] for i in normalized]
674    self._columns = [i[1] for i in normalized]
675    self._probability = probability
676    self._num_outputs = len(normalized) + 1  # 1 for row key
677    variant_tensor = gen_bigtable_ops.bigtable_scan_dataset(
678        table=self._table._resource,  # pylint: disable=protected-access
679        prefix=self._prefix,
680        start_key=self._start,
681        end_key=self._end,
682        column_families=self._column_families,
683        columns=self._columns,
684        probability=self._probability)
685    super(_BigtableScanDataset, self).__init__(variant_tensor)
686
687  @property
688  def _element_structure(self):
689    return structure.NestedStructure(
690        tuple(
691            [structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
692
693
694class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
695  """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
696  """
697
698  def __init__(self, table, prefix, start, end):
699    self._table = table
700    self._prefix = prefix
701    self._start = start
702    self._end = end
703    variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
704        table=self._table._resource,  # pylint: disable=protected-access
705        prefix=self._prefix,
706        start_key=self._start,
707        end_key=self._end)
708    super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
709
710  @property
711  def _element_structure(self):
712    return structure.NestedStructure(
713        (structure.TensorStructure(dtypes.string, []),
714         structure.TensorStructure(dtypes.string, [])))
715