1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Bigtable Ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib import bigtable
22from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
23from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
24from tensorflow.contrib.bigtable.python.ops import bigtable_api
25from tensorflow.contrib.util import loader
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.framework import errors
28from tensorflow.python.platform import resource_loader
29from tensorflow.python.platform import test
30from tensorflow.python.util import compat
31
32_bigtable_so = loader.load_op_library(
33    resource_loader.get_path_to_datafile("_bigtable_test.so"))
34
35
36def _ListOfTuplesOfStringsToBytes(values):
37  return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values]
38
39
40class BigtableOpsTest(test.TestCase):
41  COMMON_ROW_KEYS = ["r1", "r2", "r3"]
42  COMMON_VALUES = ["v1", "v2", "v3"]
43
44  def setUp(self):
45    self._client = gen_bigtable_test_ops.bigtable_test_client()
46    table = gen_bigtable_ops.bigtable_table(self._client, "testtable")
47    self._table = bigtable.BigtableTable("testtable", None, table)
48
49  def _makeSimpleDataset(self):
50    output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS)
51    output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES)
52    return dataset_ops.Dataset.zip((output_rows, output_values))
53
54  def _writeCommonValues(self, sess):
55    output_ds = self._makeSimpleDataset()
56    write_op = self._table.write(output_ds, ["cf1"], ["c1"])
57    sess.run(write_op)
58
59  def runReadKeyTest(self, read_ds):
60    itr = dataset_ops.make_initializable_iterator(read_ds)
61    n = itr.get_next()
62    expected = list(self.COMMON_ROW_KEYS)
63    expected.reverse()
64    with self.cached_session() as sess:
65      self._writeCommonValues(sess)
66      sess.run(itr.initializer)
67      for i in range(3):
68        output = sess.run(n)
69        want = expected.pop()
70        self.assertEqual(
71            compat.as_bytes(want), compat.as_bytes(output),
72            "Unequal at step %d: want: %s, got: %s" % (i, want, output))
73
74  def testReadPrefixKeys(self):
75    self.runReadKeyTest(self._table.keys_by_prefix_dataset("r"))
76
77  def testReadRangeKeys(self):
78    self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4"))
79
80  def runScanTest(self, read_ds):
81    itr = dataset_ops.make_initializable_iterator(read_ds)
82    n = itr.get_next()
83    expected_keys = list(self.COMMON_ROW_KEYS)
84    expected_keys.reverse()
85    expected_values = list(self.COMMON_VALUES)
86    expected_values.reverse()
87    with self.cached_session() as sess:
88      self._writeCommonValues(sess)
89      sess.run(itr.initializer)
90      for i in range(3):
91        output = sess.run(n)
92        want = expected_keys.pop()
93        self.assertEqual(
94            compat.as_bytes(want), compat.as_bytes(output[0]),
95            "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0]))
96        want = expected_values.pop()
97        self.assertEqual(
98            compat.as_bytes(want), compat.as_bytes(output[1]),
99            "Unequal values at step: %d: want: %s, got: %s" % (i, want,
100                                                               output[1]))
101
102  def testScanPrefixStringCol(self):
103    self.runScanTest(self._table.scan_prefix("r", cf1="c1"))
104
105  def testScanPrefixListCol(self):
106    self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
107
108  def testScanPrefixTupleCol(self):
109    self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1")))
110
111  def testScanRangeStringCol(self):
112    self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
113
114  def testScanRangeListCol(self):
115    self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
116
117  def testScanRangeTupleCol(self):
118    self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1")))
119
120  def testLookup(self):
121    ds = self._table.keys_by_prefix_dataset("r")
122    ds = ds.apply(self._table.lookup_columns(cf1="c1"))
123    itr = dataset_ops.make_initializable_iterator(ds)
124    n = itr.get_next()
125    expected_keys = list(self.COMMON_ROW_KEYS)
126    expected_values = list(self.COMMON_VALUES)
127    expected_tuples = zip(expected_keys, expected_values)
128    with self.cached_session() as sess:
129      self._writeCommonValues(sess)
130      sess.run(itr.initializer)
131      for i, elem in enumerate(expected_tuples):
132        output = sess.run(n)
133        self.assertEqual(
134            compat.as_bytes(elem[0]), compat.as_bytes(output[0]),
135            "Unequal keys at step %d: want: %s, got: %s" %
136            (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0])))
137        self.assertEqual(
138            compat.as_bytes(elem[1]), compat.as_bytes(output[1]),
139            "Unequal values at step %d: want: %s, got: %s" %
140            (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1])))
141
142  def testSampleKeys(self):
143    ds = self._table.sample_keys()
144    itr = dataset_ops.make_initializable_iterator(ds)
145    n = itr.get_next()
146    expected_key = self.COMMON_ROW_KEYS[0]
147    with self.cached_session() as sess:
148      self._writeCommonValues(sess)
149      sess.run(itr.initializer)
150      output = sess.run(n)
151      self.assertEqual(
152          compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output),
153          "Unequal keys: want: %s, got: %s" % (compat.as_bytes(
154              self.COMMON_ROW_KEYS[0]), compat.as_bytes(output)))
155      output = sess.run(n)
156      self.assertEqual(
157          compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output),
158          "Unequal keys: want: %s, got: %s" % (compat.as_bytes(
159              self.COMMON_ROW_KEYS[2]), compat.as_bytes(output)))
160      with self.assertRaises(errors.OutOfRangeError):
161        sess.run(n)
162
163  def runSampleKeyPairsTest(self, ds, expected_key_pairs):
164    itr = dataset_ops.make_initializable_iterator(ds)
165    n = itr.get_next()
166    with self.cached_session() as sess:
167      self._writeCommonValues(sess)
168      sess.run(itr.initializer)
169      for i, elems in enumerate(expected_key_pairs):
170        output = sess.run(n)
171        self.assertEqual(
172            compat.as_bytes(elems[0]), compat.as_bytes(output[0]),
173            "Unequal key pair (first element) at step %d; want: %s, got %s" %
174            (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0])))
175        self.assertEqual(
176            compat.as_bytes(elems[1]), compat.as_bytes(output[1]),
177            "Unequal key pair (second element) at step %d; want: %s, got %s" %
178            (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1])))
179      with self.assertRaises(errors.OutOfRangeError):
180        sess.run(n)
181
182  def testSampleKeyPairsSimplePrefix(self):
183    ds = bigtable_api._BigtableSampleKeyPairsDataset(
184        self._table, prefix="r", start="", end="")
185    expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")]
186    self.runSampleKeyPairsTest(ds, expected_key_pairs)
187
188  def testSampleKeyPairsSimpleRange(self):
189    ds = bigtable_api._BigtableSampleKeyPairsDataset(
190        self._table, prefix="", start="r1", end="r3")
191    expected_key_pairs = [("r1", "r3")]
192    self.runSampleKeyPairsTest(ds, expected_key_pairs)
193
194  def testSampleKeyPairsSkipRangePrefix(self):
195    ds = bigtable_api._BigtableSampleKeyPairsDataset(
196        self._table, prefix="r2", start="", end="")
197    expected_key_pairs = [("r2", "r3")]
198    self.runSampleKeyPairsTest(ds, expected_key_pairs)
199
200  def testSampleKeyPairsSkipRangeRange(self):
201    ds = bigtable_api._BigtableSampleKeyPairsDataset(
202        self._table, prefix="", start="r2", end="r3")
203    expected_key_pairs = [("r2", "r3")]
204    self.runSampleKeyPairsTest(ds, expected_key_pairs)
205
206  def testSampleKeyPairsOffsetRanges(self):
207    ds = bigtable_api._BigtableSampleKeyPairsDataset(
208        self._table, prefix="", start="r2", end="r4")
209    expected_key_pairs = [("r2", "r3"), ("r3", "r4")]
210    self.runSampleKeyPairsTest(ds, expected_key_pairs)
211
212  def testSampleKeyPairEverything(self):
213    ds = bigtable_api._BigtableSampleKeyPairsDataset(
214        self._table, prefix="", start="", end="")
215    expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")]
216    self.runSampleKeyPairsTest(ds, expected_key_pairs)
217
218  def testSampleKeyPairsPrefixAndStartKey(self):
219    ds = bigtable_api._BigtableSampleKeyPairsDataset(
220        self._table, prefix="r", start="r1", end="")
221    itr = dataset_ops.make_initializable_iterator(ds)
222    with self.cached_session() as sess:
223      with self.assertRaises(errors.InvalidArgumentError):
224        sess.run(itr.initializer)
225
226  def testSampleKeyPairsPrefixAndEndKey(self):
227    ds = bigtable_api._BigtableSampleKeyPairsDataset(
228        self._table, prefix="r", start="", end="r3")
229    itr = dataset_ops.make_initializable_iterator(ds)
230    with self.cached_session() as sess:
231      with self.assertRaises(errors.InvalidArgumentError):
232        sess.run(itr.initializer)
233
234  def testParallelScanPrefix(self):
235    ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
236    itr = dataset_ops.make_initializable_iterator(ds)
237    n = itr.get_next()
238    with self.cached_session() as sess:
239      self._writeCommonValues(sess)
240      sess.run(itr.initializer)
241      expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
242      actual_values = []
243      for _ in range(len(expected_values)):
244        output = sess.run(n)
245        actual_values.append(output)
246      with self.assertRaises(errors.OutOfRangeError):
247        sess.run(n)
248      self.assertItemsEqual(
249          _ListOfTuplesOfStringsToBytes(expected_values),
250          _ListOfTuplesOfStringsToBytes(actual_values))
251
252  def testParallelScanRange(self):
253    ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
254    itr = dataset_ops.make_initializable_iterator(ds)
255    n = itr.get_next()
256    with self.cached_session() as sess:
257      self._writeCommonValues(sess)
258      sess.run(itr.initializer)
259      expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
260      actual_values = []
261      for _ in range(len(expected_values)):
262        output = sess.run(n)
263        actual_values.append(output)
264      with self.assertRaises(errors.OutOfRangeError):
265        sess.run(n)
266      self.assertItemsEqual(
267          _ListOfTuplesOfStringsToBytes(expected_values),
268          _ListOfTuplesOfStringsToBytes(actual_values))
269
270
271if __name__ == "__main__":
272  test.main()
273