1 /*
2  * Copyright (C) 2006 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.google.inject.spring;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 
21 import com.google.inject.Binder;
22 import com.google.inject.Inject;
23 import com.google.inject.Provider;
24 import com.google.inject.name.Names;
25 import org.springframework.beans.factory.BeanFactory;
26 import org.springframework.beans.factory.ListableBeanFactory;
27 
28 /**
29  * Integrates Guice with Spring.
30  *
31  * @author crazybob@google.com (Bob Lee)
32  */
33 public class SpringIntegration {
SpringIntegration()34   private SpringIntegration() {}
35 
36   /**
37    * Creates a provider which looks up objects from Spring using the given name. Expects a binding
38    * to {@link org.springframework.beans.factory.BeanFactory}. Example usage:
39    *
40    * <pre>
41    * bind(DataSource.class)
42    *   .toProvider(fromSpring(DataSource.class, "dataSource"));
43    * </pre>
44    */
fromSpring(Class<T> type, String name)45   public static <T> Provider<T> fromSpring(Class<T> type, String name) {
46     return new InjectableSpringProvider<T>(type, name);
47   }
48 
49   /**
50    * Binds all Spring beans from the given factory by name. For a Spring bean named "foo", this
51    * method creates a binding to the bean's type and {@code @Named("foo")}.
52    *
53    * @see com.google.inject.name.Named
54    * @see com.google.inject.name.Names#named(String)
55    */
bindAll(Binder binder, ListableBeanFactory beanFactory)56   public static void bindAll(Binder binder, ListableBeanFactory beanFactory) {
57     binder = binder.skipSources(SpringIntegration.class);
58 
59     for (String name : beanFactory.getBeanDefinitionNames()) {
60       Class<?> type = beanFactory.getType(name);
61       bindBean(binder, beanFactory, name, type);
62     }
63   }
64 
bindBean( Binder binder, ListableBeanFactory beanFactory, String name, Class<T> type)65   static <T> void bindBean(
66       Binder binder, ListableBeanFactory beanFactory, String name, Class<T> type) {
67     SpringProvider<T> provider = SpringProvider.newInstance(type, name);
68     try {
69       provider.initialize(beanFactory);
70     } catch (Exception e) {
71       binder.addError(e);
72       return;
73     }
74 
75     binder.bind(type).annotatedWith(Names.named(name)).toProvider(provider);
76   }
77 
78   static class SpringProvider<T> implements Provider<T> {
79 
80     BeanFactory beanFactory;
81     boolean singleton;
82     final Class<T> type;
83     final String name;
84 
SpringProvider(Class<T> type, String name)85     public SpringProvider(Class<T> type, String name) {
86       this.type = checkNotNull(type, "type");
87       this.name = checkNotNull(name, "name");
88     }
89 
newInstance(Class<T> type, String name)90     static <T> SpringProvider<T> newInstance(Class<T> type, String name) {
91       return new SpringProvider<T>(type, name);
92     }
93 
initialize(BeanFactory beanFactory)94     void initialize(BeanFactory beanFactory) {
95       this.beanFactory = beanFactory;
96       if (!beanFactory.isTypeMatch(name, type)) {
97         throw new ClassCastException(
98             "Spring bean named '" + name + "' does not implement " + type.getName() + ".");
99       }
100       singleton = beanFactory.isSingleton(name);
101     }
102 
103     @Override
get()104     public T get() {
105       return singleton ? getSingleton() : type.cast(beanFactory.getBean(name));
106     }
107 
108     volatile T instance;
109 
getSingleton()110     private T getSingleton() {
111       if (instance == null) {
112         instance = type.cast(beanFactory.getBean(name));
113       }
114       return instance;
115     }
116   }
117 
118   static class InjectableSpringProvider<T> extends SpringProvider<T> {
119 
InjectableSpringProvider(Class<T> type, String name)120     InjectableSpringProvider(Class<T> type, String name) {
121       super(type, name);
122     }
123 
124     @Inject
125     @Override
initialize(BeanFactory beanFactory)126     void initialize(BeanFactory beanFactory) {
127       super.initialize(beanFactory);
128     }
129   }
130 }
131