1 /*
2  * Copyright (C) 2006 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.inject.servlet;
18 
19 import static com.google.inject.Asserts.assertContains;
20 import static com.google.inject.Asserts.reserialize;
21 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest;
22 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletResponse;
23 import static java.lang.annotation.ElementType.FIELD;
24 import static java.lang.annotation.ElementType.METHOD;
25 import static java.lang.annotation.ElementType.PARAMETER;
26 import static java.lang.annotation.RetentionPolicy.RUNTIME;
27 
28 import com.google.common.collect.ImmutableMap;
29 import com.google.common.collect.Lists;
30 import com.google.inject.AbstractModule;
31 import com.google.inject.BindingAnnotation;
32 import com.google.inject.CreationException;
33 import com.google.inject.Guice;
34 import com.google.inject.Inject;
35 import com.google.inject.Injector;
36 import com.google.inject.Key;
37 import com.google.inject.Module;
38 import com.google.inject.Provider;
39 import com.google.inject.Provides;
40 import com.google.inject.ProvisionException;
41 import com.google.inject.internal.Errors;
42 import com.google.inject.name.Named;
43 import com.google.inject.name.Names;
44 import com.google.inject.servlet.ServletScopes.NullObject;
45 import com.google.inject.util.Providers;
46 import java.io.IOException;
47 import java.io.Serializable;
48 import java.lang.annotation.Retention;
49 import java.lang.annotation.Target;
50 import java.util.Map;
51 import javax.servlet.Filter;
52 import javax.servlet.FilterChain;
53 import javax.servlet.FilterConfig;
54 import javax.servlet.ServletException;
55 import javax.servlet.ServletRequest;
56 import javax.servlet.ServletResponse;
57 import javax.servlet.http.HttpServlet;
58 import javax.servlet.http.HttpServletRequest;
59 import javax.servlet.http.HttpServletRequestWrapper;
60 import javax.servlet.http.HttpServletResponse;
61 import javax.servlet.http.HttpServletResponseWrapper;
62 import javax.servlet.http.HttpSession;
63 import junit.framework.TestCase;
64 
65 /** @author crazybob@google.com (Bob Lee) */
66 public class ServletTest extends TestCase {
67   private static final Key<HttpServletRequest> HTTP_REQ_KEY = Key.get(HttpServletRequest.class);
68   private static final Key<HttpServletResponse> HTTP_RESP_KEY = Key.get(HttpServletResponse.class);
69   private static final Key<Map<String, String[]>> REQ_PARAMS_KEY =
70       new Key<Map<String, String[]>>(RequestParameters.class) {};
71 
72   private static final Key<InRequest> IN_REQUEST_NULL_KEY = Key.get(InRequest.class, Null.class);
73   private static final Key<InSession> IN_SESSION_KEY = Key.get(InSession.class);
74   private static final Key<InSession> IN_SESSION_NULL_KEY = Key.get(InSession.class, Null.class);
75 
76   @Override
setUp()77   public void setUp() {
78     //we need to clear the reference to the pipeline every test =(
79     GuiceFilter.reset();
80   }
81 
testScopeExceptions()82   public void testScopeExceptions() throws Exception {
83     Injector injector =
84         Guice.createInjector(
85             new AbstractModule() {
86               @Override
87               protected void configure() {
88                 install(new ServletModule());
89               }
90 
91               @Provides
92               @RequestScoped
93               String provideString() {
94                 return "foo";
95               }
96 
97               @Provides
98               @SessionScoped
99               Integer provideInteger() {
100                 return 1;
101               }
102 
103               @Provides
104               @RequestScoped
105               @Named("foo")
106               String provideNamedString() {
107                 return "foo";
108               }
109             });
110 
111     try {
112       injector.getInstance(String.class);
113       fail();
114     } catch (ProvisionException oose) {
115       assertContains(oose.getMessage(), "Cannot access scoped [java.lang.String].");
116     }
117 
118     try {
119       injector.getInstance(Integer.class);
120       fail();
121     } catch (ProvisionException oose) {
122       assertContains(oose.getMessage(), "Cannot access scoped [java.lang.Integer].");
123     }
124 
125     Key<?> key = Key.get(String.class, Names.named("foo"));
126     try {
127       injector.getInstance(key);
128       fail();
129     } catch (ProvisionException oose) {
130       assertContains(oose.getMessage(), "Cannot access scoped [" + Errors.convert(key) + "]");
131     }
132   }
133 
testRequestAndResponseBindings()134   public void testRequestAndResponseBindings() throws Exception {
135     final Injector injector = createInjector();
136     final HttpServletRequest request = newFakeHttpServletRequest();
137     final HttpServletResponse response = newFakeHttpServletResponse();
138 
139     final boolean[] invoked = new boolean[1];
140     GuiceFilter filter = new GuiceFilter();
141     FilterChain filterChain =
142         new FilterChain() {
143           @Override
144           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
145             invoked[0] = true;
146             assertSame(request, servletRequest);
147             assertSame(request, injector.getInstance(ServletRequest.class));
148             assertSame(request, injector.getInstance(HTTP_REQ_KEY));
149 
150             assertSame(response, servletResponse);
151             assertSame(response, injector.getInstance(ServletResponse.class));
152             assertSame(response, injector.getInstance(HTTP_RESP_KEY));
153 
154             assertSame(servletRequest.getParameterMap(), injector.getInstance(REQ_PARAMS_KEY));
155           }
156         };
157     filter.doFilter(request, response, filterChain);
158 
159     assertTrue(invoked[0]);
160   }
161 
testRequestAndResponseBindings_wrappingFilter()162   public void testRequestAndResponseBindings_wrappingFilter() throws Exception {
163     final HttpServletRequest request = newFakeHttpServletRequest();
164     final ImmutableMap<String, String[]> wrappedParamMap =
165         ImmutableMap.of("wrap", new String[] {"a", "b"});
166     final HttpServletRequestWrapper requestWrapper =
167         new HttpServletRequestWrapper(request) {
168           @Override
169           public Map getParameterMap() {
170             return wrappedParamMap;
171           }
172 
173           @Override
174           public Object getAttribute(String attr) {
175             // Ensure that attributes are stored on the original request object.
176             throw new UnsupportedOperationException();
177           }
178         };
179     final HttpServletResponse response = newFakeHttpServletResponse();
180     final HttpServletResponseWrapper responseWrapper = new HttpServletResponseWrapper(response);
181 
182     final boolean[] filterInvoked = new boolean[1];
183     final Injector injector =
184         createInjector(
185             new ServletModule() {
186               @Override
187               protected void configureServlets() {
188                 filter("/*")
189                     .through(
190                         new Filter() {
191                           @Inject Provider<ServletRequest> servletReqProvider;
192                           @Inject Provider<HttpServletRequest> reqProvider;
193                           @Inject Provider<ServletResponse> servletRespProvider;
194                           @Inject Provider<HttpServletResponse> respProvider;
195 
196                           @Override
197                           public void init(FilterConfig filterConfig) {}
198 
199                           @Override
200                           public void doFilter(
201                               ServletRequest req, ServletResponse resp, FilterChain chain)
202                               throws IOException, ServletException {
203                             filterInvoked[0] = true;
204                             assertSame(req, servletReqProvider.get());
205                             assertSame(req, reqProvider.get());
206 
207                             assertSame(resp, servletRespProvider.get());
208                             assertSame(resp, respProvider.get());
209 
210                             chain.doFilter(requestWrapper, responseWrapper);
211 
212                             assertSame(req, reqProvider.get());
213                             assertSame(resp, respProvider.get());
214                           }
215 
216                           @Override
217                           public void destroy() {}
218                         });
219               }
220             });
221 
222     GuiceFilter filter = new GuiceFilter();
223     final boolean[] chainInvoked = new boolean[1];
224     FilterChain filterChain =
225         new FilterChain() {
226           @Override
227           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
228             chainInvoked[0] = true;
229             assertSame(requestWrapper, servletRequest);
230             assertSame(requestWrapper, injector.getInstance(ServletRequest.class));
231             assertSame(requestWrapper, injector.getInstance(HTTP_REQ_KEY));
232 
233             assertSame(responseWrapper, servletResponse);
234             assertSame(responseWrapper, injector.getInstance(ServletResponse.class));
235             assertSame(responseWrapper, injector.getInstance(HTTP_RESP_KEY));
236 
237             assertSame(servletRequest.getParameterMap(), injector.getInstance(REQ_PARAMS_KEY));
238 
239             InRequest inRequest = injector.getInstance(InRequest.class);
240             assertSame(inRequest, injector.getInstance(InRequest.class));
241           }
242         };
243     filter.doFilter(request, response, filterChain);
244 
245     assertTrue(chainInvoked[0]);
246     assertTrue(filterInvoked[0]);
247   }
248 
testRequestAndResponseBindings_matchesPassedParameters()249   public void testRequestAndResponseBindings_matchesPassedParameters() throws Exception {
250     final int[] filterInvoked = new int[1];
251     final boolean[] servletInvoked = new boolean[1];
252     createInjector(
253         new ServletModule() {
254           @Override
255           protected void configureServlets() {
256             final HttpServletRequest[] previousReq = new HttpServletRequest[1];
257             final HttpServletResponse[] previousResp = new HttpServletResponse[1];
258 
259             final Provider<ServletRequest> servletReqProvider = getProvider(ServletRequest.class);
260             final Provider<HttpServletRequest> reqProvider = getProvider(HttpServletRequest.class);
261             final Provider<ServletResponse> servletRespProvider =
262                 getProvider(ServletResponse.class);
263             final Provider<HttpServletResponse> respProvider =
264                 getProvider(HttpServletResponse.class);
265 
266             Filter filter =
267                 new Filter() {
268                   @Override
269                   public void init(FilterConfig filterConfig) {}
270 
271                   @Override
272                   public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
273                       throws IOException, ServletException {
274                     filterInvoked[0]++;
275                     assertSame(req, servletReqProvider.get());
276                     assertSame(req, reqProvider.get());
277                     if (previousReq[0] != null) {
278                       assertEquals(req, previousReq[0]);
279                     }
280 
281                     assertSame(resp, servletRespProvider.get());
282                     assertSame(resp, respProvider.get());
283                     if (previousResp[0] != null) {
284                       assertEquals(resp, previousResp[0]);
285                     }
286 
287                     chain.doFilter(
288                         previousReq[0] = new HttpServletRequestWrapper((HttpServletRequest) req),
289                         previousResp[0] =
290                             new HttpServletResponseWrapper((HttpServletResponse) resp));
291 
292                     assertSame(req, reqProvider.get());
293                     assertSame(resp, respProvider.get());
294                   }
295 
296                   @Override
297                   public void destroy() {}
298                 };
299 
300             filter("/*").through(filter);
301             filter("/*").through(filter); // filter twice to test wrapping in filters
302             serve("/*")
303                 .with(
304                     new HttpServlet() {
305                       @Override
306                       protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
307                         servletInvoked[0] = true;
308                         assertSame(req, servletReqProvider.get());
309                         assertSame(req, reqProvider.get());
310 
311                         assertSame(resp, servletRespProvider.get());
312                         assertSame(resp, respProvider.get());
313                       }
314                     });
315           }
316         });
317 
318     GuiceFilter filter = new GuiceFilter();
319     filter.doFilter(
320         newFakeHttpServletRequest(),
321         newFakeHttpServletResponse(),
322         new FilterChain() {
323           @Override
324           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
325             throw new IllegalStateException("Shouldn't get here");
326           }
327         });
328 
329     assertEquals(2, filterInvoked[0]);
330     assertTrue(servletInvoked[0]);
331   }
332 
testNewRequestObject()333   public void testNewRequestObject() throws CreationException, IOException, ServletException {
334     final Injector injector = createInjector();
335     final HttpServletRequest request = newFakeHttpServletRequest();
336 
337     GuiceFilter filter = new GuiceFilter();
338     final boolean[] invoked = new boolean[1];
339     FilterChain filterChain =
340         new FilterChain() {
341           @Override
342           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
343             invoked[0] = true;
344             assertNotNull(injector.getInstance(InRequest.class));
345             assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
346           }
347         };
348 
349     filter.doFilter(request, null, filterChain);
350 
351     assertTrue(invoked[0]);
352   }
353 
testExistingRequestObject()354   public void testExistingRequestObject() throws CreationException, IOException, ServletException {
355     final Injector injector = createInjector();
356     final HttpServletRequest request = newFakeHttpServletRequest();
357 
358     GuiceFilter filter = new GuiceFilter();
359     final boolean[] invoked = new boolean[1];
360     FilterChain filterChain =
361         new FilterChain() {
362           @Override
363           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
364             invoked[0] = true;
365 
366             InRequest inRequest = injector.getInstance(InRequest.class);
367             assertSame(inRequest, injector.getInstance(InRequest.class));
368 
369             assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
370             assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
371           }
372         };
373 
374     filter.doFilter(request, null, filterChain);
375 
376     assertTrue(invoked[0]);
377   }
378 
testNewSessionObject()379   public void testNewSessionObject() throws CreationException, IOException, ServletException {
380     final Injector injector = createInjector();
381     final HttpServletRequest request = newFakeHttpServletRequest();
382 
383     GuiceFilter filter = new GuiceFilter();
384     final boolean[] invoked = new boolean[1];
385     FilterChain filterChain =
386         new FilterChain() {
387           @Override
388           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
389             invoked[0] = true;
390             assertNotNull(injector.getInstance(InSession.class));
391             assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
392           }
393         };
394 
395     filter.doFilter(request, null, filterChain);
396 
397     assertTrue(invoked[0]);
398   }
399 
testExistingSessionObject()400   public void testExistingSessionObject() throws CreationException, IOException, ServletException {
401     final Injector injector = createInjector();
402     final HttpServletRequest request = newFakeHttpServletRequest();
403 
404     GuiceFilter filter = new GuiceFilter();
405     final boolean[] invoked = new boolean[1];
406     FilterChain filterChain =
407         new FilterChain() {
408           @Override
409           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
410             invoked[0] = true;
411 
412             InSession inSession = injector.getInstance(InSession.class);
413             assertSame(inSession, injector.getInstance(InSession.class));
414 
415             assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
416             assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
417           }
418         };
419 
420     filter.doFilter(request, null, filterChain);
421 
422     assertTrue(invoked[0]);
423   }
424 
testHttpSessionIsSerializable()425   public void testHttpSessionIsSerializable() throws Exception {
426     final Injector injector = createInjector();
427     final HttpServletRequest request = newFakeHttpServletRequest();
428     final HttpSession session = request.getSession();
429 
430     GuiceFilter filter = new GuiceFilter();
431     final boolean[] invoked = new boolean[1];
432     FilterChain filterChain =
433         new FilterChain() {
434           @Override
435           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
436             invoked[0] = true;
437             assertNotNull(injector.getInstance(InSession.class));
438             assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
439           }
440         };
441 
442     filter.doFilter(request, null, filterChain);
443 
444     assertTrue(invoked[0]);
445 
446     HttpSession deserializedSession = reserialize(session);
447 
448     String inSessionKey = IN_SESSION_KEY.toString();
449     String inSessionNullKey = IN_SESSION_NULL_KEY.toString();
450     assertTrue(deserializedSession.getAttribute(inSessionKey) instanceof InSession);
451     assertEquals(NullObject.INSTANCE, deserializedSession.getAttribute(inSessionNullKey));
452   }
453 
testGuiceFilterConstructors()454   public void testGuiceFilterConstructors() throws Exception {
455     final RuntimeException servletException = new RuntimeException();
456     final RuntimeException chainException = new RuntimeException();
457     final Injector injector =
458         createInjector(
459             new ServletModule() {
460               @Override
461               protected void configureServlets() {
462                 serve("/*")
463                     .with(
464                         new HttpServlet() {
465                           @Override
466                           protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
467                             throw servletException;
468                           }
469                         });
470               }
471             });
472     final HttpServletRequest request = newFakeHttpServletRequest();
473     FilterChain filterChain =
474         new FilterChain() {
475           @Override
476           public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
477             throw chainException;
478           }
479         };
480 
481     try {
482       new GuiceFilter().doFilter(request, null, filterChain);
483       fail();
484     } catch (RuntimeException e) {
485       assertSame(servletException, e);
486     }
487     try {
488       injector.getInstance(GuiceFilter.class).doFilter(request, null, filterChain);
489       fail();
490     } catch (RuntimeException e) {
491       assertSame(servletException, e);
492     }
493     try {
494       injector
495           .getInstance(Key.get(GuiceFilter.class, ScopingOnly.class))
496           .doFilter(request, null, filterChain);
497       fail();
498     } catch (RuntimeException e) {
499       assertSame(chainException, e);
500     }
501   }
502 
createInjector(Module... modules)503   private Injector createInjector(Module... modules) throws CreationException {
504     return Guice.createInjector(
505         Lists.<Module>asList(
506             new AbstractModule() {
507               @Override
508               protected void configure() {
509                 install(new ServletModule());
510                 bind(InSession.class);
511                 bind(IN_SESSION_NULL_KEY)
512                     .toProvider(Providers.<InSession>of(null))
513                     .in(SessionScoped.class);
514                 bind(InRequest.class);
515                 bind(IN_REQUEST_NULL_KEY)
516                     .toProvider(Providers.<InRequest>of(null))
517                     .in(RequestScoped.class);
518               }
519             },
520             modules));
521   }
522 
523   @SessionScoped
524   static class InSession implements Serializable {}
525 
526   @RequestScoped
527   static class InRequest {}
528 
529   @BindingAnnotation
530   @Retention(RUNTIME)
531   @Target({PARAMETER, METHOD, FIELD})
532   @interface Null {}
533 }
534