1 /*
2  * Copyright 2013 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.auto.factory.processor;
17 
18 import static com.google.auto.common.GeneratedAnnotationSpecs.generatedAnnotationSpec;
19 import static com.squareup.javapoet.MethodSpec.constructorBuilder;
20 import static com.squareup.javapoet.MethodSpec.methodBuilder;
21 import static com.squareup.javapoet.TypeSpec.classBuilder;
22 import static javax.lang.model.element.Modifier.FINAL;
23 import static javax.lang.model.element.Modifier.PRIVATE;
24 import static javax.lang.model.element.Modifier.PUBLIC;
25 import static javax.lang.model.element.Modifier.STATIC;
26 
27 import com.google.common.base.Function;
28 import com.google.common.base.Joiner;
29 import com.google.common.collect.FluentIterable;
30 import com.google.common.collect.ImmutableList;
31 import com.google.common.collect.ImmutableSet;
32 import com.google.common.collect.ImmutableSetMultimap;
33 import com.google.common.collect.Iterables;
34 import com.google.common.collect.Sets;
35 import com.squareup.javapoet.AnnotationSpec;
36 import com.squareup.javapoet.ClassName;
37 import com.squareup.javapoet.CodeBlock;
38 import com.squareup.javapoet.JavaFile;
39 import com.squareup.javapoet.MethodSpec;
40 import com.squareup.javapoet.ParameterSpec;
41 import com.squareup.javapoet.ParameterizedTypeName;
42 import com.squareup.javapoet.TypeName;
43 import com.squareup.javapoet.TypeSpec;
44 import com.squareup.javapoet.TypeVariableName;
45 import java.io.IOException;
46 import java.util.Iterator;
47 import javax.annotation.processing.Filer;
48 import javax.annotation.processing.ProcessingEnvironment;
49 import javax.inject.Inject;
50 import javax.inject.Provider;
51 import javax.lang.model.SourceVersion;
52 import javax.lang.model.element.AnnotationMirror;
53 import javax.lang.model.type.TypeKind;
54 import javax.lang.model.type.TypeMirror;
55 import javax.lang.model.type.TypeVariable;
56 import javax.lang.model.util.Elements;
57 
58 final class FactoryWriter {
59 
60   private final Filer filer;
61   private final Elements elements;
62   private final SourceVersion sourceVersion;
63   private final ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated;
64 
FactoryWriter( ProcessingEnvironment processingEnv, ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated)65   FactoryWriter(
66       ProcessingEnvironment processingEnv,
67       ImmutableSetMultimap<String, PackageAndClass> factoriesBeingCreated) {
68     this.filer = processingEnv.getFiler();
69     this.elements = processingEnv.getElementUtils();
70     this.sourceVersion = processingEnv.getSourceVersion();
71     this.factoriesBeingCreated = factoriesBeingCreated;
72   }
73 
74   private static final Joiner ARGUMENT_JOINER = Joiner.on(", ");
75 
writeFactory(FactoryDescriptor descriptor)76   void writeFactory(FactoryDescriptor descriptor)
77       throws IOException {
78     String factoryName = descriptor.name().className();
79     TypeSpec.Builder factory =
80         classBuilder(factoryName)
81             .addOriginatingElement(descriptor.declaration().targetType());
82     generatedAnnotationSpec(
83             elements,
84             sourceVersion,
85             AutoFactoryProcessor.class,
86             "https://github.com/google/auto/tree/master/factory")
87         .ifPresent(factory::addAnnotation);
88     if (!descriptor.allowSubclasses()) {
89       factory.addModifiers(FINAL);
90     }
91     if (descriptor.publicType()) {
92       factory.addModifiers(PUBLIC);
93     }
94 
95     factory.superclass(TypeName.get(descriptor.extendingType()));
96     for (TypeMirror implementingType : descriptor.implementingTypes()) {
97       factory.addSuperinterface(TypeName.get(implementingType));
98     }
99 
100     ImmutableSet<TypeVariableName> factoryTypeVariables = getFactoryTypeVariables(descriptor);
101 
102     addFactoryTypeParameters(factory, factoryTypeVariables);
103     addConstructorAndProviderFields(factory, descriptor);
104     addFactoryMethods(factory, descriptor, factoryTypeVariables);
105     addImplementationMethods(factory, descriptor);
106     addCheckNotNullMethod(factory, descriptor);
107 
108     JavaFile.builder(descriptor.name().packageName(), factory.build())
109         .skipJavaLangImports(true)
110         .build()
111         .writeTo(filer);
112   }
113 
addFactoryTypeParameters( TypeSpec.Builder factory, ImmutableSet<TypeVariableName> typeVariableNames)114   private static void addFactoryTypeParameters(
115       TypeSpec.Builder factory, ImmutableSet<TypeVariableName> typeVariableNames) {
116     factory.addTypeVariables(typeVariableNames);
117   }
118 
addConstructorAndProviderFields( TypeSpec.Builder factory, FactoryDescriptor descriptor)119   private void addConstructorAndProviderFields(
120       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
121     MethodSpec.Builder constructor = constructorBuilder().addAnnotation(Inject.class);
122     if (descriptor.publicType()) {
123       constructor.addModifiers(PUBLIC);
124     }
125     Iterator<ProviderField> providerFields = descriptor.providers().values().iterator();
126     for (int argumentIndex = 1; providerFields.hasNext(); argumentIndex++) {
127       ProviderField provider = providerFields.next();
128       TypeName typeName = resolveTypeName(provider.key().type().get()).box();
129       TypeName providerType = ParameterizedTypeName.get(ClassName.get(Provider.class), typeName);
130       factory.addField(providerType, provider.name(), PRIVATE, FINAL);
131       if (provider.key().qualifier().isPresent()) {
132         // only qualify the constructor parameter
133         providerType = providerType.annotated(AnnotationSpec.get(provider.key().qualifier().get()));
134       }
135       constructor.addParameter(providerType, provider.name());
136       constructor.addStatement("this.$1L = checkNotNull($1L, $2L)", provider.name(), argumentIndex);
137     }
138 
139     factory.addMethod(constructor.build());
140   }
141 
addFactoryMethods( TypeSpec.Builder factory, FactoryDescriptor descriptor, ImmutableSet<TypeVariableName> factoryTypeVariables)142   private void addFactoryMethods(
143       TypeSpec.Builder factory,
144       FactoryDescriptor descriptor,
145       ImmutableSet<TypeVariableName> factoryTypeVariables) {
146     for (FactoryMethodDescriptor methodDescriptor : descriptor.methodDescriptors()) {
147       MethodSpec.Builder method =
148           MethodSpec.methodBuilder(methodDescriptor.name())
149               .addTypeVariables(getMethodTypeVariables(methodDescriptor, factoryTypeVariables))
150               .returns(TypeName.get(methodDescriptor.returnType()))
151               .varargs(methodDescriptor.isVarArgs());
152       if (methodDescriptor.overridingMethod()) {
153         method.addAnnotation(Override.class);
154       }
155       if (methodDescriptor.publicMethod()) {
156         method.addModifiers(PUBLIC);
157       }
158       CodeBlock.Builder args = CodeBlock.builder();
159       method.addParameters(parameters(methodDescriptor.passedParameters()));
160       Iterator<Parameter> parameters = methodDescriptor.creationParameters().iterator();
161       for (int argumentIndex = 1; parameters.hasNext(); argumentIndex++) {
162         Parameter parameter = parameters.next();
163         boolean checkNotNull = !parameter.nullable().isPresent();
164         CodeBlock argument;
165         if (methodDescriptor.passedParameters().contains(parameter)) {
166           argument = CodeBlock.of(parameter.name());
167           if (parameter.isPrimitive()) {
168             checkNotNull = false;
169           }
170         } else {
171           ProviderField provider = descriptor.providers().get(parameter.key());
172           argument = CodeBlock.of(provider.name());
173           if (parameter.isProvider()) {
174             // Providers are checked for nullness in the Factory's constructor.
175             checkNotNull = false;
176           } else {
177             argument = CodeBlock.of("$L.get()", argument);
178           }
179         }
180         if (checkNotNull) {
181           argument = CodeBlock.of("checkNotNull($L, $L)", argument, argumentIndex);
182         }
183         args.add(argument);
184         if (parameters.hasNext()) {
185           args.add(", ");
186         }
187       }
188       method.addStatement("return new $T($L)", methodDescriptor.returnType(), args.build());
189       factory.addMethod(method.build());
190     }
191   }
192 
addImplementationMethods( TypeSpec.Builder factory, FactoryDescriptor descriptor)193   private void addImplementationMethods(
194       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
195     for (ImplementationMethodDescriptor methodDescriptor :
196         descriptor.implementationMethodDescriptors()) {
197       MethodSpec.Builder implementationMethod =
198           methodBuilder(methodDescriptor.name())
199               .addAnnotation(Override.class)
200               .returns(TypeName.get(methodDescriptor.returnType()))
201               .varargs(methodDescriptor.isVarArgs());
202       if (methodDescriptor.publicMethod()) {
203         implementationMethod.addModifiers(PUBLIC);
204       }
205       implementationMethod.addParameters(parameters(methodDescriptor.passedParameters()));
206       implementationMethod.addStatement(
207           "return create($L)",
208           FluentIterable.from(methodDescriptor.passedParameters())
209               .transform(
210                   new Function<Parameter, String>() {
211                     @Override
212                     public String apply(Parameter parameter) {
213                       return parameter.name();
214                     }
215                   })
216               .join(ARGUMENT_JOINER));
217       factory.addMethod(implementationMethod.build());
218     }
219   }
220 
221   /**
222    * {@link ParameterSpec}s to match {@code parameters}. Note that the type of the {@link
223    * ParameterSpec}s match {@link Parameter#type()} and not {@link Key#type()}.
224    */
parameters(Iterable<Parameter> parameters)225   private ImmutableList<ParameterSpec> parameters(Iterable<Parameter> parameters) {
226     ImmutableList.Builder<ParameterSpec> builder = ImmutableList.builder();
227     for (Parameter parameter : parameters) {
228       ParameterSpec.Builder parameterBuilder =
229           ParameterSpec.builder(resolveTypeName(parameter.type().get()), parameter.name());
230       for (AnnotationMirror annotation :
231           Iterables.concat(parameter.nullable().asSet(), parameter.key().qualifier().asSet())) {
232         parameterBuilder.addAnnotation(AnnotationSpec.get(annotation));
233       }
234       builder.add(parameterBuilder.build());
235     }
236     return builder.build();
237   }
238 
addCheckNotNullMethod( TypeSpec.Builder factory, FactoryDescriptor descriptor)239   private static void addCheckNotNullMethod(
240       TypeSpec.Builder factory, FactoryDescriptor descriptor) {
241     if (shouldGenerateCheckNotNull(descriptor)) {
242       TypeVariableName typeVariable = TypeVariableName.get("T");
243       factory.addMethod(
244           methodBuilder("checkNotNull")
245               .addModifiers(PRIVATE, STATIC)
246               .addTypeVariable(typeVariable)
247               .returns(typeVariable)
248               .addParameter(typeVariable, "reference")
249               .addParameter(TypeName.INT, "argumentIndex")
250               .beginControlFlow("if (reference == null)")
251               .addStatement(
252                   "throw new $T($S + argumentIndex)",
253                   NullPointerException.class,
254                   "@AutoFactory method argument is null but is not marked @Nullable. Argument "
255                       + "index: ")
256               .endControlFlow()
257               .addStatement("return reference")
258               .build());
259     }
260   }
261 
shouldGenerateCheckNotNull(FactoryDescriptor descriptor)262   private static boolean shouldGenerateCheckNotNull(FactoryDescriptor descriptor) {
263     if (!descriptor.providers().isEmpty()) {
264       return true;
265     }
266     for (FactoryMethodDescriptor method : descriptor.methodDescriptors()) {
267       for (Parameter parameter : method.creationParameters()) {
268         if (!parameter.nullable().isPresent() && !parameter.type().get().getKind().isPrimitive()) {
269           return true;
270         }
271       }
272     }
273     return false;
274   }
275 
276   /**
277    * Returns an appropriate {@code TypeName} for the given type. If the type is an
278    * {@code ErrorType}, and if it is a simple-name reference to one of the {@code *Factory}
279    * classes that we are going to generate, then we return its fully-qualified name. In every other
280    * case we just return {@code TypeName.get(type)}. Specifically, if it is an {@code ErrorType}
281    * referencing some other type, or referencing one of the classes we are going to generate but
282    * using its fully-qualified name, then we leave it as-is. JavaPoet treats {@code TypeName.get(t)}
283    * the same for {@code ErrorType} as for {@code DeclaredType}, which means that if this is a name
284    * that will eventually be generated then the code we write that references the type will in fact
285    * compile.
286    *
287    * <p>A simpler alternative would be to defer processing to a later round if we find an
288    * {@code @AutoFactory} class that references undefined types, under the assumption that something
289    * else will generate those types in the meanwhile. However, this would fail if for example
290    * {@code @AutoFactory class Foo} has a constructor parameter of type {@code BarFactory} and
291    * {@code @AutoFactory class Bar} has a constructor parameter of type {@code FooFactory}. We did
292    * in fact find instances of this in Google's source base.
293    */
resolveTypeName(TypeMirror type)294   private TypeName resolveTypeName(TypeMirror type) {
295     if (type.getKind() != TypeKind.ERROR) {
296       return TypeName.get(type);
297     }
298     ImmutableSet<PackageAndClass> factoryNames = factoriesBeingCreated.get(type.toString());
299     if (factoryNames.size() == 1) {
300       PackageAndClass packageAndClass = Iterables.getOnlyElement(factoryNames);
301       return ClassName.get(packageAndClass.packageName(), packageAndClass.className());
302     }
303     return TypeName.get(type);
304   }
305 
getFactoryTypeVariables( FactoryDescriptor descriptor)306   private static ImmutableSet<TypeVariableName> getFactoryTypeVariables(
307       FactoryDescriptor descriptor) {
308     ImmutableSet.Builder<TypeVariableName> typeVariables = ImmutableSet.builder();
309     for (ProviderField provider : descriptor.providers().values()) {
310       typeVariables.addAll(getReferencedTypeParameterNames(provider.key().type().get()));
311     }
312     return typeVariables.build();
313   }
314 
getMethodTypeVariables( FactoryMethodDescriptor methodDescriptor, ImmutableSet<TypeVariableName> factoryTypeVariables)315   private static ImmutableSet<TypeVariableName> getMethodTypeVariables(
316       FactoryMethodDescriptor methodDescriptor,
317       ImmutableSet<TypeVariableName> factoryTypeVariables) {
318     ImmutableSet.Builder<TypeVariableName> typeVariables = ImmutableSet.builder();
319     typeVariables.addAll(getReferencedTypeParameterNames(methodDescriptor.returnType()));
320     for (Parameter parameter : methodDescriptor.passedParameters()) {
321       typeVariables.addAll(getReferencedTypeParameterNames(parameter.type().get()));
322     }
323     return Sets.difference(typeVariables.build(), factoryTypeVariables).immutableCopy();
324   }
325 
getReferencedTypeParameterNames(TypeMirror type)326   private static ImmutableSet<TypeVariableName> getReferencedTypeParameterNames(TypeMirror type) {
327     ImmutableSet.Builder<TypeVariableName> typeVariableNames = ImmutableSet.builder();
328     for (TypeVariable typeVariable : TypeVariables.getReferencedTypeVariables(type)) {
329       typeVariableNames.add(TypeVariableName.get(typeVariable));
330     }
331     return typeVariableNames.build();
332   }
333 }
334