1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""IMDB sentiment classification dataset. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import json 22 23import numpy as np 24 25from tensorflow.python.keras.preprocessing.sequence import _remove_long_seq 26from tensorflow.python.keras.utils.data_utils import get_file 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.util.tf_export import keras_export 29 30 31@keras_export('keras.datasets.imdb.load_data') 32def load_data(path='imdb.npz', 33 num_words=None, 34 skip_top=0, 35 maxlen=None, 36 seed=113, 37 start_char=1, 38 oov_char=2, 39 index_from=3, 40 **kwargs): 41 """Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/). 42 43 This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment 44 (positive/negative). Reviews have been preprocessed, and each review is 45 encoded as a list of word indexes (integers). 46 For convenience, words are indexed by overall frequency in the dataset, 47 so that for instance the integer "3" encodes the 3rd most frequent word in 48 the data. This allows for quick filtering operations such as: 49 "only consider the top 10,000 most 50 common words, but eliminate the top 20 most common words". 51 52 As a convention, "0" does not stand for a specific word, but instead is used 53 to encode any unknown word. 54 55 Args: 56 path: where to cache the data (relative to `~/.keras/dataset`). 57 num_words: integer or None. Words are 58 ranked by how often they occur (in the training set) and only 59 the `num_words` most frequent words are kept. Any less frequent word 60 will appear as `oov_char` value in the sequence data. If None, 61 all words are kept. Defaults to None, so all words are kept. 62 skip_top: skip the top N most frequently occurring words 63 (which may not be informative). These words will appear as 64 `oov_char` value in the dataset. Defaults to 0, so no words are 65 skipped. 66 maxlen: int or None. Maximum sequence length. 67 Any longer sequence will be truncated. Defaults to None, which 68 means no truncation. 69 seed: int. Seed for reproducible data shuffling. 70 start_char: int. The start of a sequence will be marked with this 71 character. Defaults to 1 because 0 is usually the padding character. 72 oov_char: int. The out-of-vocabulary character. 73 Words that were cut out because of the `num_words` or 74 `skip_top` limits will be replaced with this character. 75 index_from: int. Index actual words with this index and higher. 76 **kwargs: Used for backwards compatibility. 77 78 Returns: 79 Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 80 81 **x_train, x_test**: lists of sequences, which are lists of indexes 82 (integers). If the num_words argument was specific, the maximum 83 possible index value is `num_words - 1`. If the `maxlen` argument was 84 specified, the largest possible sequence length is `maxlen`. 85 86 **y_train, y_test**: lists of integer labels (1 or 0). 87 88 Raises: 89 ValueError: in case `maxlen` is so low 90 that no input sequence could be kept. 91 92 Note that the 'out of vocabulary' character is only used for 93 words that were present in the training set but are not included 94 because they're not making the `num_words` cut here. 95 Words that were not seen in the training set but are in the test set 96 have simply been skipped. 97 """ 98 # Legacy support 99 if 'nb_words' in kwargs: 100 logging.warning('The `nb_words` argument in `load_data` ' 101 'has been renamed `num_words`.') 102 num_words = kwargs.pop('nb_words') 103 if kwargs: 104 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 105 106 origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' 107 path = get_file( 108 path, 109 origin=origin_folder + 'imdb.npz', 110 file_hash= 111 '69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f') 112 with np.load(path, allow_pickle=True) as f: 113 x_train, labels_train = f['x_train'], f['y_train'] 114 x_test, labels_test = f['x_test'], f['y_test'] 115 116 rng = np.random.RandomState(seed) 117 indices = np.arange(len(x_train)) 118 rng.shuffle(indices) 119 x_train = x_train[indices] 120 labels_train = labels_train[indices] 121 122 indices = np.arange(len(x_test)) 123 rng.shuffle(indices) 124 x_test = x_test[indices] 125 labels_test = labels_test[indices] 126 127 if start_char is not None: 128 x_train = [[start_char] + [w + index_from for w in x] for x in x_train] 129 x_test = [[start_char] + [w + index_from for w in x] for x in x_test] 130 elif index_from: 131 x_train = [[w + index_from for w in x] for x in x_train] 132 x_test = [[w + index_from for w in x] for x in x_test] 133 134 if maxlen: 135 x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train) 136 x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test) 137 if not x_train or not x_test: 138 raise ValueError('After filtering for sequences shorter than maxlen=' + 139 str(maxlen) + ', no sequence was kept. ' 140 'Increase maxlen.') 141 142 xs = np.concatenate([x_train, x_test]) 143 labels = np.concatenate([labels_train, labels_test]) 144 145 if not num_words: 146 num_words = max(max(x) for x in xs) 147 148 # by convention, use 2 as OOV word 149 # reserve 'index_from' (=3 by default) characters: 150 # 0 (padding), 1 (start), 2 (OOV) 151 if oov_char is not None: 152 xs = [ 153 [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs 154 ] 155 else: 156 xs = [[w for w in x if skip_top <= w < num_words] for x in xs] 157 158 idx = len(x_train) 159 x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx]) 160 x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:]) 161 162 return (x_train, y_train), (x_test, y_test) 163 164 165@keras_export('keras.datasets.imdb.get_word_index') 166def get_word_index(path='imdb_word_index.json'): 167 """Retrieves a dict mapping words to their index in the IMDB dataset. 168 169 Args: 170 path: where to cache the data (relative to `~/.keras/dataset`). 171 172 Returns: 173 The word index dictionary. Keys are word strings, values are their index. 174 """ 175 origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' 176 path = get_file( 177 path, 178 origin=origin_folder + 'imdb_word_index.json', 179 file_hash='bfafd718b763782e994055a2d397834f') 180 with open(path) as f: 181 return json.load(f) 182