1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""MNIST handwritten digits dataset.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.keras.utils.data_utils import get_file
24from tensorflow.python.util.tf_export import keras_export
25
26
27@keras_export('keras.datasets.mnist.load_data')
28def load_data(path='mnist.npz'):
29  """Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
30
31  This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
32  along with a test set of 10,000 images.
33  More info can be found at the
34  [MNIST homepage](http://yann.lecun.com/exdb/mnist/).
35
36
37  Args:
38      path: path where to cache the dataset locally
39          (relative to `~/.keras/datasets`).
40
41  Returns:
42      Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
43
44      **x_train, x_test**: uint8 arrays of grayscale image data with shapes
45        (num_samples, 28, 28).
46
47      **y_train, y_test**: uint8 arrays of digit labels (integers in range 0-9)
48        with shapes (num_samples,).
49
50  License:
51      Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
52      which is a derivative work from original NIST datasets.
53      MNIST dataset is made available under the terms of the
54      [Creative Commons Attribution-Share Alike 3.0 license.](
55      https://creativecommons.org/licenses/by-sa/3.0/)
56  """
57  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
58  path = get_file(
59      path,
60      origin=origin_folder + 'mnist.npz',
61      file_hash=
62      '731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1')
63  with np.load(path, allow_pickle=True) as f:
64    x_train, y_train = f['x_train'], f['y_train']
65    x_test, y_test = f['x_test'], f['y_test']
66
67    return (x_train, y_train), (x_test, y_test)
68