1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Bigtable Ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib import bigtable 22from tensorflow.contrib.bigtable.ops import gen_bigtable_ops 23from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops 24from tensorflow.contrib.bigtable.python.ops import bigtable_api 25from tensorflow.contrib.util import loader 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.framework import errors 28from tensorflow.python.platform import resource_loader 29from tensorflow.python.platform import test 30from tensorflow.python.util import compat 31 32_bigtable_so = loader.load_op_library( 33 resource_loader.get_path_to_datafile("_bigtable_test.so")) 34 35 36def _ListOfTuplesOfStringsToBytes(values): 37 return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values] 38 39 40class BigtableOpsTest(test.TestCase): 41 COMMON_ROW_KEYS = ["r1", "r2", "r3"] 42 COMMON_VALUES = ["v1", "v2", "v3"] 43 44 def setUp(self): 45 self._client = gen_bigtable_test_ops.bigtable_test_client() 46 table = gen_bigtable_ops.bigtable_table(self._client, "testtable") 47 self._table = bigtable.BigtableTable("testtable", None, table) 48 49 def _makeSimpleDataset(self): 50 output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS) 51 output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES) 52 return dataset_ops.Dataset.zip((output_rows, output_values)) 53 54 def _writeCommonValues(self, sess): 55 output_ds = self._makeSimpleDataset() 56 write_op = self._table.write(output_ds, ["cf1"], ["c1"]) 57 sess.run(write_op) 58 59 def runReadKeyTest(self, read_ds): 60 itr = dataset_ops.make_initializable_iterator(read_ds) 61 n = itr.get_next() 62 expected = list(self.COMMON_ROW_KEYS) 63 expected.reverse() 64 with self.cached_session() as sess: 65 self._writeCommonValues(sess) 66 sess.run(itr.initializer) 67 for i in range(3): 68 output = sess.run(n) 69 want = expected.pop() 70 self.assertEqual( 71 compat.as_bytes(want), compat.as_bytes(output), 72 "Unequal at step %d: want: %s, got: %s" % (i, want, output)) 73 74 def testReadPrefixKeys(self): 75 self.runReadKeyTest(self._table.keys_by_prefix_dataset("r")) 76 77 def testReadRangeKeys(self): 78 self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) 79 80 def runScanTest(self, read_ds): 81 itr = dataset_ops.make_initializable_iterator(read_ds) 82 n = itr.get_next() 83 expected_keys = list(self.COMMON_ROW_KEYS) 84 expected_keys.reverse() 85 expected_values = list(self.COMMON_VALUES) 86 expected_values.reverse() 87 with self.cached_session() as sess: 88 self._writeCommonValues(sess) 89 sess.run(itr.initializer) 90 for i in range(3): 91 output = sess.run(n) 92 want = expected_keys.pop() 93 self.assertEqual( 94 compat.as_bytes(want), compat.as_bytes(output[0]), 95 "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0])) 96 want = expected_values.pop() 97 self.assertEqual( 98 compat.as_bytes(want), compat.as_bytes(output[1]), 99 "Unequal values at step: %d: want: %s, got: %s" % (i, want, 100 output[1])) 101 102 def testScanPrefixStringCol(self): 103 self.runScanTest(self._table.scan_prefix("r", cf1="c1")) 104 105 def testScanPrefixListCol(self): 106 self.runScanTest(self._table.scan_prefix("r", cf1=["c1"])) 107 108 def testScanPrefixTupleCol(self): 109 self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1"))) 110 111 def testScanRangeStringCol(self): 112 self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1")) 113 114 def testScanRangeListCol(self): 115 self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"])) 116 117 def testScanRangeTupleCol(self): 118 self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1"))) 119 120 def testLookup(self): 121 ds = self._table.keys_by_prefix_dataset("r") 122 ds = ds.apply(self._table.lookup_columns(cf1="c1")) 123 itr = dataset_ops.make_initializable_iterator(ds) 124 n = itr.get_next() 125 expected_keys = list(self.COMMON_ROW_KEYS) 126 expected_values = list(self.COMMON_VALUES) 127 expected_tuples = zip(expected_keys, expected_values) 128 with self.cached_session() as sess: 129 self._writeCommonValues(sess) 130 sess.run(itr.initializer) 131 for i, elem in enumerate(expected_tuples): 132 output = sess.run(n) 133 self.assertEqual( 134 compat.as_bytes(elem[0]), compat.as_bytes(output[0]), 135 "Unequal keys at step %d: want: %s, got: %s" % 136 (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0]))) 137 self.assertEqual( 138 compat.as_bytes(elem[1]), compat.as_bytes(output[1]), 139 "Unequal values at step %d: want: %s, got: %s" % 140 (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) 141 142 def testSampleKeys(self): 143 ds = self._table.sample_keys() 144 itr = dataset_ops.make_initializable_iterator(ds) 145 n = itr.get_next() 146 expected_key = self.COMMON_ROW_KEYS[0] 147 with self.cached_session() as sess: 148 self._writeCommonValues(sess) 149 sess.run(itr.initializer) 150 output = sess.run(n) 151 self.assertEqual( 152 compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output), 153 "Unequal keys: want: %s, got: %s" % (compat.as_bytes( 154 self.COMMON_ROW_KEYS[0]), compat.as_bytes(output))) 155 output = sess.run(n) 156 self.assertEqual( 157 compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output), 158 "Unequal keys: want: %s, got: %s" % (compat.as_bytes( 159 self.COMMON_ROW_KEYS[2]), compat.as_bytes(output))) 160 with self.assertRaises(errors.OutOfRangeError): 161 sess.run(n) 162 163 def runSampleKeyPairsTest(self, ds, expected_key_pairs): 164 itr = dataset_ops.make_initializable_iterator(ds) 165 n = itr.get_next() 166 with self.cached_session() as sess: 167 self._writeCommonValues(sess) 168 sess.run(itr.initializer) 169 for i, elems in enumerate(expected_key_pairs): 170 output = sess.run(n) 171 self.assertEqual( 172 compat.as_bytes(elems[0]), compat.as_bytes(output[0]), 173 "Unequal key pair (first element) at step %d; want: %s, got %s" % 174 (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0]))) 175 self.assertEqual( 176 compat.as_bytes(elems[1]), compat.as_bytes(output[1]), 177 "Unequal key pair (second element) at step %d; want: %s, got %s" % 178 (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1]))) 179 with self.assertRaises(errors.OutOfRangeError): 180 sess.run(n) 181 182 def testSampleKeyPairsSimplePrefix(self): 183 ds = bigtable_api._BigtableSampleKeyPairsDataset( 184 self._table, prefix="r", start="", end="") 185 expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")] 186 self.runSampleKeyPairsTest(ds, expected_key_pairs) 187 188 def testSampleKeyPairsSimpleRange(self): 189 ds = bigtable_api._BigtableSampleKeyPairsDataset( 190 self._table, prefix="", start="r1", end="r3") 191 expected_key_pairs = [("r1", "r3")] 192 self.runSampleKeyPairsTest(ds, expected_key_pairs) 193 194 def testSampleKeyPairsSkipRangePrefix(self): 195 ds = bigtable_api._BigtableSampleKeyPairsDataset( 196 self._table, prefix="r2", start="", end="") 197 expected_key_pairs = [("r2", "r3")] 198 self.runSampleKeyPairsTest(ds, expected_key_pairs) 199 200 def testSampleKeyPairsSkipRangeRange(self): 201 ds = bigtable_api._BigtableSampleKeyPairsDataset( 202 self._table, prefix="", start="r2", end="r3") 203 expected_key_pairs = [("r2", "r3")] 204 self.runSampleKeyPairsTest(ds, expected_key_pairs) 205 206 def testSampleKeyPairsOffsetRanges(self): 207 ds = bigtable_api._BigtableSampleKeyPairsDataset( 208 self._table, prefix="", start="r2", end="r4") 209 expected_key_pairs = [("r2", "r3"), ("r3", "r4")] 210 self.runSampleKeyPairsTest(ds, expected_key_pairs) 211 212 def testSampleKeyPairEverything(self): 213 ds = bigtable_api._BigtableSampleKeyPairsDataset( 214 self._table, prefix="", start="", end="") 215 expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")] 216 self.runSampleKeyPairsTest(ds, expected_key_pairs) 217 218 def testSampleKeyPairsPrefixAndStartKey(self): 219 ds = bigtable_api._BigtableSampleKeyPairsDataset( 220 self._table, prefix="r", start="r1", end="") 221 itr = dataset_ops.make_initializable_iterator(ds) 222 with self.cached_session() as sess: 223 with self.assertRaises(errors.InvalidArgumentError): 224 sess.run(itr.initializer) 225 226 def testSampleKeyPairsPrefixAndEndKey(self): 227 ds = bigtable_api._BigtableSampleKeyPairsDataset( 228 self._table, prefix="r", start="", end="r3") 229 itr = dataset_ops.make_initializable_iterator(ds) 230 with self.cached_session() as sess: 231 with self.assertRaises(errors.InvalidArgumentError): 232 sess.run(itr.initializer) 233 234 def testParallelScanPrefix(self): 235 ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") 236 itr = dataset_ops.make_initializable_iterator(ds) 237 n = itr.get_next() 238 with self.cached_session() as sess: 239 self._writeCommonValues(sess) 240 sess.run(itr.initializer) 241 expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) 242 actual_values = [] 243 for _ in range(len(expected_values)): 244 output = sess.run(n) 245 actual_values.append(output) 246 with self.assertRaises(errors.OutOfRangeError): 247 sess.run(n) 248 self.assertItemsEqual( 249 _ListOfTuplesOfStringsToBytes(expected_values), 250 _ListOfTuplesOfStringsToBytes(actual_values)) 251 252 def testParallelScanRange(self): 253 ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") 254 itr = dataset_ops.make_initializable_iterator(ds) 255 n = itr.get_next() 256 with self.cached_session() as sess: 257 self._writeCommonValues(sess) 258 sess.run(itr.initializer) 259 expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) 260 actual_values = [] 261 for _ in range(len(expected_values)): 262 output = sess.run(n) 263 actual_values.append(output) 264 with self.assertRaises(errors.OutOfRangeError): 265 sess.run(n) 266 self.assertItemsEqual( 267 _ListOfTuplesOfStringsToBytes(expected_values), 268 _ListOfTuplesOfStringsToBytes(actual_values)) 269 270 271if __name__ == "__main__": 272 test.main() 273