1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import numpy as np 23import tensorflow as tf 24import tensorflow.contrib.mpi_collectives as mpi 25from tensorflow.python.platform import test 26 27 28average_allgather = False 29 30 31class AllgatherTest(test.TestCase): 32 def checkAllgather(self, num_ranks, all_gathered, local_gathered): 33 # Ensure that indices match. 34 all_gat_ind = np.sort(all_gathered.indices) 35 loc_gat_ind = np.sort(local_gathered.indices) 36 assert(len(loc_gat_ind) == len(all_gat_ind)) 37 for i in range(len(loc_gat_ind)): 38 assert(loc_gat_ind[i] == all_gat_ind[i]) 39 40 # For each index, verify same values. 41 local_checked = [] 42 for i in range(len(local_gathered.indices)): 43 local_checked.append(False) 44 for i in range(len(all_gathered.indices)): 45 all_index = all_gathered.indices[i] 46 # TODO(jthestness): Make this lookup quicker using sorting. 47 loc_index = -1 48 for j in range(len(local_gathered.indices)): 49 if local_gathered.indices[j] == all_index and not local_checked[j]: 50 loc_index = j 51 local_checked[j] = True 52 break 53 assert(loc_index >= 0) 54 correct_output = local_gathered.values[loc_index][0] 55 if average_allgather: 56 correct_output = correct_output / float(num_ranks) 57 assert(all_gathered.values[i][0] == correct_output) 58 59 60 def test_mpi_allgather(self): 61 # Get MPI rank 62 my_rank = int(os.environ['PMI_RANK']) 63 num_ranks = int(os.environ['PMI_SIZE']) 64 65 indices_per_rank = 100 66 tensor_width = 10 67 68 # Create IndexedSlices for each rank, some with overlapping indices. 69 to_gather_indices = [] 70 to_gather_values = [] 71 to_gather = [] 72 for rank_id in range(num_ranks): 73 indices = [] 74 values = [] 75 my_multiple = rank_id + 1 76 current_index = my_multiple 77 for i in range(indices_per_rank): 78 indices.append(current_index) 79 ones_tensor = tf.ones([tensor_width]) 80 values.append(tf.multiply(ones_tensor, 81 tf.fill(ones_tensor.get_shape(), 82 float(current_index)))) 83 current_index += my_multiple 84 concat_ind = tf.stack(indices) 85 concat_vals = tf.stack(values) 86 to_gather_indices.append(concat_ind) 87 to_gather_values.append(concat_vals) 88 to_gather.append(tf.IndexedSlices(concat_vals, concat_ind)) 89 90 # Collect the local IndexedSlices (indices and values) to create 91 # correct IndexedSlices output. 92 correct_gather_indices = tf.concat(to_gather_indices, 0) 93 correct_gather_values = tf.concat(to_gather_values, 0) 94 correct_gather = tf.IndexedSlices(correct_gather_values, 95 correct_gather_indices) 96 97 all_gather = mpi.allreduce(to_gather[my_rank], average_allgather) 98 99 # NOTE: This assumes that device IDs are numbered the same as ranks. 100 gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) 101 config = tf.ConfigProto(gpu_options=gpu_options) 102 103 # MPI Session to test allgather. 104 with mpi.Session(config=config) as sess: 105 sess.run(tf.global_variables_initializer()) 106 107 all_gathered, local_gathered = sess.run([all_gather, correct_gather]) 108 109 # Compare all_gathered with local_gathered. 110 self.checkAllgather(num_ranks, all_gathered, local_gathered) 111 112 113if __name__ == '__main__': 114 test.main() 115