1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Methods to allow dask.DataFrame (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import numpy as np
28
29from tensorflow.python.util.deprecation import deprecated
30
31try:
32  # pylint: disable=g-import-not-at-top
33  import dask.dataframe as dd
34  allowed_classes = (dd.Series, dd.DataFrame)
35  HAS_DASK = True
36except ImportError:
37  HAS_DASK = False
38
39
40def _add_to_index(df, start):
41  """New dask.dataframe with values added to index of each subdataframe."""
42  df = df.copy()
43  df.index += start
44  return df
45
46
47def _get_divisions(df):
48  """Number of rows in each sub-dataframe."""
49  lengths = df.map_partitions(len).compute()
50  divisions = np.cumsum(lengths).tolist()
51  divisions.insert(0, 0)
52  return divisions
53
54
55def _construct_dask_df_with_divisions(df):
56  """Construct the new task graph and make a new dask.dataframe around it."""
57  divisions = _get_divisions(df)
58  # pylint: disable=protected-access
59  name = 'csv-index' + df._name
60  dsk = {(name, i): (_add_to_index, (df._name, i), divisions[i])
61         for i in range(df.npartitions)}
62  # pylint: enable=protected-access
63  from toolz import merge  # pylint: disable=g-import-not-at-top
64  if isinstance(df, dd.DataFrame):
65    return dd.DataFrame(merge(dsk, df.dask), name, df.columns, divisions)
66  elif isinstance(df, dd.Series):
67    return dd.Series(merge(dsk, df.dask), name, df.name, divisions)
68
69
70@deprecated(None, 'Please feed input to tf.data to support dask.')
71def extract_dask_data(data):
72  """Extract data from dask.Series or dask.DataFrame for predictors.
73
74  Given a distributed dask.DataFrame or dask.Series containing columns or names
75  for one or more predictors, this operation returns a single dask.DataFrame or
76  dask.Series that can be iterated over.
77
78  Args:
79    data: A distributed dask.DataFrame or dask.Series.
80
81  Returns:
82    A dask.DataFrame or dask.Series that can be iterated over.
83    If the supplied argument is neither a dask.DataFrame nor a dask.Series this
84    operation returns it without modification.
85  """
86  if isinstance(data, allowed_classes):
87    return _construct_dask_df_with_divisions(data)
88  else:
89    return data
90
91
92@deprecated(None, 'Please feed input to tf.data to support dask.')
93def extract_dask_labels(labels):
94  """Extract data from dask.Series or dask.DataFrame for labels.
95
96  Given a distributed dask.DataFrame or dask.Series containing exactly one
97  column or name, this operation returns a single dask.DataFrame or dask.Series
98  that can be iterated over.
99
100  Args:
101    labels: A distributed dask.DataFrame or dask.Series with exactly one
102            column or name.
103
104  Returns:
105    A dask.DataFrame or dask.Series that can be iterated over.
106    If the supplied argument is neither a dask.DataFrame nor a dask.Series this
107    operation returns it without modification.
108
109  Raises:
110    ValueError: If the supplied dask.DataFrame contains more than one
111                column or the supplied dask.Series contains more than
112                one name.
113  """
114  if isinstance(labels, dd.DataFrame):
115    ncol = labels.columns
116  elif isinstance(labels, dd.Series):
117    ncol = labels.name
118  if isinstance(labels, allowed_classes):
119    if len(ncol) > 1:
120      raise ValueError('Only one column for labels is allowed.')
121    return _construct_dask_df_with_divisions(labels)
122  else:
123    return labels
124