1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow; 17 18 import java.util.ArrayList; 19 import java.util.List; 20 21 /** 22 * Driver for {@link Graph} execution. 23 * 24 * <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a 25 * {@link Graph} are executed to compute {@link Tensor}s. For example: 26 * 27 * <pre>{@code 28 * // Let's say graph is an instance of the Graph class 29 * // for the computation y = 3 * x 30 * 31 * try (Session s = new Session(graph)) { 32 * try (Tensor x = Tensor.create(2.0f); 33 * Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) { 34 * System.out.println(y.floatValue()); // Will print 6.0f 35 * } 36 * try (Tensor x = Tensor.create(1.1f); 37 * Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) { 38 * System.out.println(y.floatValue()); // Will print 3.3f 39 * } 40 * } 41 * }</pre> 42 * 43 * <p><b>WARNING:</b>A {@code Session} owns resources that <b>must</b> be explicitly freed by 44 * invoking {@link #close()}. 45 * 46 * <p>Instances of a Session are thread-safe. 47 */ 48 public final class Session implements AutoCloseable { 49 50 /** Construct a new session with the associated {@link Graph}. */ Session(Graph g)51 public Session(Graph g) { 52 this(g, null); 53 } 54 55 /** 56 * Construct a new session with the associated {@link Graph} and configuration options. 57 * 58 * @param g The {@link Graph} the created Session will operate on. 59 * @param config Configuration parameters for the session specified as a serialized <a 60 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto</a> 61 * protocol buffer. 62 * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto 63 * protocol buffer. 64 */ Session(Graph g, byte[] config)65 public Session(Graph g, byte[] config) { 66 graph = g; 67 Graph.Reference r = g.ref(); 68 try { 69 nativeHandle = 70 (config == null) ? allocate(r.nativeHandle()) : allocate2(r.nativeHandle(), null, config); 71 graphRef = g.ref(); 72 } finally { 73 r.close(); 74 } 75 } 76 77 /** Wrap an existing session with the associated {@link Graph}. */ Session(Graph g, long nativeHandle)78 Session(Graph g, long nativeHandle) { 79 graph = g; 80 this.nativeHandle = nativeHandle; 81 graphRef = g.ref(); 82 } 83 84 /** 85 * Release resources associated with the Session. 86 * 87 * <p>Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session 88 * is not usable after close returns. 89 */ 90 @Override close()91 public void close() { 92 graphRef.close(); 93 synchronized (nativeHandleLock) { 94 if (nativeHandle == 0) { 95 return; 96 } 97 while (numActiveRuns > 0) { 98 try { 99 nativeHandleLock.wait(); 100 } catch (InterruptedException e) { 101 Thread.currentThread().interrupt(); 102 // Possible leak of the Session and Graph in this case? 103 return; 104 } 105 } 106 delete(nativeHandle); 107 nativeHandle = 0; 108 } 109 } 110 111 /** 112 * Run {@link Operation}s and evaluate {@link Tensor}s. 113 * 114 * <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to 115 * evaluate the {@link Tensor}s to fetch. The {@link #feed(String,int,Tensor)} call allows callers 116 * to override the value of {@link Tensor}s in the graph by substituting the provided {@link 117 * Tensor}s for the outputs of the operations provided to {@link #feed(String,int,Tensor)}. 118 */ 119 public final class Runner { 120 /** 121 * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. 122 * 123 * @param operation Is either the string name of the operation, in which case this method is a 124 * shorthand for {@code feed(operation, 0)}, or it is a string of the form 125 * <tt>operation_name:output_index</tt> , in which case this method acts like {@code 126 * feed(operation_name, output_index)}. These colon-separated names are commonly used in the 127 * {@code SignatureDef} protocol buffer messages that are included in {@link 128 * SavedModelBundle#metaGraphDef()}. 129 */ feed(String operation, Tensor<?> t)130 public Runner feed(String operation, Tensor<?> t) { 131 return feed(parseOutput(operation), t); 132 } 133 134 /** 135 * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} 136 * for the value it produces. 137 * 138 * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which 139 * one {@code t} is being provided for. 140 */ feed(String operation, int index, Tensor<?> t)141 public Runner feed(String operation, int index, Tensor<?> t) { 142 Operation op = operationByName(operation); 143 if (op != null) { 144 inputs.add(op.output(index)); 145 inputTensors.add(t); 146 } 147 return this; 148 } 149 150 /** 151 * Use {@code t} instead of the Tensor referred to by executing the operation referred to by 152 * {@code operand}. 153 */ feed(Operand<?> operand, Tensor<?> t)154 public Runner feed(Operand<?> operand, Tensor<?> t) { 155 inputs.add(operand.asOutput()); 156 inputTensors.add(t); 157 return this; 158 } 159 160 /** 161 * Make {@link #run()} return the output of {@code operation}. 162 * 163 * @param operation Is either the string name of the operation, in which case this method is a 164 * shorthand for {@code fetch(operation, 0)}, or it is a string of the form 165 * <tt>operation_name:output_index</tt> , in which case this method acts like {@code 166 * fetch(operation_name, output_index)}. These colon-separated names are commonly used in 167 * the {@code SignatureDef} protocol buffer messages that are included in {@link 168 * SavedModelBundle#metaGraphDef()}. 169 */ fetch(String operation)170 public Runner fetch(String operation) { 171 return fetch(parseOutput(operation)); 172 } 173 174 /** 175 * Make {@link #run()} return the {@code index}-th output of {@code operation}. 176 * 177 * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which 178 * one to return. 179 */ fetch(String operation, int index)180 public Runner fetch(String operation, int index) { 181 Operation op = operationByName(operation); 182 if (op != null) { 183 outputs.add(op.output(index)); 184 } 185 return this; 186 } 187 188 /** 189 * Makes {@link #run()} return the Tensor referred to by {@code output}. 190 */ fetch(Output<?> output)191 public Runner fetch(Output<?> output) { 192 outputs.add(output); 193 return this; 194 } 195 196 /** 197 * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. 198 */ fetch(Operand<?> operand)199 public Runner fetch(Operand<?> operand) { 200 return fetch(operand.asOutput()); 201 } 202 203 /** 204 * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. 205 */ addTarget(String operation)206 public Runner addTarget(String operation) { 207 Operation op = operationByName(operation); 208 if (op != null) { 209 targets.add(op); 210 } 211 return this; 212 } 213 214 /** 215 * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. 216 */ addTarget(Operation operation)217 public Runner addTarget(Operation operation) { 218 targets.add(operation); 219 return this; 220 } 221 222 /** 223 * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s. 224 */ addTarget(Operand<?> operand)225 public Runner addTarget(Operand<?> operand) { 226 return addTarget(operand.asOutput().op()); 227 } 228 229 /** 230 * (Experimental method): set options (typically for debugging) for this run. 231 * 232 * <p>The options are presented as a serialized <a 233 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions 234 * protocol buffer</a>. 235 * 236 * <p>The org.tensorflow package is free of any protocol buffer dependencies in order to remain 237 * friendly to resource constrained systems (where something like <a 238 * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may 239 * be more appropriate). A cost of that is this lack of type-safety in this API function. This 240 * choice is under review and this function may be replaced by more type-safe equivalents at any 241 * time. 242 */ setOptions(byte[] options)243 public Runner setOptions(byte[] options) { 244 this.runOptions = options; 245 return this; 246 } 247 248 /** 249 * Execute the graph fragments necessary to compute all requested fetches. 250 * 251 * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor}s, i.e., the 252 * caller must call {@link Tensor#close()} on all elements of the returned list to free up 253 * resources. 254 * 255 * <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it 256 * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in 257 * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a 258 * {@code Map<Output, Tensor>}? 259 * 260 * <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to 261 * extract output tensors in a type-safe way. 262 */ run()263 public List<Tensor<?>> run() { 264 return runHelper(false).outputs; 265 } 266 267 /** 268 * Execute graph fragments to compute requested fetches and return metadata about the run. 269 * 270 * <p>This is exactly like {@link #run()}, but in addition to the requested Tensors, also 271 * returns metadata about the graph execution in the form of a serialized <a 272 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata 273 * protocol buffer</a>. 274 */ runAndFetchMetadata()275 public Run runAndFetchMetadata() { 276 return runHelper(true); 277 } 278 runHelper(boolean wantMetadata)279 private Run runHelper(boolean wantMetadata) { 280 long[] inputTensorHandles = new long[inputTensors.size()]; 281 long[] inputOpHandles = new long[inputs.size()]; 282 int[] inputOpIndices = new int[inputs.size()]; 283 long[] outputOpHandles = new long[outputs.size()]; 284 int[] outputOpIndices = new int[outputs.size()]; 285 long[] targetOpHandles = new long[targets.size()]; 286 long[] outputTensorHandles = new long[outputs.size()]; 287 288 // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the 289 // validity of the Graph and graphRef ensures that. 290 int idx = 0; 291 for (Tensor<?> t : inputTensors) { 292 inputTensorHandles[idx++] = t.getNativeHandle(); 293 } 294 idx = 0; 295 for (Output<?> o : inputs) { 296 inputOpHandles[idx] = o.op().getUnsafeNativeHandle(); 297 inputOpIndices[idx] = o.index(); 298 idx++; 299 } 300 idx = 0; 301 for (Output<?> o : outputs) { 302 outputOpHandles[idx] = o.op().getUnsafeNativeHandle(); 303 outputOpIndices[idx] = o.index(); 304 idx++; 305 } 306 idx = 0; 307 for (Operation op : targets) { 308 targetOpHandles[idx++] = op.getUnsafeNativeHandle(); 309 } 310 Reference runRef = new Reference(); 311 byte[] metadata = null; 312 try { 313 metadata = 314 Session.run( 315 nativeHandle, 316 runOptions, 317 inputTensorHandles, 318 inputOpHandles, 319 inputOpIndices, 320 outputOpHandles, 321 outputOpIndices, 322 targetOpHandles, 323 wantMetadata, 324 outputTensorHandles); 325 } finally { 326 runRef.close(); 327 } 328 List<Tensor<?>> outputs = new ArrayList<Tensor<?>>(); 329 for (long h : outputTensorHandles) { 330 try { 331 outputs.add(Tensor.fromHandle(h)); 332 } catch (Exception e) { 333 for (Tensor<?> t : outputs) { 334 t.close(); 335 } 336 outputs.clear(); 337 throw e; 338 } 339 } 340 Run ret = new Run(); 341 ret.outputs = outputs; 342 ret.metadata = metadata; 343 return ret; 344 } 345 346 private class Reference implements AutoCloseable { Reference()347 public Reference() { 348 synchronized (nativeHandleLock) { 349 if (nativeHandle == 0) { 350 throw new IllegalStateException("run() cannot be called on the Session after close()"); 351 } 352 ++numActiveRuns; 353 } 354 } 355 356 @Override close()357 public void close() { 358 synchronized (nativeHandleLock) { 359 if (nativeHandle == 0) { 360 return; 361 } 362 if (--numActiveRuns == 0) { 363 nativeHandleLock.notifyAll(); 364 } 365 } 366 } 367 } 368 operationByName(String opName)369 private Operation operationByName(String opName) { 370 Operation op = graph.operation(opName); 371 if (op == null) { 372 throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph"); 373 } 374 return op; 375 } 376 377 @SuppressWarnings("rawtypes") parseOutput(String opName)378 private Output<?> parseOutput(String opName) { 379 int colon = opName.lastIndexOf(':'); 380 if (colon == -1 || colon == opName.length() - 1) { 381 return new Output(operationByName(opName), 0); 382 } 383 try { 384 String op = opName.substring(0, colon); 385 int index = Integer.parseInt(opName.substring(colon + 1)); 386 return new Output(operationByName(op), index); 387 } catch (NumberFormatException e) { 388 return new Output(operationByName(opName), 0); 389 } 390 } 391 392 private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>(); 393 private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>(); 394 private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>(); 395 private ArrayList<Operation> targets = new ArrayList<Operation>(); 396 private byte[] runOptions = null; 397 } 398 399 /** Create a Runner to execute graph operations and evaluate Tensors. */ runner()400 public Runner runner() { 401 return new Runner(); 402 } 403 404 /** 405 * Output tensors and metadata obtained when executing a session. 406 * 407 * <p>See {@link Runner#runAndFetchMetadata()} 408 */ 409 public static final class Run { 410 /** Tensors from requested fetches. */ 411 public List<Tensor<?>> outputs; 412 413 /** 414 * (Experimental): Metadata about the run. 415 * 416 * <p>A serialized <a 417 * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata 418 * protocol buffer</a>. The org.tensorflow package is free of any protocol buffer dependencies 419 * in order to remain friendly to resource constrained systems (where something like <a 420 * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may 421 * be more appropriate). A cost of that is this opaque blob. This choice is under review and 422 * this field may be replaced by more type-safe equivalents at any time. 423 */ 424 public byte[] metadata; 425 } 426 427 private final Graph graph; 428 private final Graph.Reference graphRef; 429 430 private final Object nativeHandleLock = new Object(); 431 private long nativeHandle; 432 private int numActiveRuns; 433 434 // TODO(ashankar): Remove after TensorFlow 1.2 has been released with allocate2(). allocate(long graphHandle)435 private static native long allocate(long graphHandle); 436 allocate2(long graphHandle, String target, byte[] config)437 private static native long allocate2(long graphHandle, String target, byte[] config); 438 delete(long handle)439 private static native void delete(long handle); 440 441 /** 442 * Execute a session. 443 * 444 * <p>The author apologizes for the ugliness of the long argument list of this method. However, 445 * take solace in the fact that this is a private method meant to cross the JNI boundary. 446 * 447 * @param handle to the C API TF_Session object (Session.nativeHandle) 448 * @param runOptions serialized representation of a RunOptions protocol buffer, or null 449 * @param inputOpHandles (see inputOpIndices) 450 * @param inputOpIndices (see inputTensorHandles) 451 * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values 452 * that are being "fed" (do not need to be computed) during graph execution. 453 * inputTensorHandles[i] (which correponds to a Tensor.nativeHandle) is considered to be the 454 * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that 455 * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. 456 * @param outputOpHandles (see outputOpIndices) 457 * @param outputOpIndices together with outputOpHandles identifies the set of values that should 458 * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is 459 * required that outputOpHandles.length == outputOpIndices.length. 460 * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose 461 * output will not be returned 462 * @param wantRunMetadata indicates whether metadata about this execution should be returned. 463 * @param outputTensorHandles will be filled in with handles to the outputs requested. It is 464 * required that outputTensorHandles.length == outputOpHandles.length. 465 * @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol 466 * buffer, false otherwise. 467 */ run( long handle, byte[] runOptions, long[] inputTensorHandles, long[] inputOpHandles, int[] inputOpIndices, long[] outputOpHandles, int[] outputOpIndices, long[] targetOpHandles, boolean wantRunMetadata, long[] outputTensorHandles)468 private static native byte[] run( 469 long handle, 470 byte[] runOptions, 471 long[] inputTensorHandles, 472 long[] inputOpHandles, 473 int[] inputOpIndices, 474 long[] outputOpHandles, 475 int[] outputOpIndices, 476 long[] targetOpHandles, 477 boolean wantRunMetadata, 478 long[] outputTensorHandles); 479 } 480