1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import numpy as np
23import tensorflow as tf
24import tensorflow.contrib.mpi_collectives as mpi
25from tensorflow.python.platform import test
26
27
28average_allgather = False
29
30
31class AllgatherTest(test.TestCase):
32  def checkAllgather(self, num_ranks, all_gathered, local_gathered):
33    # Ensure that indices match.
34    all_gat_ind = np.sort(all_gathered.indices)
35    loc_gat_ind = np.sort(local_gathered.indices)
36    assert(len(loc_gat_ind) == len(all_gat_ind))
37    for i in range(len(loc_gat_ind)):
38      assert(loc_gat_ind[i] == all_gat_ind[i])
39
40    # For each index, verify same values.
41    local_checked = []
42    for i in range(len(local_gathered.indices)):
43      local_checked.append(False)
44    for i in range(len(all_gathered.indices)):
45      all_index = all_gathered.indices[i]
46      # TODO(jthestness): Make this lookup quicker using sorting.
47      loc_index = -1
48      for j in range(len(local_gathered.indices)):
49        if local_gathered.indices[j] == all_index and not local_checked[j]:
50          loc_index = j
51          local_checked[j] = True
52          break
53      assert(loc_index >= 0)
54      correct_output = local_gathered.values[loc_index][0]
55      if average_allgather:
56        correct_output = correct_output / float(num_ranks)
57      assert(all_gathered.values[i][0] == correct_output)
58
59
60  def test_mpi_allgather(self):
61    # Get MPI rank
62    my_rank = int(os.environ['PMI_RANK'])
63    num_ranks = int(os.environ['PMI_SIZE'])
64
65    indices_per_rank = 100
66    tensor_width = 10
67
68    # Create IndexedSlices for each rank, some with overlapping indices.
69    to_gather_indices = []
70    to_gather_values = []
71    to_gather = []
72    for rank_id in range(num_ranks):
73      indices = []
74      values = []
75      my_multiple = rank_id + 1
76      current_index = my_multiple
77      for i in range(indices_per_rank):
78        indices.append(current_index)
79        ones_tensor = tf.ones([tensor_width])
80        values.append(tf.multiply(ones_tensor,
81                                  tf.fill(ones_tensor.get_shape(),
82                                          float(current_index))))
83        current_index += my_multiple
84      concat_ind = tf.stack(indices)
85      concat_vals = tf.stack(values)
86      to_gather_indices.append(concat_ind)
87      to_gather_values.append(concat_vals)
88      to_gather.append(tf.IndexedSlices(concat_vals, concat_ind))
89
90    # Collect the local IndexedSlices (indices and values) to create
91    # correct IndexedSlices output.
92    correct_gather_indices = tf.concat(to_gather_indices, 0)
93    correct_gather_values = tf.concat(to_gather_values, 0)
94    correct_gather = tf.IndexedSlices(correct_gather_values,
95                                      correct_gather_indices)
96
97    all_gather = mpi.allreduce(to_gather[my_rank], average_allgather)
98
99    # NOTE: This assumes that device IDs are numbered the same as ranks.
100    gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
101    config = tf.ConfigProto(gpu_options=gpu_options)
102
103    # MPI Session to test allgather.
104    with mpi.Session(config=config) as sess:
105      sess.run(tf.global_variables_initializer())
106
107      all_gathered, local_gathered = sess.run([all_gather, correct_gather])
108
109      # Compare all_gathered with local_gathered.
110      self.checkAllgather(num_ranks, all_gathered, local_gathered)
111
112
113if __name__ == '__main__':
114  test.main()
115