1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
17 #include "tensorflow/core/platform/test.h"
18 #include "tensorflow/core/platform/test_benchmark.h"
19 #include "tensorflow/core/protobuf/worker.pb.h"
20
21 namespace tensorflow {
22
23 namespace {
ToString(const grpc::ByteBuffer & buf)24 string ToString(const grpc::ByteBuffer& buf) {
25 std::vector<grpc::Slice> slices;
26 CHECK(buf.Dump(&slices).ok());
27 string result;
28 for (const grpc::Slice& s : slices) {
29 result.append(reinterpret_cast<const char*>(s.begin()), s.size());
30 }
31 return result;
32 }
33
34 // Return a ByteBuffer that contains str split up into num_slices slices.
MakeBuffer(const string & str,int num_slices)35 grpc::ByteBuffer MakeBuffer(const string& str, int num_slices) {
36 // Convert to a ByteBuffer.
37 std::vector<::grpc::Slice> slices;
38 const size_t per_slice = (str.size() + num_slices - 1) / num_slices;
39 for (size_t pos = 0; pos < str.size();) {
40 const size_t n = std::min(str.size() - pos, per_slice);
41 slices.emplace_back(&str[pos], n);
42 pos += n;
43 }
44 if (slices.empty()) {
45 slices.emplace_back();
46 }
47 return ::grpc::ByteBuffer(&slices[0], slices.size());
48 }
49
50 // Make a proto with approximately the specified length.
MakeProto(int size)51 CleanupAllRequest MakeProto(int size) {
52 int approx_size = 0;
53 CleanupAllRequest proto;
54 int index = 0;
55 while (approx_size < size) {
56 int item_size = std::min(size - approx_size, 1024);
57 proto.add_container(string(item_size, 'a' + static_cast<char>(index % 26)));
58 approx_size += item_size + 3; // +3 for encoding overhead.
59 index++;
60 }
61 return proto;
62 }
63 } // namespace
64
TEST(GrpcProto,Unparse)65 TEST(GrpcProto, Unparse) {
66 CleanupAllRequest proto;
67 proto.add_container("hello");
68 proto.add_container("world");
69 grpc::ByteBuffer buf;
70 ASSERT_TRUE(GrpcMaybeUnparseProto(proto, &buf).ok());
71 CleanupAllRequest parsed;
72 ASSERT_TRUE(parsed.ParseFromString(ToString(buf)));
73 ASSERT_EQ(proto.DebugString(), parsed.DebugString());
74 }
75
TEST(GrpcProto,UnparseToString)76 TEST(GrpcProto, UnparseToString) {
77 CleanupAllRequest proto;
78 proto.add_container("hello");
79 proto.add_container("world");
80 string str;
81 CHECK(proto.SerializeToString(&str));
82 grpc::ByteBuffer buf;
83 ASSERT_TRUE(GrpcMaybeUnparseProto(str, &buf).ok());
84 CleanupAllRequest parsed;
85 ASSERT_TRUE(parsed.ParseFromString(ToString(buf)));
86 ASSERT_EQ(proto.DebugString(), parsed.DebugString());
87 }
88
TEST(GrpcProto,Parse)89 TEST(GrpcProto, Parse) {
90 // Test with serialization broken up into a bunch of slices.
91 struct Case {
92 int length;
93 int slices;
94 };
95 for (Case c : std::vector<Case>{
96 {0, 1},
97 {20, 1},
98 {100, 1},
99 {1 << 20, 1},
100 {100, 5},
101 {10000, 50},
102 }) {
103 CleanupAllRequest proto = MakeProto(c.length);
104 ::grpc::ByteBuffer src = MakeBuffer(proto.SerializeAsString(), c.slices);
105 CleanupAllRequest parsed;
106 ASSERT_TRUE(GrpcMaybeParseProto(&src, &parsed))
107 << c.length << " " << c.slices;
108 ASSERT_EQ(proto.DebugString(), parsed.DebugString());
109 }
110 }
111
TEST(GrpcProto,ParseFromString)112 TEST(GrpcProto, ParseFromString) {
113 // Test with serialization broken up into a bunch of slices.
114 struct Case {
115 int length;
116 int slices;
117 };
118 for (Case c : std::vector<Case>{
119 {0, 1},
120 {20, 1},
121 {100, 1},
122 {1 << 20, 1},
123 {100, 5},
124 {10000, 50},
125 }) {
126 CleanupAllRequest proto = MakeProto(c.length);
127 ::grpc::ByteBuffer src = MakeBuffer(proto.SerializeAsString(), c.slices);
128 string parsed_str;
129 CleanupAllRequest parsed;
130 ASSERT_TRUE(GrpcMaybeParseProto(&src, &parsed_str))
131 << c.length << " " << c.slices;
132 ASSERT_TRUE(parsed.ParseFromString(parsed_str));
133 ASSERT_EQ(proto.DebugString(), parsed.DebugString());
134 }
135 }
136
BM_UnparseGrpc(int iters,int size)137 static void BM_UnparseGrpc(int iters, int size) {
138 testing::StopTiming();
139 auto proto = MakeProto(size);
140 testing::StartTiming();
141 for (int i = 0; i < iters; i++) {
142 grpc::ByteBuffer buf;
143 CHECK(GrpcMaybeUnparseProto(proto, &buf).ok());
144 }
145 testing::StopTiming();
146 }
147 BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
148
BM_UnparseString(int iters,int size)149 static void BM_UnparseString(int iters, int size) {
150 testing::StopTiming();
151 auto proto = MakeProto(size);
152 testing::StartTiming();
153
154 for (int i = 0; i < iters; i++) {
155 string buf;
156 proto.SerializeToString(&buf);
157 }
158
159 testing::StopTiming();
160 }
161 BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
162
BM_ParseGrpc(int iters,int size,int num_slices)163 static void BM_ParseGrpc(int iters, int size, int num_slices) {
164 testing::StopTiming();
165 CleanupAllRequest proto = MakeProto(size);
166 auto buf = MakeBuffer(proto.SerializeAsString(), num_slices);
167 testing::StartTiming();
168
169 for (int i = 0; i < iters; i++) {
170 CHECK(GrpcMaybeParseProto(&buf, &proto));
171 }
172
173 testing::StopTiming();
174 }
175 BENCHMARK(BM_ParseGrpc)
176 ->ArgPair(1, 1)
177 ->ArgPair(1 << 10, 1)
178 ->ArgPair(1 << 10, 4)
179 ->ArgPair(1 << 20, 1)
180 ->ArgPair(1 << 20, 4);
181
BM_ParseString(int iters,int size)182 static void BM_ParseString(int iters, int size) {
183 testing::StopTiming();
184 CleanupAllRequest proto = MakeProto(size);
185 string serial = proto.SerializeAsString();
186 testing::StartTiming();
187
188 for (int i = 0; i < iters; i++) {
189 CHECK(proto.ParseFromString(serial));
190 }
191
192 testing::StopTiming();
193 }
194 BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
195
196 } // namespace tensorflow
197