1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""MNIST handwritten digits dataset. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.keras.utils.data_utils import get_file 24from tensorflow.python.util.tf_export import keras_export 25 26 27@keras_export('keras.datasets.mnist.load_data') 28def load_data(path='mnist.npz'): 29 """Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). 30 31 This is a dataset of 60,000 28x28 grayscale images of the 10 digits, 32 along with a test set of 10,000 images. 33 More info can be found at the 34 [MNIST homepage](http://yann.lecun.com/exdb/mnist/). 35 36 37 Args: 38 path: path where to cache the dataset locally 39 (relative to `~/.keras/datasets`). 40 41 Returns: 42 Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 43 44 **x_train, x_test**: uint8 arrays of grayscale image data with shapes 45 (num_samples, 28, 28). 46 47 **y_train, y_test**: uint8 arrays of digit labels (integers in range 0-9) 48 with shapes (num_samples,). 49 50 License: 51 Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, 52 which is a derivative work from original NIST datasets. 53 MNIST dataset is made available under the terms of the 54 [Creative Commons Attribution-Share Alike 3.0 license.]( 55 https://creativecommons.org/licenses/by-sa/3.0/) 56 """ 57 origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' 58 path = get_file( 59 path, 60 origin=origin_folder + 'mnist.npz', 61 file_hash= 62 '731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') 63 with np.load(path, allow_pickle=True) as f: 64 x_train, y_train = f['x_train'], f['y_train'] 65 x_test, y_test = f['x_test'], f['y_test'] 66 67 return (x_train, y_train), (x_test, y_test) 68