1 /*
2  * Copyright (C) 2008 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.inject.servlet;
18 
19 import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST;
20 import static org.easymock.EasyMock.anyObject;
21 import static org.easymock.EasyMock.createMock;
22 import static org.easymock.EasyMock.eq;
23 import static org.easymock.EasyMock.expect;
24 import static org.easymock.EasyMock.expectLastCall;
25 import static org.easymock.EasyMock.replay;
26 import static org.easymock.EasyMock.verify;
27 
28 import com.google.common.collect.ImmutableList;
29 import com.google.common.collect.Sets;
30 import com.google.inject.Binding;
31 import com.google.inject.Injector;
32 import com.google.inject.Key;
33 import com.google.inject.Provider;
34 import com.google.inject.TypeLiteral;
35 import com.google.inject.spi.BindingScopingVisitor;
36 import com.google.inject.util.Providers;
37 import java.io.IOException;
38 import java.util.ArrayList;
39 import java.util.Date;
40 import java.util.HashMap;
41 import java.util.List;
42 import java.util.UUID;
43 import javax.servlet.RequestDispatcher;
44 import javax.servlet.ServletException;
45 import javax.servlet.http.HttpServlet;
46 import javax.servlet.http.HttpServletRequest;
47 import javax.servlet.http.HttpServletResponse;
48 import junit.framework.TestCase;
49 
50 /**
51  * Tests forwarding and inclusion (RequestDispatcher actions from the servlet spec).
52  *
53  * @author Dhanji R. Prasanna (dhanji@gmail com)
54  */
55 public class ServletPipelineRequestDispatcherTest extends TestCase {
56   private static final Key<HttpServlet> HTTP_SERLVET_KEY = Key.get(HttpServlet.class);
57   private static final String A_KEY = "thinglyDEgintly" + new Date() + UUID.randomUUID();
58   private static final String A_VALUE =
59       ServletPipelineRequestDispatcherTest.class.toString() + new Date() + UUID.randomUUID();
60 
testIncludeManagedServlet()61   public final void testIncludeManagedServlet() throws IOException, ServletException {
62     String pattern = "blah.html";
63     final ServletDefinition servletDefinition =
64         new ServletDefinition(
65             Key.get(HttpServlet.class),
66             UriPatternType.get(UriPatternType.SERVLET, pattern),
67             new HashMap<String, String>(),
68             null);
69 
70     final Injector injector = createMock(Injector.class);
71     final Binding binding = createMock(Binding.class);
72     final HttpServletRequest requestMock = createMock(HttpServletRequest.class);
73 
74     expect(requestMock.getAttribute(A_KEY)).andReturn(A_VALUE);
75 
76     requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
77     requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
78 
79     final boolean[] run = new boolean[1];
80     final HttpServlet mockServlet =
81         new HttpServlet() {
82           @Override
83           protected void service(
84               HttpServletRequest request, HttpServletResponse httpServletResponse)
85               throws ServletException, IOException {
86             run[0] = true;
87 
88             final Object o = request.getAttribute(A_KEY);
89             assertEquals("Wrong attrib returned - " + o, A_VALUE, o);
90           }
91         };
92 
93     expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject())).andReturn(true);
94     expect(injector.getBinding(Key.get(HttpServlet.class))).andReturn(binding);
95     expect(injector.getInstance(HTTP_SERLVET_KEY)).andReturn(mockServlet);
96 
97     final Key<ServletDefinition> servetDefsKey = Key.get(TypeLiteral.get(ServletDefinition.class));
98 
99     Binding<ServletDefinition> mockBinding = createMock(Binding.class);
100     expect(injector.findBindingsByType(eq(servetDefsKey.getTypeLiteral())))
101         .andReturn(ImmutableList.<Binding<ServletDefinition>>of(mockBinding));
102     Provider<ServletDefinition> bindingProvider = Providers.of(servletDefinition);
103     expect(mockBinding.getProvider()).andReturn(bindingProvider);
104 
105     replay(injector, binding, requestMock, mockBinding);
106 
107     // Have to init the Servlet before we can dispatch to it.
108     servletDefinition.init(null, injector, Sets.<HttpServlet>newIdentityHashSet());
109 
110     final RequestDispatcher dispatcher =
111         new ManagedServletPipeline(injector).getRequestDispatcher(pattern);
112 
113     assertNotNull(dispatcher);
114     dispatcher.include(requestMock, createMock(HttpServletResponse.class));
115 
116     assertTrue("Include did not dispatch to our servlet!", run[0]);
117 
118     verify(injector, requestMock, mockBinding);
119   }
120 
testForwardToManagedServlet()121   public final void testForwardToManagedServlet() throws IOException, ServletException {
122     String pattern = "blah.html";
123     final ServletDefinition servletDefinition =
124         new ServletDefinition(
125             Key.get(HttpServlet.class),
126             UriPatternType.get(UriPatternType.SERVLET, pattern),
127             new HashMap<String, String>(),
128             null);
129 
130     final Injector injector = createMock(Injector.class);
131     final Binding binding = createMock(Binding.class);
132     final HttpServletRequest requestMock = createMock(HttpServletRequest.class);
133     final HttpServletResponse mockResponse = createMock(HttpServletResponse.class);
134 
135     expect(requestMock.getAttribute(A_KEY)).andReturn(A_VALUE);
136 
137     requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
138     requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
139 
140     expect(mockResponse.isCommitted()).andReturn(false);
141 
142     mockResponse.resetBuffer();
143     expectLastCall().once();
144 
145     final List<String> paths = new ArrayList<>();
146     final HttpServlet mockServlet =
147         new HttpServlet() {
148           @Override
149           protected void service(
150               HttpServletRequest request, HttpServletResponse httpServletResponse)
151               throws ServletException, IOException {
152             paths.add(request.getRequestURI());
153 
154             final Object o = request.getAttribute(A_KEY);
155             assertEquals("Wrong attrib returned - " + o, A_VALUE, o);
156           }
157         };
158 
159     expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject())).andReturn(true);
160     expect(injector.getBinding(Key.get(HttpServlet.class))).andReturn(binding);
161 
162     expect(injector.getInstance(HTTP_SERLVET_KEY)).andReturn(mockServlet);
163 
164     final Key<ServletDefinition> servetDefsKey = Key.get(TypeLiteral.get(ServletDefinition.class));
165 
166     Binding<ServletDefinition> mockBinding = createMock(Binding.class);
167     expect(injector.findBindingsByType(eq(servetDefsKey.getTypeLiteral())))
168         .andReturn(ImmutableList.<Binding<ServletDefinition>>of(mockBinding));
169     Provider<ServletDefinition> bindingProvider = Providers.of(servletDefinition);
170     expect(mockBinding.getProvider()).andReturn(bindingProvider);
171 
172     replay(injector, binding, requestMock, mockResponse, mockBinding);
173 
174     // Have to init the Servlet before we can dispatch to it.
175     servletDefinition.init(null, injector, Sets.<HttpServlet>newIdentityHashSet());
176 
177     final RequestDispatcher dispatcher =
178         new ManagedServletPipeline(injector).getRequestDispatcher(pattern);
179 
180     assertNotNull(dispatcher);
181     dispatcher.forward(requestMock, mockResponse);
182 
183     assertTrue("Include did not dispatch to our servlet!", paths.contains(pattern));
184 
185     verify(injector, requestMock, mockResponse, mockBinding);
186   }
187 
testForwardToManagedServletFailureOnCommittedBuffer()188   public final void testForwardToManagedServletFailureOnCommittedBuffer()
189       throws IOException, ServletException {
190     IllegalStateException expected = null;
191     try {
192       forwardToManagedServletFailureOnCommittedBuffer();
193     } catch (IllegalStateException ise) {
194       expected = ise;
195     } finally {
196       assertNotNull("Expected IllegalStateException was not thrown", expected);
197     }
198   }
199 
forwardToManagedServletFailureOnCommittedBuffer()200   public final void forwardToManagedServletFailureOnCommittedBuffer()
201       throws IOException, ServletException {
202     String pattern = "blah.html";
203     final ServletDefinition servletDefinition =
204         new ServletDefinition(
205             Key.get(HttpServlet.class),
206             UriPatternType.get(UriPatternType.SERVLET, pattern),
207             new HashMap<String, String>(),
208             null);
209 
210     final Injector injector = createMock(Injector.class);
211     final Binding binding = createMock(Binding.class);
212     final HttpServletRequest mockRequest = createMock(HttpServletRequest.class);
213     final HttpServletResponse mockResponse = createMock(HttpServletResponse.class);
214 
215     expect(mockResponse.isCommitted()).andReturn(true);
216 
217     final HttpServlet mockServlet =
218         new HttpServlet() {
219           @Override
220           protected void service(
221               HttpServletRequest request, HttpServletResponse httpServletResponse)
222               throws ServletException, IOException {
223 
224             final Object o = request.getAttribute(A_KEY);
225             assertEquals("Wrong attrib returned - " + o, A_VALUE, o);
226           }
227         };
228 
229     expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject())).andReturn(true);
230     expect(injector.getBinding(Key.get(HttpServlet.class))).andReturn(binding);
231 
232     expect(injector.getInstance(Key.get(HttpServlet.class))).andReturn(mockServlet);
233 
234     final Key<ServletDefinition> servetDefsKey = Key.get(TypeLiteral.get(ServletDefinition.class));
235 
236     Binding<ServletDefinition> mockBinding = createMock(Binding.class);
237     expect(injector.findBindingsByType(eq(servetDefsKey.getTypeLiteral())))
238         .andReturn(ImmutableList.<Binding<ServletDefinition>>of(mockBinding));
239     Provider<ServletDefinition> bindingProvider = Providers.of(servletDefinition);
240     expect(mockBinding.getProvider()).andReturn(bindingProvider);
241 
242     replay(injector, binding, mockRequest, mockResponse, mockBinding);
243 
244     // Have to init the Servlet before we can dispatch to it.
245     servletDefinition.init(null, injector, Sets.<HttpServlet>newIdentityHashSet());
246 
247     final RequestDispatcher dispatcher =
248         new ManagedServletPipeline(injector).getRequestDispatcher(pattern);
249 
250     assertNotNull(dispatcher);
251 
252     try {
253       dispatcher.forward(mockRequest, mockResponse);
254     } finally {
255       verify(injector, mockRequest, mockResponse, mockBinding);
256     }
257   }
258 
testWrappedRequestUriAndUrlConsistency()259   public final void testWrappedRequestUriAndUrlConsistency() {
260     final HttpServletRequest mockRequest = createMock(HttpServletRequest.class);
261     expect(mockRequest.getScheme()).andReturn("http");
262     expect(mockRequest.getServerName()).andReturn("the.server");
263     expect(mockRequest.getServerPort()).andReturn(12345);
264     replay(mockRequest);
265     HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri");
266     assertEquals("/new-uri", wrappedRequest.getRequestURI());
267     assertEquals("http://the.server:12345/new-uri", wrappedRequest.getRequestURL().toString());
268   }
269 
testWrappedRequestUrlNegativePort()270   public final void testWrappedRequestUrlNegativePort() {
271     final HttpServletRequest mockRequest = createMock(HttpServletRequest.class);
272     expect(mockRequest.getScheme()).andReturn("http");
273     expect(mockRequest.getServerName()).andReturn("the.server");
274     expect(mockRequest.getServerPort()).andReturn(-1);
275     replay(mockRequest);
276     HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri");
277     assertEquals("/new-uri", wrappedRequest.getRequestURI());
278     assertEquals("http://the.server/new-uri", wrappedRequest.getRequestURL().toString());
279   }
280 
testWrappedRequestUrlDefaultPort()281   public final void testWrappedRequestUrlDefaultPort() {
282     final HttpServletRequest mockRequest = createMock(HttpServletRequest.class);
283     expect(mockRequest.getScheme()).andReturn("http");
284     expect(mockRequest.getServerName()).andReturn("the.server");
285     expect(mockRequest.getServerPort()).andReturn(80);
286     replay(mockRequest);
287     HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri");
288     assertEquals("/new-uri", wrappedRequest.getRequestURI());
289     assertEquals("http://the.server/new-uri", wrappedRequest.getRequestURL().toString());
290   }
291 
testWrappedRequestUrlDefaultHttpsPort()292   public final void testWrappedRequestUrlDefaultHttpsPort() {
293     final HttpServletRequest mockRequest = createMock(HttpServletRequest.class);
294     expect(mockRequest.getScheme()).andReturn("https");
295     expect(mockRequest.getServerName()).andReturn("the.server");
296     expect(mockRequest.getServerPort()).andReturn(443);
297     replay(mockRequest);
298     HttpServletRequest wrappedRequest = ManagedServletPipeline.wrapRequest(mockRequest, "/new-uri");
299     assertEquals("/new-uri", wrappedRequest.getRequestURI());
300     assertEquals("https://the.server/new-uri", wrappedRequest.getRequestURL().toString());
301   }
302 }
303