1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.processor; 17 18 import com.google.common.base.CaseFormat; 19 import com.google.common.base.Strings; 20 import com.google.common.collect.HashMultimap; 21 import com.google.common.collect.Multimap; 22 import com.squareup.javapoet.ClassName; 23 import com.squareup.javapoet.FieldSpec; 24 import com.squareup.javapoet.JavaFile; 25 import com.squareup.javapoet.MethodSpec; 26 import com.squareup.javapoet.ParameterSpec; 27 import com.squareup.javapoet.ParameterizedTypeName; 28 import com.squareup.javapoet.TypeName; 29 import com.squareup.javapoet.TypeSpec; 30 import com.squareup.javapoet.TypeVariableName; 31 import com.squareup.javapoet.WildcardTypeName; 32 import java.io.IOException; 33 import java.util.Collection; 34 import java.util.Collections; 35 import java.util.HashMap; 36 import java.util.Map; 37 import java.util.Set; 38 import java.util.regex.Matcher; 39 import java.util.regex.Pattern; 40 import javax.annotation.processing.AbstractProcessor; 41 import javax.annotation.processing.Filer; 42 import javax.annotation.processing.Messager; 43 import javax.annotation.processing.ProcessingEnvironment; 44 import javax.annotation.processing.RoundEnvironment; 45 import javax.lang.model.SourceVersion; 46 import javax.lang.model.element.AnnotationMirror; 47 import javax.lang.model.element.AnnotationValue; 48 import javax.lang.model.element.Element; 49 import javax.lang.model.element.ExecutableElement; 50 import javax.lang.model.element.Modifier; 51 import javax.lang.model.element.TypeElement; 52 import javax.lang.model.element.TypeParameterElement; 53 import javax.lang.model.element.VariableElement; 54 import javax.lang.model.type.TypeMirror; 55 import javax.lang.model.type.TypeVariable; 56 import javax.lang.model.util.ElementFilter; 57 import javax.lang.model.util.Elements; 58 import javax.tools.Diagnostic.Kind; 59 60 /** 61 * A compile-time Processor that aggregates classes annotated with {@link 62 * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please 63 * refer to the {@link org.tensorflow.op.annotation.Operator} annotation for details about the API 64 * generated for each annotated class. 65 * 66 * <p>Note that this processor can only be invoked once, in a single compilation run that includes 67 * all the {@code Operator} annotated source classes. The reason is that the {@code Ops} API is an 68 * "aggregating" API, and annotation processing does not permit modifying an already generated 69 * class. 70 * 71 * @see org.tensorflow.op.annotation.Operator 72 */ 73 public final class OperatorProcessor extends AbstractProcessor { 74 75 @Override getSupportedSourceVersion()76 public SourceVersion getSupportedSourceVersion() { 77 return SourceVersion.latest(); 78 } 79 80 @Override init(ProcessingEnvironment processingEnv)81 public synchronized void init(ProcessingEnvironment processingEnv) { 82 super.init(processingEnv); 83 messager = processingEnv.getMessager(); 84 filer = processingEnv.getFiler(); 85 elements = processingEnv.getElementUtils(); 86 } 87 88 @Override process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv)89 public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) { 90 // Nothing needs to be done at the end of all rounds. 91 if (roundEnv.processingOver()) { 92 return false; 93 } 94 95 // Nothing to look at in this round. 96 if (annotations.size() == 0) { 97 return false; 98 } 99 100 // We expect to be registered for exactly one annotation. 101 if (annotations.size() != 1) { 102 throw new IllegalStateException( 103 "Unexpected - multiple annotations registered: " + annotations); 104 } 105 TypeElement annotation = annotations.iterator().next(); 106 Set<? extends Element> annotated = roundEnv.getElementsAnnotatedWith(annotation); 107 108 // If there are no annotated elements, claim the annotation but do nothing. 109 if (annotated.size() == 0) { 110 return false; 111 } 112 113 // This processor has to aggregate all op classes in one round, as it generates a single Ops 114 // API class which cannot be modified once generated. If we find an annotation after we've 115 // generated our code, flag the location of each such class. 116 if (hasRun) { 117 for (Element e : annotated) { 118 error( 119 e, 120 "The Operator processor has already processed @Operator annotated sources\n" 121 + "and written out an Ops API. It cannot process additional @Operator sources.\n" 122 + "One reason this can happen is if other annotation processors generate\n" 123 + "new @Operator source files."); 124 } 125 return false; 126 } 127 128 // Collect all classes tagged with our annotation. 129 Multimap<String, MethodSpec> groupedMethods = HashMultimap.create(); 130 if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { 131 return false; 132 } 133 134 // Nothing to do when there are no tagged classes. 135 if (groupedMethods.isEmpty()) { 136 return false; 137 } 138 139 // Validate operator classes and generate Op API. 140 writeApi(groupedMethods); 141 142 hasRun = true; 143 return false; 144 } 145 146 @Override getSupportedAnnotationTypes()147 public Set<String> getSupportedAnnotationTypes() { 148 return Collections.singleton("org.tensorflow.op.annotation.Operator"); 149 } 150 151 private static final Pattern JAVADOC_TAG_PATTERN = 152 Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); 153 private static final TypeName T_OP = ClassName.get("org.tensorflow.op", "Op"); 154 private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); 155 private static final TypeName T_OPERATOR = 156 ClassName.get("org.tensorflow.op.annotation", "Operator"); 157 private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); 158 private static final TypeName T_EXEC_ENV = 159 ClassName.get("org.tensorflow", "ExecutionEnvironment"); 160 private static final TypeName T_EAGER_SESSION = ClassName.get("org.tensorflow", "EagerSession"); 161 private static final TypeName T_STRING = ClassName.get(String.class); 162 // Operand<?> 163 private static final TypeName T_OPERAND = 164 ParameterizedTypeName.get( 165 ClassName.get("org.tensorflow", "Operand"), WildcardTypeName.subtypeOf(Object.class)); 166 // Iterable<Operand<?>> 167 private static final TypeName T_ITERABLE_OPERAND = 168 ParameterizedTypeName.get(ClassName.get(Iterable.class), T_OPERAND); 169 170 private Filer filer; 171 private Messager messager; 172 private Elements elements; 173 private boolean hasRun = false; 174 error(Element e, String message, Object... args)175 private void error(Element e, String message, Object... args) { 176 if (args != null && args.length > 0) { 177 message = String.format(message, args); 178 } 179 messager.printMessage(Kind.ERROR, message, e); 180 } 181 write(TypeSpec spec)182 private void write(TypeSpec spec) { 183 try { 184 JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); 185 } catch (IOException e) { 186 throw new AssertionError(e); 187 } 188 } 189 writeApi(Multimap<String, MethodSpec> groupedMethods)190 private void writeApi(Multimap<String, MethodSpec> groupedMethods) { 191 Map<String, ClassName> groups = new HashMap<>(); 192 193 // Generate a API class for each group collected other than the default one (= empty string) 194 for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) { 195 if (!entry.getKey().isEmpty()) { 196 TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); 197 write(groupClass); 198 groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name)); 199 } 200 } 201 // Generate the top API class, adding any methods added to the default group 202 TypeSpec topClass = buildTopClass(groups, groupedMethods.get("")); 203 write(topClass); 204 } 205 collectOpsMethods( RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation)206 private boolean collectOpsMethods( 207 RoundEnvironment roundEnv, 208 Multimap<String, MethodSpec> groupedMethods, 209 TypeElement annotation) { 210 boolean result = true; 211 for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { 212 // @Operator can only apply to types, so e must be a TypeElement. 213 if (!(e instanceof TypeElement)) { 214 error( 215 e, 216 "@Operator can only be applied to classes, but this is a %s", 217 e.getKind().toString()); 218 result = false; 219 continue; 220 } 221 TypeElement opClass = (TypeElement) e; 222 // Skip deprecated operations for now, as we do not guarantee API stability yet 223 if (opClass.getAnnotation(Deprecated.class) == null) { 224 collectOpMethods(groupedMethods, opClass, annotation); 225 } 226 } 227 return result; 228 } 229 collectOpMethods( Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation)230 private void collectOpMethods( 231 Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { 232 AnnotationMirror am = getAnnotationMirror(opClass, annotation); 233 String groupName = getAnnotationElementValueAsString("group", am); 234 String methodName = getAnnotationElementValueAsString("name", am); 235 ClassName opClassName = ClassName.get(opClass); 236 if (Strings.isNullOrEmpty(methodName)) { 237 methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); 238 } 239 // Build a method for each @Operator found in the class path. There should be one method per 240 // operation factory called 241 // "create", which takes in parameter a scope and, optionally, a list of arguments 242 for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { 243 if (opMethod.getModifiers().contains(Modifier.STATIC) 244 && opMethod.getSimpleName().contentEquals("create")) { 245 MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); 246 groupedMethods.put(groupName, method); 247 } 248 } 249 } 250 buildOpMethod( String methodName, ClassName opClassName, ExecutableElement factoryMethod)251 private MethodSpec buildOpMethod( 252 String methodName, ClassName opClassName, ExecutableElement factoryMethod) { 253 MethodSpec.Builder builder = 254 MethodSpec.methodBuilder(methodName) 255 .addModifiers(Modifier.PUBLIC) 256 .returns(TypeName.get(factoryMethod.getReturnType())) 257 .varargs(factoryMethod.isVarArgs()) 258 .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); 259 260 for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { 261 TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); 262 builder.addTypeVariable(tvn); 263 } 264 for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { 265 builder.addException(TypeName.get(thrownType)); 266 } 267 StringBuilder call = new StringBuilder("return $T.create(scope"); 268 boolean first = true; 269 for (VariableElement param : factoryMethod.getParameters()) { 270 ParameterSpec p = ParameterSpec.get(param); 271 if (first) { 272 first = false; 273 continue; 274 } 275 call.append(", "); 276 call.append(p.name); 277 builder.addParameter(p); 278 } 279 call.append(")"); 280 builder.addStatement(call.toString(), opClassName); 281 return builder.build(); 282 } 283 buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod)284 private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { 285 StringBuilder javadoc = new StringBuilder(); 286 javadoc.append("Builds an {@link ").append(opClassName.simpleName()).append("} operation\n\n"); 287 288 // Add all javadoc tags found in the operator factory method but the first one, which should be 289 // in all cases the 290 // 'scope' parameter that is implicitly passed by this API 291 Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); 292 boolean firstParam = true; 293 294 while (tagMatcher.find()) { 295 String tag = tagMatcher.group(); 296 if (tag.startsWith("@param") && firstParam) { 297 firstParam = false; 298 } else { 299 javadoc.append(tag).append('\n'); 300 } 301 } 302 javadoc.append("@see ").append(opClassName).append("\n"); 303 304 return javadoc.toString(); 305 } 306 buildGroupClass(String group, Collection<MethodSpec> methods)307 private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) { 308 MethodSpec.Builder ctorBuilder = 309 MethodSpec.constructorBuilder() 310 .addParameter(T_SCOPE, "scope") 311 .addStatement("this.scope = scope"); 312 313 TypeSpec.Builder builder = 314 TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") 315 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 316 .addJavadoc( 317 "An API for building {@code $L} operations as {@link $T Op}s\n\n" 318 + "@see {@link $T}\n", 319 group, 320 T_OP, 321 T_OPS) 322 .addMethods(methods) 323 .addMethod(ctorBuilder.build()); 324 325 builder.addField( 326 FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); 327 328 return builder.build(); 329 } 330 buildTopClass( Map<String, ClassName> groupToClass, Collection<MethodSpec> methods)331 private static TypeSpec buildTopClass( 332 Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { 333 MethodSpec.Builder ctorBuilder = 334 MethodSpec.constructorBuilder() 335 .addModifiers(Modifier.PRIVATE) 336 .addParameter(T_SCOPE, "scope") 337 .addStatement("this.scope = scope", T_SCOPE); 338 339 for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { 340 ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); 341 } 342 343 TypeSpec.Builder opsBuilder = 344 TypeSpec.classBuilder("Ops") 345 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 346 .addJavadoc( 347 "An API for building operations as {@link $T Op}s\n<p>\n" 348 + "Any operation wrapper found in the classpath properly annotated as an" 349 + "{@link $T @Operator} is exposed\n" 350 + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" 351 + "try (Graph g = new Graph()) {\n" 352 + " Ops ops = Ops.create(g);\n" 353 + " // Operations are typed classes with convenience\n" 354 + " // builders in Ops.\n" 355 + " Constant three = ops.constant(3);\n" 356 + " // Single-result operations implement the Operand\n" 357 + " // interface, so this works too.\n" 358 + " Operand four = ops.constant(4);\n" 359 + " // Most builders are found within a group, and accept\n" 360 + " // Operand types as operands\n" 361 + " Operand nine = ops.math.add(four, ops.constant(5));\n" 362 + " // Multi-result operations however offer methods to\n" 363 + " // select a particular result for use.\n" 364 + " Operand result = \n" 365 + " ops.math.add(ops.unique(s, a).y(), b);\n" 366 + " // Optional attributes\n" 367 + " ops.linalg.matMul(a, b, MatMul.transposeA(true));\n" 368 + " // Naming operators\n" 369 + " ops.withName(\"foo\").constant(5); // name \"foo\"\n" 370 + " // Names can exist in a hierarchy\n" 371 + " Ops sub = ops.withSubScope(\"sub\");\n" 372 + " sub.withName(\"bar\").constant(4); // \"sub/bar\"\n" 373 + "}\n" 374 + "}</pre>\n", 375 T_OP, 376 T_OPERATOR) 377 .addMethods(methods) 378 .addMethod(ctorBuilder.build()); 379 380 opsBuilder.addMethod( 381 MethodSpec.methodBuilder("withSubScope") 382 .addModifiers(Modifier.PUBLIC) 383 .addParameter(T_STRING, "childScopeName") 384 .returns(T_OPS) 385 .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) 386 .addJavadoc( 387 "Returns an API that builds operations with the provided name prefix.\n" 388 + "\n@see {@link $T#withSubScope(String)}\n", 389 T_SCOPE) 390 .build()); 391 392 opsBuilder.addMethod( 393 MethodSpec.methodBuilder("withName") 394 .addModifiers(Modifier.PUBLIC) 395 .addParameter(T_STRING, "opName") 396 .returns(T_OPS) 397 .addStatement("return new Ops(scope.withName(opName))") 398 .addJavadoc( 399 "Returns an API that uses the provided name for an op.\n\n" 400 + "@see {@link $T#withName(String)}\n", 401 T_SCOPE) 402 .build()); 403 404 opsBuilder.addMethod( 405 MethodSpec.methodBuilder("withControlDependencies") 406 .addModifiers(Modifier.PUBLIC) 407 .addParameter(T_ITERABLE_OPERAND, "controls") 408 .returns(T_OPS) 409 .addStatement("return new Ops(scope.withControlDependencies(controls))") 410 .addJavadoc( 411 "Returns an API that adds operations to the graph with the provided control" 412 + " dependencies.\n\n" 413 + "@see {@link $T#withControlDependencies(Iterable<Operand<?>>)}\n", 414 T_SCOPE) 415 .build()); 416 417 opsBuilder.addField( 418 FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); 419 420 opsBuilder.addMethod( 421 MethodSpec.methodBuilder("scope") 422 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 423 .returns(T_SCOPE) 424 .addStatement("return scope") 425 .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) 426 .build()); 427 428 for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { 429 opsBuilder.addField( 430 FieldSpec.builder(entry.getValue(), entry.getKey()) 431 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 432 .build()); 433 434 opsBuilder.addMethod( 435 MethodSpec.methodBuilder(entry.getKey()) 436 .addModifiers(Modifier.PUBLIC, Modifier.FINAL) 437 .returns(entry.getValue()) 438 .addStatement("return $L", entry.getKey()) 439 .addJavadoc("Returns an API for building {@code $L} operations\n", entry.getKey()) 440 .build()); 441 } 442 443 opsBuilder.addMethod( 444 MethodSpec.methodBuilder("create") 445 .addModifiers(Modifier.PUBLIC, Modifier.STATIC) 446 .addParameter(T_EXEC_ENV, "env") 447 .returns(T_OPS) 448 .addStatement("return new Ops(new $T(env))", T_SCOPE) 449 .addJavadoc( 450 "Creates an API for building operations in the provided execution environment\n") 451 .build()); 452 453 opsBuilder.addMethod( 454 MethodSpec.methodBuilder("create") 455 .addModifiers(Modifier.PUBLIC, Modifier.STATIC) 456 .returns(T_OPS) 457 .addStatement("return new Ops(new $T($T.getDefault()))", T_SCOPE, T_EAGER_SESSION) 458 .addJavadoc( 459 "Creates an API for building operations in the default eager execution" 460 + " environment\n\n" 461 + "<p>Invoking this method is equivalent to {@code" 462 + " Ops.create(EagerSession.getDefault())}.\n") 463 .build()); 464 465 return opsBuilder.build(); 466 } 467 getAnnotationMirror(Element element, TypeElement annotation)468 private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) { 469 for (AnnotationMirror am : element.getAnnotationMirrors()) { 470 if (am.getAnnotationType().asElement().equals(annotation)) { 471 return am; 472 } 473 } 474 throw new IllegalArgumentException( 475 "Annotation " 476 + annotation.getSimpleName() 477 + " not present on element " 478 + element.getSimpleName()); 479 } 480 getAnnotationElementValueAsString(String elementName, AnnotationMirror am)481 private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { 482 for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : 483 am.getElementValues().entrySet()) { 484 if (entry.getKey().getSimpleName().contentEquals(elementName)) { 485 return entry.getValue().getValue().toString(); 486 } 487 } 488 return ""; 489 } 490 } 491