1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/data_flow_ops.h"
17 #include "tensorflow/cc/ops/data_flow_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19
20 #include "tensorflow/cc/framework/grad_op_registry.h"
21 #include "tensorflow/cc/framework/gradients.h"
22
23 namespace tensorflow {
24 namespace ops {
25 namespace {
26
27 REGISTER_NO_GRADIENT_OP("Queue");
28 REGISTER_NO_GRADIENT_OP("QueueEnqueue");
29 REGISTER_NO_GRADIENT_OP("QueueEnqueueMany");
30 REGISTER_NO_GRADIENT_OP("QueueDequeue");
31 REGISTER_NO_GRADIENT_OP("QueueDequeueMany");
32 REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo");
33 REGISTER_NO_GRADIENT_OP("QueueClose");
34 REGISTER_NO_GRADIENT_OP("QueueSize");
35 REGISTER_NO_GRADIENT_OP("Stack");
36 REGISTER_NO_GRADIENT_OP("StackPush");
37 REGISTER_NO_GRADIENT_OP("StackPop");
38 REGISTER_NO_GRADIENT_OP("StackClose");
39 REGISTER_NO_GRADIENT_OP("GetSessionHandle");
40 REGISTER_NO_GRADIENT_OP("GetSessionHandleV2");
41 REGISTER_NO_GRADIENT_OP("GetSessionTensor");
42 REGISTER_NO_GRADIENT_OP("DeleteSessionTensor");
43
DynamicPartitionGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)44 Status DynamicPartitionGrad(const Scope& scope, const Operation& op,
45 const std::vector<Output>& grad_inputs,
46 std::vector<Output>* grad_outputs) {
47 // DynamicPartition only moves input values into various positions
48 // in the output, so the gradient operation only has to map incoming
49 // gradients into their input source locations.
50 // running example:
51 // data = [10, 20, 30, 40, 50]
52 // partitions = [0, 0, 1, 1, 0]
53 // num_partitions = 2
54 // dynamic_partition(data, partitions, num_partitions) = {
55 // [10, 20, 50],
56 // [30, 40]
57 // }
58 // grads = {
59 // [g1, g2, g3],
60 // [g4, g5]
61 // }
62 // The desired propagation of the gradients back to the data inputs is:
63 // [g1, g2, g4, g5, g3]
64 auto data = op.input(0);
65 auto partitions = op.input(1);
66 int32 num_partitions;
67 TF_RETURN_IF_ERROR(
68 GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions));
69
70 // Note: the shape of the partitions is a prefix of the data shape.
71 // shape(partitions) = [5]
72 auto partitions_shape = Shape(scope, partitions);
73 // We now create a partitions-shaped tensor with integers from
74 // [0..size(partitions)) This will be dynamic_partitioned with the
75 // input parameters, providing the destination index for a given
76 // source item.
77 // partitions_size = prod([5]) = 5
78 // reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4]
79 auto zero = Const(scope, 0);
80 auto one = Const(scope, 1);
81 auto original_indices = Reshape(
82 scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one),
83 partitions_shape);
84 // dynamic_partition(
85 // [0, 1, 2, 3, 4],
86 // [0, 0, 1, 1, 0], 2)
87 // = { [0, 1, 4],
88 // [2, 3] }
89 auto partitioned_indices =
90 DynamicPartition(scope, original_indices, partitions, num_partitions);
91
92 // Invert these indices with dynamic_stitch to map the incoming
93 // gradients to their source inputs.
94 // dynamic_stitch(
95 // { [0, 1, 4], [2, 3] },
96 // { [g1, g2, g3], [g4, g5] })
97 // = [g1, g2, g4, g5, g3]
98 auto reconstructed =
99 DynamicStitch(scope, partitioned_indices.outputs, grad_inputs);
100 // reshape back into a data-shaped tensor to propagate gradients for the data
101 // input.
102 grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data)));
103 // Stop propagation along the partitions input
104 grad_outputs->push_back(NoGradient());
105 return scope.status();
106 }
107 REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad);
108
DynamicStitchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)109 Status DynamicStitchGrad(const Scope& scope, const Operation& op,
110 const std::vector<Output>& grad_inputs,
111 std::vector<Output>* grad_outputs) {
112 // Running example:
113 // indices = {2, [1, 0]}
114 // data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]}
115 // out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]]
116 // grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]]
117
118 // indices and data are two equal-sized lists passed
119 // into DynamicStitch.
120 // num_values = 2
121 int32 num_values = op.num_inputs() / 2;
122
123 // Stop propagation along the indices list
124 for (int32 i = 0; i < num_values; i++) {
125 grad_outputs->push_back(NoGradient());
126 }
127
128 // DynamicStitch shuffles its data to the output (using items in
129 // indices) so the gradient propagated to a given data input simply
130 // selects the gradient for its output position.
131 for (int32 i = 0; i < num_values; i++) {
132 // index has the destination positions for the i'th data
133 // element. We cast it into an int32 if necessary, so we can use
134 // it from a Gather op.
135 // i = 0: index = 2
136 // i = 1: index = [1, 0]
137 auto index = op.input(i);
138 if (index.type() != DT_INT32) {
139 index = Cast(scope, index, DT_INT32);
140 }
141 // Gather the index specified locations in the gradient and
142 // propagate it as the gradient for the i'th data item.
143 // i = 0: gather(grad, 2) = [g_5, g_6]
144 // i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]]
145 grad_outputs->push_back(Gather(scope, grad_inputs[0], index));
146 }
147
148 return scope.status();
149 }
150 REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad);
151 REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad);
152
153 } // anonymous namespace
154 } // namespace ops
155 } // namespace tensorflow
156