1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow collective Ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import device 21from tensorflow.python.ops import gen_collective_ops 22 23 24def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op, 25 subdiv_offsets=(0,)): 26 """Reduces tensors collectively, across devices. 27 28 Args: 29 t: the tensor to be reduced. 30 group_size: the total number of tensors to be collectively reduced. 31 Each must reside on a different device. 32 group_key: an integer identifying the group of devices. 33 instance_key: an integer identifying the participating group of Ops. 34 merge_op: string naming the binary Op to be applied to compute each 35 partial reduction. 36 final_op: string naming the unary Op to be applied to each fully 37 reduced value. Can be 'Id' for no operation. 38 subdiv_offsets: a list of integer offsets into the tensor at which each 39 independent subdivision should begin. Use [0] if no subdivision should 40 be done. 41 42 Returns: 43 An Op implementing the distributed reduction. 44 45 Raises: 46 ValueError: if any of the input parameter constraints are not met. 47 """ 48 if not device.canonical_name(t.device): 49 raise ValueError('Device assignment required for collective ops') 50 if group_size <= 1: 51 raise ValueError('Parameter group_size to all_reduce must be at least 2.') 52 return gen_collective_ops.collective_reduce(t, 53 group_size=group_size, 54 group_key=group_key, 55 instance_key=instance_key, 56 merge_op=merge_op, 57 final_op=final_op, 58 subdiv_offsets=subdiv_offsets) 59 60 61def all_gather(t, group_size, group_key, instance_key): 62 """Accumulates tensors collectively, across devices, along first dimension. 63 64 Args: 65 t: the tensor to participate in the accumulation. 66 group_size: the total number of tensors to be collectively accumulated. 67 Each must reside on a different device. 68 group_key: an integer identifying the group of devices. 69 instance_key: an integer identifying the participating group of Ops. 70 71 Returns: 72 An Op implementing the distributed operation. 73 74 Raises: 75 ValueError: if any of the input parameter constraints are not met. 76 """ 77 if not device.canonical_name(t.device): 78 raise ValueError('Device assignment required for collective ops') 79 if group_size <= 1: 80 raise ValueError('Parameter group_size to all_gather must be at least 2.') 81 dims = t.shape.as_list() 82 output_shape = [dims[0] * group_size] + dims[1:] 83 return gen_collective_ops.collective_gather(t, 84 shape=output_shape, 85 group_size=group_size, 86 group_key=group_key, 87 instance_key=instance_key) 88 89 90def broadcast_send(t, shape, dtype, group_size, group_key, instance_key): 91 """Broadcasts one tensor to a group of others, across devices. 92 93 Args: 94 t: the tensor to be sent. 95 shape: the shape of the tensor being sent, which must agree with t. 96 dtype: the type of the tensor being sent, which must agree with t. 97 group_size: one plus the number of receiving tensors, i.e. the total 98 number of devices participating. Each tensor must reside on a 99 different device. 100 group_key: an integer identifying the group of devices. 101 instance_key: an integer identifying the participating group of Ops. 102 103 Returns: 104 An Op implementing the distributed broadcast send. 105 106 Raises: 107 ValueError: if any of the input parameter constraints are not met. 108 109 Note that the shape and dtype arguments appear redundant since they 110 should be obtainable from t. The are two reasons for including 111 them. First, the shape and type of tensors passed via broadcast must 112 be known ahead of time in their most specific form so that the receive 113 side can allocate memory for the operation and shape/type inference can 114 carry forward from there. Including the same declarations on the 115 send side clarifies a commitment already made. Secondly, having nearly 116 identical use syntax for send and receive sides may simplify tool-driven 117 generation of broadcast. 118 """ 119 if not device.canonical_name(t.device): 120 raise ValueError('Device assignment required for collective ops') 121 if group_size <= 1: 122 raise ValueError( 123 'Parameter group_size to broadcast_send must be at least 2.') 124 if t.shape != shape: 125 raise ValueError( 126 'Shape of broadcast_send tensor not equal to delcared shape') 127 if t.dtype != dtype: 128 raise ValueError( 129 'Type of broadcast_send tensor not equal to declared type') 130 return gen_collective_ops.collective_bcast_send(t, 131 shape=shape, 132 group_size=group_size, 133 group_key=group_key, 134 instance_key=instance_key) 135 136 137def broadcast_recv(shape, dtype, group_size, group_key, instance_key): 138 """Receives a broadcasts tensor, across devices. 139 140 Args: 141 shape: Shape of the tensor to be received. 142 dtype: Type of the tensor to be received. 143 group_size: one plus the number of receiving tensors, i.e. the total 144 number of devices participating. Each tensor must reside on a 145 different device. 146 group_key: an integer identifying the group of devices. 147 instance_key: an integer identifying the participating group of Ops. 148 149 Returns: 150 An Op implementing the broadcast receive. 151 152 Raises: 153 ValueError: if any of the input parameter constraints are not met. 154 """ 155 if group_size <= 1: 156 raise ValueError( 157 'Parameter group_size to broadcast_send must be at least 2.') 158 return gen_collective_ops.collective_bcast_recv(shape=shape, 159 T=dtype, 160 group_size=group_size, 161 group_key=group_key, 162 instance_key=instance_key) 163