1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
17 #include "tensorflow/core/platform/test.h"
18 #include "tensorflow/core/platform/test_benchmark.h"
19 #include "tensorflow/core/protobuf/worker.pb.h"
20 
21 namespace tensorflow {
22 
23 namespace {
ToString(const grpc::ByteBuffer & buf)24 string ToString(const grpc::ByteBuffer& buf) {
25   std::vector<grpc::Slice> slices;
26   CHECK(buf.Dump(&slices).ok());
27   string result;
28   for (const grpc::Slice& s : slices) {
29     result.append(reinterpret_cast<const char*>(s.begin()), s.size());
30   }
31   return result;
32 }
33 
34 // Return a ByteBuffer that contains str split up into num_slices slices.
MakeBuffer(const string & str,int num_slices)35 grpc::ByteBuffer MakeBuffer(const string& str, int num_slices) {
36   // Convert to a ByteBuffer.
37   std::vector<::grpc::Slice> slices;
38   const size_t per_slice = (str.size() + num_slices - 1) / num_slices;
39   for (size_t pos = 0; pos < str.size();) {
40     const size_t n = std::min(str.size() - pos, per_slice);
41     slices.emplace_back(&str[pos], n);
42     pos += n;
43   }
44   if (slices.empty()) {
45     slices.emplace_back();
46   }
47   return ::grpc::ByteBuffer(&slices[0], slices.size());
48 }
49 
50 // Make a proto with approximately the specified length.
MakeProto(int size)51 CleanupAllRequest MakeProto(int size) {
52   int approx_size = 0;
53   CleanupAllRequest proto;
54   int index = 0;
55   while (approx_size < size) {
56     int item_size = std::min(size - approx_size, 1024);
57     proto.add_container(string(item_size, 'a' + static_cast<char>(index % 26)));
58     approx_size += item_size + 3;  // +3 for encoding overhead.
59     index++;
60   }
61   return proto;
62 }
63 }  // namespace
64 
TEST(GrpcProto,Unparse)65 TEST(GrpcProto, Unparse) {
66   CleanupAllRequest proto;
67   proto.add_container("hello");
68   proto.add_container("world");
69   grpc::ByteBuffer buf;
70   ASSERT_TRUE(GrpcMaybeUnparseProto(proto, &buf).ok());
71   CleanupAllRequest parsed;
72   ASSERT_TRUE(parsed.ParseFromString(ToString(buf)));
73   ASSERT_EQ(proto.DebugString(), parsed.DebugString());
74 }
75 
TEST(GrpcProto,UnparseToString)76 TEST(GrpcProto, UnparseToString) {
77   CleanupAllRequest proto;
78   proto.add_container("hello");
79   proto.add_container("world");
80   string str;
81   CHECK(proto.SerializeToString(&str));
82   grpc::ByteBuffer buf;
83   ASSERT_TRUE(GrpcMaybeUnparseProto(str, &buf).ok());
84   CleanupAllRequest parsed;
85   ASSERT_TRUE(parsed.ParseFromString(ToString(buf)));
86   ASSERT_EQ(proto.DebugString(), parsed.DebugString());
87 }
88 
TEST(GrpcProto,Parse)89 TEST(GrpcProto, Parse) {
90   // Test with serialization broken up into a bunch of slices.
91   struct Case {
92     int length;
93     int slices;
94   };
95   for (Case c : std::vector<Case>{
96            {0, 1},
97            {20, 1},
98            {100, 1},
99            {1 << 20, 1},
100            {100, 5},
101            {10000, 50},
102        }) {
103     CleanupAllRequest proto = MakeProto(c.length);
104     ::grpc::ByteBuffer src = MakeBuffer(proto.SerializeAsString(), c.slices);
105     CleanupAllRequest parsed;
106     ASSERT_TRUE(GrpcMaybeParseProto(&src, &parsed))
107         << c.length << " " << c.slices;
108     ASSERT_EQ(proto.DebugString(), parsed.DebugString());
109   }
110 }
111 
TEST(GrpcProto,ParseFromString)112 TEST(GrpcProto, ParseFromString) {
113   // Test with serialization broken up into a bunch of slices.
114   struct Case {
115     int length;
116     int slices;
117   };
118   for (Case c : std::vector<Case>{
119            {0, 1},
120            {20, 1},
121            {100, 1},
122            {1 << 20, 1},
123            {100, 5},
124            {10000, 50},
125        }) {
126     CleanupAllRequest proto = MakeProto(c.length);
127     ::grpc::ByteBuffer src = MakeBuffer(proto.SerializeAsString(), c.slices);
128     string parsed_str;
129     CleanupAllRequest parsed;
130     ASSERT_TRUE(GrpcMaybeParseProto(&src, &parsed_str))
131         << c.length << " " << c.slices;
132     ASSERT_TRUE(parsed.ParseFromString(parsed_str));
133     ASSERT_EQ(proto.DebugString(), parsed.DebugString());
134   }
135 }
136 
BM_UnparseGrpc(int iters,int size)137 static void BM_UnparseGrpc(int iters, int size) {
138   testing::StopTiming();
139   auto proto = MakeProto(size);
140   testing::StartTiming();
141   for (int i = 0; i < iters; i++) {
142     grpc::ByteBuffer buf;
143     CHECK(GrpcMaybeUnparseProto(proto, &buf).ok());
144   }
145   testing::StopTiming();
146 }
147 BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
148 
BM_UnparseString(int iters,int size)149 static void BM_UnparseString(int iters, int size) {
150   testing::StopTiming();
151   auto proto = MakeProto(size);
152   testing::StartTiming();
153 
154   for (int i = 0; i < iters; i++) {
155     string buf;
156     proto.SerializeToString(&buf);
157   }
158 
159   testing::StopTiming();
160 }
161 BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
162 
BM_ParseGrpc(int iters,int size,int num_slices)163 static void BM_ParseGrpc(int iters, int size, int num_slices) {
164   testing::StopTiming();
165   CleanupAllRequest proto = MakeProto(size);
166   auto buf = MakeBuffer(proto.SerializeAsString(), num_slices);
167   testing::StartTiming();
168 
169   for (int i = 0; i < iters; i++) {
170     CHECK(GrpcMaybeParseProto(&buf, &proto));
171   }
172 
173   testing::StopTiming();
174 }
175 BENCHMARK(BM_ParseGrpc)
176     ->ArgPair(1, 1)
177     ->ArgPair(1 << 10, 1)
178     ->ArgPair(1 << 10, 4)
179     ->ArgPair(1 << 20, 1)
180     ->ArgPair(1 << 20, 4);
181 
BM_ParseString(int iters,int size)182 static void BM_ParseString(int iters, int size) {
183   testing::StopTiming();
184   CleanupAllRequest proto = MakeProto(size);
185   string serial = proto.SerializeAsString();
186   testing::StartTiming();
187 
188   for (int i = 0; i < iters; i++) {
189     CHECK(proto.ParseFromString(serial));
190   }
191 
192   testing::StopTiming();
193 }
194 BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
195 
196 }  // namespace tensorflow
197