1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.testing.integration;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 
22 import com.google.protobuf.BoolValue;
23 import com.google.protobuf.ByteString;
24 import io.grpc.CallOptions;
25 import io.grpc.Channel;
26 import io.grpc.ClientCall;
27 import io.grpc.ClientInterceptor;
28 import io.grpc.Codec;
29 import io.grpc.CompressorRegistry;
30 import io.grpc.DecompressorRegistry;
31 import io.grpc.ForwardingClientCall;
32 import io.grpc.ForwardingClientCallListener;
33 import io.grpc.ManagedChannel;
34 import io.grpc.Metadata;
35 import io.grpc.MethodDescriptor;
36 import io.grpc.ServerCall;
37 import io.grpc.ServerCall.Listener;
38 import io.grpc.ServerCallHandler;
39 import io.grpc.ServerInterceptor;
40 import io.grpc.internal.AbstractServerImplBuilder;
41 import io.grpc.internal.GrpcUtil;
42 import io.grpc.netty.NettyChannelBuilder;
43 import io.grpc.netty.NettyServerBuilder;
44 import io.grpc.testing.integration.Messages.Payload;
45 import io.grpc.testing.integration.Messages.PayloadType;
46 import io.grpc.testing.integration.Messages.SimpleRequest;
47 import io.grpc.testing.integration.Messages.SimpleResponse;
48 import java.io.FilterInputStream;
49 import java.io.FilterOutputStream;
50 import java.io.IOException;
51 import java.io.InputStream;
52 import java.io.OutputStream;
53 import org.junit.Before;
54 import org.junit.BeforeClass;
55 import org.junit.Test;
56 import org.junit.runner.RunWith;
57 import org.junit.runners.JUnit4;
58 
59 /**
60  * Tests that compression is turned on.
61  */
62 @RunWith(JUnit4.class)
63 public class TransportCompressionTest extends AbstractInteropTest {
64 
65   // Masquerade as identity.
66   private static final Fzip FZIPPER = new Fzip("gzip", new Codec.Gzip());
67   private volatile boolean expectFzip;
68 
69   private static final DecompressorRegistry decompressors = DecompressorRegistry.emptyInstance()
70       .with(Codec.Identity.NONE, false)
71       .with(FZIPPER, true);
72   private static final CompressorRegistry compressors = CompressorRegistry.newEmptyInstance();
73 
74   @Before
beforeTests()75   public void beforeTests() {
76     FZIPPER.anyRead = false;
77     FZIPPER.anyWritten = false;
78   }
79 
80   @BeforeClass
registerCompressors()81   public static void registerCompressors() {
82     compressors.register(FZIPPER);
83     compressors.register(Codec.Identity.NONE);
84   }
85 
86   @Override
getServerBuilder()87   protected AbstractServerImplBuilder<?> getServerBuilder() {
88     return NettyServerBuilder.forPort(0)
89         .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
90         .compressorRegistry(compressors)
91         .decompressorRegistry(decompressors)
92         .intercept(new ServerInterceptor() {
93             @Override
94             public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
95                 Metadata headers, ServerCallHandler<ReqT, RespT> next) {
96               Listener<ReqT> listener = next.startCall(call, headers);
97               // TODO(carl-mastrangelo): check that encoding was set.
98               call.setMessageCompression(true);
99               return listener;
100             }
101           });
102   }
103 
104   @Test
105   public void compresses() {
106     expectFzip = true;
107     final SimpleRequest request = SimpleRequest.newBuilder()
108         .setResponseSize(314159)
109         .setResponseCompressed(BoolValue.newBuilder().setValue(true))
110         .setResponseType(PayloadType.COMPRESSABLE)
111         .setPayload(Payload.newBuilder()
112             .setBody(ByteString.copyFrom(new byte[271828])))
113         .build();
114     final SimpleResponse goldenResponse = SimpleResponse.newBuilder()
115         .setPayload(Payload.newBuilder()
116             .setType(PayloadType.COMPRESSABLE)
117             .setBody(ByteString.copyFrom(new byte[314159])))
118         .build();
119 
120 
121     assertEquals(goldenResponse, blockingStub.unaryCall(request));
122     // Assert that compression took place
123     assertTrue(FZIPPER.anyRead);
124     assertTrue(FZIPPER.anyWritten);
125   }
126 
127   @Override
128   protected ManagedChannel createChannel() {
129     NettyChannelBuilder builder = NettyChannelBuilder.forAddress("localhost", getPort())
130         .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
131         .decompressorRegistry(decompressors)
132         .compressorRegistry(compressors)
133         .intercept(new ClientInterceptor() {
134           @Override
135           public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
136               MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
137             final ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
138             return new ForwardingClientCall<ReqT, RespT>() {
139 
140               @Override
141               protected ClientCall<ReqT, RespT> delegate() {
142                 return call;
143               }
144 
145               @Override
146               public void start(
147                   final ClientCall.Listener<RespT> responseListener, Metadata headers) {
148                 ClientCall.Listener<RespT> listener = new ForwardingClientCallListener<RespT>() {
149 
150                   @Override
151                   protected io.grpc.ClientCall.Listener<RespT> delegate() {
152                     return responseListener;
153                   }
154 
155                   @Override
156                   public void onHeaders(Metadata headers) {
157                     super.onHeaders(headers);
158                     if (expectFzip) {
159                       String encoding = headers.get(GrpcUtil.MESSAGE_ENCODING_KEY);
160                       assertEquals(encoding, FZIPPER.getMessageEncoding());
161                     }
162                   }
163                 };
164                 super.start(listener, headers);
165                 setMessageCompression(true);
166               }
167             };
168           }
169         })
170         .usePlaintext();
171     io.grpc.internal.TestingAccessor.setStatsImplementation(
172         builder, createClientCensusStatsModule());
173     return builder.build();
174   }
175 
176   /**
177    * Fzip is a custom compressor.
178    */
179   static class Fzip implements Codec {
180     volatile boolean anyRead;
181     volatile boolean anyWritten;
182     volatile Codec delegate;
183 
184     private final String actualName;
185 
186     public Fzip(String actualName, Codec delegate) {
187       this.actualName = actualName;
188       this.delegate = delegate;
189     }
190 
191     @Override
192     public String getMessageEncoding() {
193       return actualName;
194     }
195 
196     @Override
197     public OutputStream compress(OutputStream os) throws IOException {
198       return new FilterOutputStream(delegate.compress(os)) {
199         @Override
200         public void write(int b) throws IOException {
201           super.write(b);
202           anyWritten = true;
203         }
204       };
205     }
206 
207     @Override
208     public InputStream decompress(InputStream is) throws IOException {
209       return new FilterInputStream(delegate.decompress(is)) {
210         @Override
211         public int read() throws IOException {
212           int val = super.read();
213           anyRead = true;
214           return val;
215         }
216 
217         @Override
218         public int read(byte[] b, int off, int len) throws IOException {
219           int total = super.read(b, off, len);
220           anyRead = true;
221           return total;
222         }
223       };
224     }
225   }
226 }
227