1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
17 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
18 #include "tensorflow/core/lib/random/random.h"
19 
20 namespace tensorflow {
21 
22 namespace {
23 
GenerateUniformRandomNumber()24 double GenerateUniformRandomNumber() {
25   return random::New64() * (1.0 / std::numeric_limits<uint64>::max());
26 }
27 
GenerateUniformRandomNumberBetween(double a,double b)28 double GenerateUniformRandomNumberBetween(double a, double b) {
29   if (a == b) return a;
30   DCHECK_LT(a, b);
31   return a + GenerateUniformRandomNumber() * (b - a);
32 }
33 
34 }  // namespace
35 
ComputeBackoffMicroseconds(int current_retry_attempt,int64 min_delay,int64 max_delay)36 int64 ComputeBackoffMicroseconds(int current_retry_attempt, int64 min_delay,
37                                  int64 max_delay) {
38   DCHECK_GE(current_retry_attempt, 0);
39 
40   // This function with the constants below is calculating:
41   //
42   // (0.4 * min_delay) + (random[0.6,1.0] * min_delay * 1.3^retries)
43   //
44   // Note that there is an extra truncation that occurs and is documented in
45   // comments below.
46   constexpr double kBackoffBase = 1.3;
47   constexpr double kBackoffRandMult = 0.4;
48 
49   // This first term does not vary with current_retry_attempt or a random
50   // number. It exists to ensure the final term is >= min_delay
51   const double first_term = kBackoffRandMult * min_delay;
52 
53   // This is calculating min_delay * 1.3^retries
54   double uncapped_second_term = min_delay;
55   while (current_retry_attempt > 0 &&
56          uncapped_second_term < max_delay - first_term) {
57     current_retry_attempt--;
58     uncapped_second_term *= kBackoffBase;
59   }
60   // Note that first_term + uncapped_second_term can exceed max_delay here
61   // because of the final multiply by kBackoffBase.  We fix that problem with
62   // the min() below.
63   double second_term = std::min(uncapped_second_term, max_delay - first_term);
64 
65   // This supplies the random jitter to ensure that retried don't cause a
66   // thundering herd problem.
67   second_term *=
68       GenerateUniformRandomNumberBetween(1.0 - kBackoffRandMult, 1.0);
69 
70   return std::max(static_cast<int64>(first_term + second_term), min_delay);
71 }
72 
GrpcMaybeUnparseProto(const protobuf::Message & src,grpc::ByteBuffer * dst)73 ::grpc::Status GrpcMaybeUnparseProto(const protobuf::Message& src,
74                                      grpc::ByteBuffer* dst) {
75   bool own_buffer;
76   return ::grpc::GenericSerialize<::grpc::ProtoBufferWriter,
77                                   protobuf::Message>(src, dst, &own_buffer);
78 }
79 
80 // GrpcMaybeUnparseProto from a string simply copies the string to the
81 // ByteBuffer.
GrpcMaybeUnparseProto(const string & src,grpc::ByteBuffer * dst)82 ::grpc::Status GrpcMaybeUnparseProto(const string& src, grpc::ByteBuffer* dst) {
83   ::grpc::Slice s(src.data(), src.size());
84   ::grpc::ByteBuffer buffer(&s, 1);
85   dst->Swap(&buffer);
86   return ::grpc::Status::OK;
87 }
88 
GrpcMaybeParseProto(::grpc::ByteBuffer * src,protobuf::Message * dst)89 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, protobuf::Message* dst) {
90   ::grpc::ProtoBufferReader reader(src);
91   return dst->ParseFromZeroCopyStream(&reader);
92 }
93 
94 // Overload of GrpcParseProto so we can decode a TensorResponse without
95 // extra copying.  This overload is used by the RPCState class in
96 // grpc_state.h.
GrpcMaybeParseProto(::grpc::ByteBuffer * src,TensorResponse * dst)97 bool GrpcMaybeParseProto(::grpc::ByteBuffer* src, TensorResponse* dst) {
98   ::tensorflow::GrpcByteSource byte_source(src);
99   auto s = dst->ParseFrom(&byte_source);
100   return s.ok();
101 }
102 
103 // GrpcMaybeParseProto into a string simply copies bytes into the string.
GrpcMaybeParseProto(grpc::ByteBuffer * src,string * dst)104 bool GrpcMaybeParseProto(grpc::ByteBuffer* src, string* dst) {
105   dst->clear();
106   dst->reserve(src->Length());
107   std::vector<::grpc::Slice> slices;
108   if (!src->Dump(&slices).ok()) {
109     return false;
110   }
111   for (const ::grpc::Slice& s : slices) {
112     dst->append(reinterpret_cast<const char*>(s.begin()), s.size());
113   }
114   return true;
115 }
116 
117 }  // namespace tensorflow
118