1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.processor;
17 
18 import com.google.common.base.CaseFormat;
19 import com.google.common.base.Strings;
20 import com.google.common.collect.HashMultimap;
21 import com.google.common.collect.Multimap;
22 import com.squareup.javapoet.ClassName;
23 import com.squareup.javapoet.FieldSpec;
24 import com.squareup.javapoet.JavaFile;
25 import com.squareup.javapoet.MethodSpec;
26 import com.squareup.javapoet.ParameterSpec;
27 import com.squareup.javapoet.ParameterizedTypeName;
28 import com.squareup.javapoet.TypeName;
29 import com.squareup.javapoet.TypeSpec;
30 import com.squareup.javapoet.TypeVariableName;
31 import com.squareup.javapoet.WildcardTypeName;
32 import java.io.IOException;
33 import java.util.Collection;
34 import java.util.Collections;
35 import java.util.HashMap;
36 import java.util.Map;
37 import java.util.Set;
38 import java.util.regex.Matcher;
39 import java.util.regex.Pattern;
40 import javax.annotation.processing.AbstractProcessor;
41 import javax.annotation.processing.Filer;
42 import javax.annotation.processing.Messager;
43 import javax.annotation.processing.ProcessingEnvironment;
44 import javax.annotation.processing.RoundEnvironment;
45 import javax.lang.model.SourceVersion;
46 import javax.lang.model.element.AnnotationMirror;
47 import javax.lang.model.element.AnnotationValue;
48 import javax.lang.model.element.Element;
49 import javax.lang.model.element.ExecutableElement;
50 import javax.lang.model.element.Modifier;
51 import javax.lang.model.element.TypeElement;
52 import javax.lang.model.element.TypeParameterElement;
53 import javax.lang.model.element.VariableElement;
54 import javax.lang.model.type.TypeMirror;
55 import javax.lang.model.type.TypeVariable;
56 import javax.lang.model.util.ElementFilter;
57 import javax.lang.model.util.Elements;
58 import javax.tools.Diagnostic.Kind;
59 
60 /**
61  * A compile-time Processor that aggregates classes annotated with {@link
62  * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
63  * refer to the {@link org.tensorflow.op.annotation.Operator} annotation for details about the API
64  * generated for each annotated class.
65  *
66  * <p>Note that this processor can only be invoked once, in a single compilation run that includes
67  * all the {@code Operator} annotated source classes. The reason is that the {@code Ops} API is an
68  * "aggregating" API, and annotation processing does not permit modifying an already generated
69  * class.
70  *
71  * @see org.tensorflow.op.annotation.Operator
72  */
73 public final class OperatorProcessor extends AbstractProcessor {
74 
75   @Override
getSupportedSourceVersion()76   public SourceVersion getSupportedSourceVersion() {
77     return SourceVersion.latest();
78   }
79 
80   @Override
init(ProcessingEnvironment processingEnv)81   public synchronized void init(ProcessingEnvironment processingEnv) {
82     super.init(processingEnv);
83     messager = processingEnv.getMessager();
84     filer = processingEnv.getFiler();
85     elements = processingEnv.getElementUtils();
86   }
87 
88   @Override
process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv)89   public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
90     // Nothing needs to be done at the end of all rounds.
91     if (roundEnv.processingOver()) {
92       return false;
93     }
94 
95     // Nothing to look at in this round.
96     if (annotations.size() == 0) {
97       return false;
98     }
99 
100     // We expect to be registered for exactly one annotation.
101     if (annotations.size() != 1) {
102       throw new IllegalStateException(
103           "Unexpected - multiple annotations registered: " + annotations);
104     }
105     TypeElement annotation = annotations.iterator().next();
106     Set<? extends Element> annotated = roundEnv.getElementsAnnotatedWith(annotation);
107 
108     // If there are no annotated elements, claim the annotation but do nothing.
109     if (annotated.size() == 0) {
110       return false;
111     }
112 
113     // This processor has to aggregate all op classes in one round, as it generates a single Ops
114     // API class which cannot be modified once generated. If we find an annotation after we've
115     // generated our code, flag the location of each such class.
116     if (hasRun) {
117       for (Element e : annotated) {
118         error(
119             e,
120             "The Operator processor has already processed @Operator annotated sources\n"
121                 + "and written out an Ops API. It cannot process additional @Operator sources.\n"
122                 + "One reason this can happen is if other annotation processors generate\n"
123                 + "new @Operator source files.");
124       }
125       return false;
126     }
127 
128     // Collect all classes tagged with our annotation.
129     Multimap<String, MethodSpec> groupedMethods = HashMultimap.create();
130     if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) {
131       return false;
132     }
133 
134     // Nothing to do when there are no tagged classes.
135     if (groupedMethods.isEmpty()) {
136       return false;
137     }
138 
139     // Validate operator classes and generate Op API.
140     writeApi(groupedMethods);
141 
142     hasRun = true;
143     return false;
144   }
145 
146   @Override
getSupportedAnnotationTypes()147   public Set<String> getSupportedAnnotationTypes() {
148     return Collections.singleton("org.tensorflow.op.annotation.Operator");
149   }
150 
151   private static final Pattern JAVADOC_TAG_PATTERN =
152       Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
153   private static final TypeName T_OP = ClassName.get("org.tensorflow.op", "Op");
154   private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
155   private static final TypeName T_OPERATOR =
156       ClassName.get("org.tensorflow.op.annotation", "Operator");
157   private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
158   private static final TypeName T_EXEC_ENV =
159       ClassName.get("org.tensorflow", "ExecutionEnvironment");
160   private static final TypeName T_EAGER_SESSION = ClassName.get("org.tensorflow", "EagerSession");
161   private static final TypeName T_STRING = ClassName.get(String.class);
162   // Operand<?>
163   private static final TypeName T_OPERAND =
164       ParameterizedTypeName.get(
165           ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class));
166   // Iterable<Operand<?>>
167   private static final TypeName T_ITERABLE_OPERAND =
168       ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND);
169 
170   private Filer filer;
171   private Messager messager;
172   private Elements elements;
173   private boolean hasRun = false;
174 
error(Element e, String message, Object... args)175   private void error(Element e, String message, Object... args) {
176     if (args != null && args.length > 0) {
177       message = String.format(message, args);
178     }
179     messager.printMessage(Kind.ERROR, message, e);
180   }
181 
write(TypeSpec spec)182   private void write(TypeSpec spec) {
183     try {
184       JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
185     } catch (IOException e) {
186       throw new AssertionError(e);
187     }
188   }
189 
writeApi(Multimap<String, MethodSpec> groupedMethods)190   private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
191     Map<String, ClassName> groups = new HashMap<>();
192 
193     // Generate a API class for each group collected other than the default one (= empty string)
194     for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
195       if (!entry.getKey().isEmpty()) {
196         TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
197         write(groupClass);
198         groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name));
199       }
200     }
201     // Generate the top API class, adding any methods added to the default group
202     TypeSpec topClass = buildTopClass(groups, groupedMethods.get(""));
203     write(topClass);
204   }
205 
collectOpsMethods( RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation)206   private boolean collectOpsMethods(
207       RoundEnvironment roundEnv,
208       Multimap<String, MethodSpec> groupedMethods,
209       TypeElement annotation) {
210     boolean result = true;
211     for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
212       // @Operator can only apply to types, so e must be a TypeElement.
213       if (!(e instanceof TypeElement)) {
214         error(
215             e,
216             "@Operator can only be applied to classes, but this is a %s",
217             e.getKind().toString());
218         result = false;
219         continue;
220       }
221       TypeElement opClass = (TypeElement) e;
222       // Skip deprecated operations for now, as we do not guarantee API stability yet
223       if (opClass.getAnnotation(Deprecated.class) == null) {
224         collectOpMethods(groupedMethods, opClass, annotation);
225       }
226     }
227     return result;
228   }
229 
collectOpMethods( Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation)230   private void collectOpMethods(
231       Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
232     AnnotationMirror am = getAnnotationMirror(opClass, annotation);
233     String groupName = getAnnotationElementValueAsString("group", am);
234     String methodName = getAnnotationElementValueAsString("name", am);
235     ClassName opClassName = ClassName.get(opClass);
236     if (Strings.isNullOrEmpty(methodName)) {
237       methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
238     }
239     // Build a method for each @Operator found in the class path. There should be one method per
240     // operation factory called
241     // "create", which takes in parameter a scope and, optionally, a list of arguments
242     for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
243       if (opMethod.getModifiers().contains(Modifier.STATIC)
244           && opMethod.getSimpleName().contentEquals("create")) {
245         MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
246         groupedMethods.put(groupName, method);
247       }
248     }
249   }
250 
buildOpMethod( String methodName, ClassName opClassName, ExecutableElement factoryMethod)251   private MethodSpec buildOpMethod(
252       String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
253     MethodSpec.Builder builder =
254         MethodSpec.methodBuilder(methodName)
255             .addModifiers(Modifier.PUBLIC)
256             .returns(TypeName.get(factoryMethod.getReturnType()))
257             .varargs(factoryMethod.isVarArgs())
258             .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
259 
260     for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
261       TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
262       builder.addTypeVariable(tvn);
263     }
264     for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
265       builder.addException(TypeName.get(thrownType));
266     }
267     StringBuilder call = new StringBuilder("return $T.create(scope");
268     boolean first = true;
269     for (VariableElement param : factoryMethod.getParameters()) {
270       ParameterSpec p = ParameterSpec.get(param);
271       if (first) {
272         first = false;
273         continue;
274       }
275       call.append(", ");
276       call.append(p.name);
277       builder.addParameter(p);
278     }
279     call.append(")");
280     builder.addStatement(call.toString(), opClassName);
281     return builder.build();
282   }
283 
buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod)284   private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
285     StringBuilder javadoc = new StringBuilder();
286     javadoc.append("Builds an {@link ").append(opClassName.simpleName()).append("} operation\n\n");
287 
288     // Add all javadoc tags found in the operator factory method but the first one, which should be
289     // in all cases the
290     // 'scope' parameter that is implicitly passed by this API
291     Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
292     boolean firstParam = true;
293 
294     while (tagMatcher.find()) {
295       String tag = tagMatcher.group();
296       if (tag.startsWith("@param") && firstParam) {
297         firstParam = false;
298       } else {
299         javadoc.append(tag).append('\n');
300       }
301     }
302     javadoc.append("@see ").append(opClassName).append("\n");
303 
304     return javadoc.toString();
305   }
306 
buildGroupClass(String group, Collection<MethodSpec> methods)307   private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
308     MethodSpec.Builder ctorBuilder =
309         MethodSpec.constructorBuilder()
310             .addParameter(T_SCOPE, "scope")
311             .addStatement("this.scope = scope");
312 
313     TypeSpec.Builder builder =
314         TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
315             .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
316             .addJavadoc(
317                 "An API for building {@code $L} operations as {@link $T Op}s\n\n"
318                     + "@see {@link $T}\n",
319                 group,
320                 T_OP,
321                 T_OPS)
322             .addMethods(methods)
323             .addMethod(ctorBuilder.build());
324 
325     builder.addField(
326         FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
327 
328     return builder.build();
329   }
330 
buildTopClass( Map<String, ClassName> groupToClass, Collection<MethodSpec> methods)331   private static TypeSpec buildTopClass(
332       Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
333     MethodSpec.Builder ctorBuilder =
334         MethodSpec.constructorBuilder()
335             .addModifiers(Modifier.PRIVATE)
336             .addParameter(T_SCOPE, "scope")
337             .addStatement("this.scope = scope", T_SCOPE);
338 
339     for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
340       ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
341     }
342 
343     TypeSpec.Builder opsBuilder =
344         TypeSpec.classBuilder("Ops")
345             .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
346             .addJavadoc(
347                 "An API for building operations as {@link $T Op}s\n<p>\n"
348                     + "Any operation wrapper found in the classpath properly annotated as an"
349                     + "{@link $T @Operator} is exposed\n"
350                     + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
351                     + "try (Graph g = new Graph()) {\n"
352                     + "  Ops ops = Ops.create(g);\n"
353                     + "  // Operations are typed classes with convenience\n"
354                     + "  // builders in Ops.\n"
355                     + "  Constant three = ops.constant(3);\n"
356                     + "  // Single-result operations implement the Operand\n"
357                     + "  // interface, so this works too.\n"
358                     + "  Operand four = ops.constant(4);\n"
359                     + "  // Most builders are found within a group, and accept\n"
360                     + "  // Operand types as operands\n"
361                     + "  Operand nine = ops.math.add(four, ops.constant(5));\n"
362                     + "  // Multi-result operations however offer methods to\n"
363                     + "  // select a particular result for use.\n"
364                     + "  Operand result = \n"
365                     + "      ops.math.add(ops.unique(s, a).y(), b);\n"
366                     + "  // Optional attributes\n"
367                     + "  ops.linalg.matMul(a, b, MatMul.transposeA(true));\n"
368                     + "  // Naming operators\n"
369                     + "  ops.withName(\"foo\").constant(5); // name \"foo\"\n"
370                     + "  // Names can exist in a hierarchy\n"
371                     + "  Ops sub = ops.withSubScope(\"sub\");\n"
372                     + "  sub.withName(\"bar\").constant(4); // \"sub/bar\"\n"
373                     + "}\n"
374                     + "}</pre>\n",
375                 T_OP,
376                 T_OPERATOR)
377             .addMethods(methods)
378             .addMethod(ctorBuilder.build());
379 
380     opsBuilder.addMethod(
381         MethodSpec.methodBuilder("withSubScope")
382             .addModifiers(Modifier.PUBLIC)
383             .addParameter(T_STRING, "childScopeName")
384             .returns(T_OPS)
385             .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
386             .addJavadoc(
387                 "Returns an API that builds operations with the provided name prefix.\n"
388                     + "\n@see {@link $T#withSubScope(String)}\n",
389                 T_SCOPE)
390             .build());
391 
392     opsBuilder.addMethod(
393         MethodSpec.methodBuilder("withName")
394             .addModifiers(Modifier.PUBLIC)
395             .addParameter(T_STRING, "opName")
396             .returns(T_OPS)
397             .addStatement("return new Ops(scope.withName(opName))")
398             .addJavadoc(
399                 "Returns an API that uses the provided name for an op.\n\n"
400                     + "@see {@link $T#withName(String)}\n",
401                 T_SCOPE)
402             .build());
403 
404     opsBuilder.addMethod(
405         MethodSpec.methodBuilder("withControlDependencies")
406             .addModifiers(Modifier.PUBLIC)
407             .addParameter(T_ITERABLE_OPERAND, "controls")
408             .returns(T_OPS)
409             .addStatement("return new Ops(scope.withControlDependencies(controls))")
410             .addJavadoc(
411                 "Returns an API that adds operations to the graph with the provided control"
412                     + " dependencies.\n\n"
413                     + "@see {@link $T#withControlDependencies(Iterable<Operand<?>>)}\n",
414                 T_SCOPE)
415             .build());
416 
417     opsBuilder.addField(
418         FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
419 
420     opsBuilder.addMethod(
421         MethodSpec.methodBuilder("scope")
422             .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
423             .returns(T_SCOPE)
424             .addStatement("return scope")
425             .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
426             .build());
427 
428     for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
429       opsBuilder.addField(
430           FieldSpec.builder(entry.getValue(), entry.getKey())
431               .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
432               .build());
433 
434       opsBuilder.addMethod(
435           MethodSpec.methodBuilder(entry.getKey())
436               .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
437               .returns(entry.getValue())
438               .addStatement("return $L", entry.getKey())
439               .addJavadoc("Returns an API for building {@code $L} operations\n", entry.getKey())
440               .build());
441     }
442 
443     opsBuilder.addMethod(
444         MethodSpec.methodBuilder("create")
445             .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
446             .addParameter(T_EXEC_ENV, "env")
447             .returns(T_OPS)
448             .addStatement("return new Ops(new $T(env))", T_SCOPE)
449             .addJavadoc(
450                 "Creates an API for building operations in the provided execution environment\n")
451             .build());
452 
453     opsBuilder.addMethod(
454         MethodSpec.methodBuilder("create")
455             .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
456             .returns(T_OPS)
457             .addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION)
458             .addJavadoc(
459                 "Creates an API for building operations in the default eager execution"
460                     + " environment\n\n"
461                     + "<p>Invoking this method is equivalent to {@code"
462                     + " Ops.create(EagerSession.getDefault())}.\n")
463             .build());
464 
465     return opsBuilder.build();
466   }
467 
getAnnotationMirror(Element element, TypeElement annotation)468   private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) {
469     for (AnnotationMirror am : element.getAnnotationMirrors()) {
470       if (am.getAnnotationType().asElement().equals(annotation)) {
471         return am;
472       }
473     }
474     throw new IllegalArgumentException(
475         "Annotation "
476             + annotation.getSimpleName()
477             + " not present on element "
478             + element.getSimpleName());
479   }
480 
getAnnotationElementValueAsString(String elementName, AnnotationMirror am)481   private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
482     for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
483         am.getElementValues().entrySet()) {
484       if (entry.getKey().getSimpleName().contentEquals(elementName)) {
485         return entry.getValue().getValue().toString();
486       }
487     }
488     return "";
489   }
490 }
491