• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow;
17 
18 import java.util.ArrayList;
19 import java.util.List;
20 
21 /**
22  * Driver for {@link Graph} execution.
23  *
24  * <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a
25  * {@link Graph} are executed to compute {@link Tensor}s. For example:
26  *
27  * <pre>{@code
28  * // Let's say graph is an instance of the Graph class
29  * // for the computation y = 3 * x
30  *
31  * try (Session s = new Session(graph)) {
32  *   try (Tensor x = Tensor.create(2.0f);
33  *       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
34  *       System.out.println(y.floatValue());  // Will print 6.0f
35  *   }
36  *   try (Tensor x = Tensor.create(1.1f);
37  *       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
38  *       System.out.println(y.floatValue());  // Will print 3.3f
39  *   }
40  * }
41  * }</pre>
42  *
43  * <p><b>WARNING:</b>A {@code Session} owns resources that <b>must</b> be explicitly freed by
44  * invoking {@link #close()}.
45  *
46  * <p>Instances of a Session are thread-safe.
47  */
48 public final class Session implements AutoCloseable {
49 
50   /** Construct a new session with the associated {@link Graph}. */
Session(Graph g)51   public Session(Graph g) {
52     this(g, null);
53   }
54 
55   /**
56    * Construct a new session with the associated {@link Graph} and configuration options.
57    *
58    * @param g The {@link Graph} the created Session will operate on.
59    * @param config Configuration parameters for the session specified as a serialized <a
60    *     href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto</a>
61    *     protocol buffer.
62    * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto
63    *     protocol buffer.
64    */
Session(Graph g, byte[] config)65   public Session(Graph g, byte[] config) {
66     graph = g;
67     Graph.Reference r = g.ref();
68     try {
69       nativeHandle =
70           (config == null) ? allocate(r.nativeHandle()) : allocate2(r.nativeHandle(), null, config);
71       graphRef = g.ref();
72     } finally {
73       r.close();
74     }
75   }
76 
77   /** Wrap an existing session with the associated {@link Graph}. */
Session(Graph g, long nativeHandle)78   Session(Graph g, long nativeHandle) {
79     graph = g;
80     this.nativeHandle = nativeHandle;
81     graphRef = g.ref();
82   }
83 
84   /**
85    * Release resources associated with the Session.
86    *
87    * <p>Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session
88    * is not usable after close returns.
89    */
90   @Override
close()91   public void close() {
92     graphRef.close();
93     synchronized (nativeHandleLock) {
94       if (nativeHandle == 0) {
95         return;
96       }
97       while (numActiveRuns > 0) {
98         try {
99           nativeHandleLock.wait();
100         } catch (InterruptedException e) {
101           Thread.currentThread().interrupt();
102           // Possible leak of the Session and Graph in this case?
103           return;
104         }
105       }
106       delete(nativeHandle);
107       nativeHandle = 0;
108     }
109   }
110 
111   /**
112    * Run {@link Operation}s and evaluate {@link Tensor}s.
113    *
114    * <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to
115    * evaluate the {@link Tensor}s to fetch. The {@link #feed(String,int,Tensor)} call allows callers
116    * to override the value of {@link Tensor}s in the graph by substituting the provided {@link
117    * Tensor}s for the outputs of the operations provided to {@link #feed(String,int,Tensor)}.
118    */
119   public final class Runner {
120     /**
121      * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces.
122      *
123      * @param operation Is either the string name of the operation, in which case this method is a
124      *     shorthand for {@code feed(operation, 0)}, or it is a string of the form
125      *     <tt>operation_name:output_index</tt> , in which case this method acts like {@code
126      *     feed(operation_name, output_index)}. These colon-separated names are commonly used in the
127      *     {@code SignatureDef} protocol buffer messages that are included in {@link
128      *     SavedModelBundle#metaGraphDef()}.
129      */
feed(String operation, Tensor<?> t)130     public Runner feed(String operation, Tensor<?> t) {
131       return feed(parseOutput(operation), t);
132     }
133 
134     /**
135      * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t}
136      * for the value it produces.
137      *
138      * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
139      * one {@code t} is being provided for.
140      */
feed(String operation, int index, Tensor<?> t)141     public Runner feed(String operation, int index, Tensor<?> t) {
142       Operation op = operationByName(operation);
143       if (op != null) {
144         inputs.add(op.output(index));
145         inputTensors.add(t);
146       }
147       return this;
148     }
149 
150     /**
151      * Use {@code t} instead of the Tensor referred to by executing the operation referred to by
152      * {@code operand}.
153      */
feed(Operand<?> operand, Tensor<?> t)154     public Runner feed(Operand<?> operand, Tensor<?> t) {
155       inputs.add(operand.asOutput());
156       inputTensors.add(t);
157       return this;
158     }
159 
160     /**
161      * Make {@link #run()} return the output of {@code operation}.
162      *
163      * @param operation Is either the string name of the operation, in which case this method is a
164      *     shorthand for {@code fetch(operation, 0)}, or it is a string of the form
165      *     <tt>operation_name:output_index</tt> , in which case this method acts like {@code
166      *     fetch(operation_name, output_index)}. These colon-separated names are commonly used in
167      *     the {@code SignatureDef} protocol buffer messages that are included in {@link
168      *     SavedModelBundle#metaGraphDef()}.
169      */
fetch(String operation)170     public Runner fetch(String operation) {
171       return fetch(parseOutput(operation));
172     }
173 
174     /**
175      * Make {@link #run()} return the {@code index}-th output of {@code operation}.
176      *
177      * <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
178      * one to return.
179      */
fetch(String operation, int index)180     public Runner fetch(String operation, int index) {
181       Operation op = operationByName(operation);
182       if (op != null) {
183         outputs.add(op.output(index));
184       }
185       return this;
186     }
187 
188     /**
189      * Makes {@link #run()} return the Tensor referred to by {@code output}.
190      */
fetch(Output<?> output)191     public Runner fetch(Output<?> output) {
192       outputs.add(output);
193       return this;
194     }
195 
196     /**
197      * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
198      */
fetch(Operand<?> operand)199     public Runner fetch(Operand<?> operand) {
200       return fetch(operand.asOutput());
201     }
202 
203     /**
204      * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
205      */
addTarget(String operation)206     public Runner addTarget(String operation) {
207       Operation op = operationByName(operation);
208       if (op != null) {
209         targets.add(op);
210       }
211       return this;
212     }
213 
214     /**
215      * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
216      */
addTarget(Operation operation)217     public Runner addTarget(Operation operation) {
218       targets.add(operation);
219       return this;
220     }
221 
222     /**
223      * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s.
224      */
addTarget(Operand<?> operand)225     public Runner addTarget(Operand<?> operand) {
226       return addTarget(operand.asOutput().op());
227     }
228 
229     /**
230      * (Experimental method): set options (typically for debugging) for this run.
231      *
232      * <p>The options are presented as a serialized <a
233      * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
234      * protocol buffer</a>.
235      *
236      * <p>The org.tensorflow package is free of any protocol buffer dependencies in order to remain
237      * friendly to resource constrained systems (where something like <a
238      * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
239      * be more appropriate). A cost of that is this lack of type-safety in this API function. This
240      * choice is under review and this function may be replaced by more type-safe equivalents at any
241      * time.
242      */
setOptions(byte[] options)243     public Runner setOptions(byte[] options) {
244       this.runOptions = options;
245       return this;
246     }
247 
248     /**
249      * Execute the graph fragments necessary to compute all requested fetches.
250      *
251      * <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor}s, i.e., the
252      * caller must call {@link Tensor#close()} on all elements of the returned list to free up
253      * resources.
254      *
255      * <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it
256      * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in
257      * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a
258      * {@code Map<Output, Tensor>}?
259      *
260      * <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to
261      * extract output tensors in a type-safe way.
262      */
run()263     public List<Tensor<?>> run() {
264       return runHelper(false).outputs;
265     }
266 
267     /**
268      * Execute graph fragments to compute requested fetches and return metadata about the run.
269      *
270      * <p>This is exactly like {@link #run()}, but in addition to the requested Tensors, also
271      * returns metadata about the graph execution in the form of a serialized <a
272      * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
273      * protocol buffer</a>.
274      */
runAndFetchMetadata()275     public Run runAndFetchMetadata() {
276       return runHelper(true);
277     }
278 
runHelper(boolean wantMetadata)279     private Run runHelper(boolean wantMetadata) {
280       long[] inputTensorHandles = new long[inputTensors.size()];
281       long[] inputOpHandles = new long[inputs.size()];
282       int[] inputOpIndices = new int[inputs.size()];
283       long[] outputOpHandles = new long[outputs.size()];
284       int[] outputOpIndices = new int[outputs.size()];
285       long[] targetOpHandles = new long[targets.size()];
286       long[] outputTensorHandles = new long[outputs.size()];
287 
288       // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
289       // validity of the Graph and graphRef ensures that.
290       int idx = 0;
291       for (Tensor<?> t : inputTensors) {
292         inputTensorHandles[idx++] = t.getNativeHandle();
293       }
294       idx = 0;
295       for (Output<?> o : inputs) {
296         inputOpHandles[idx] = o.op().getUnsafeNativeHandle();
297         inputOpIndices[idx] = o.index();
298         idx++;
299       }
300       idx = 0;
301       for (Output<?> o : outputs) {
302         outputOpHandles[idx] = o.op().getUnsafeNativeHandle();
303         outputOpIndices[idx] = o.index();
304         idx++;
305       }
306       idx = 0;
307       for (Operation op : targets) {
308         targetOpHandles[idx++] = op.getUnsafeNativeHandle();
309       }
310       Reference runRef = new Reference();
311       byte[] metadata = null;
312       try {
313         metadata =
314             Session.run(
315                 nativeHandle,
316                 runOptions,
317                 inputTensorHandles,
318                 inputOpHandles,
319                 inputOpIndices,
320                 outputOpHandles,
321                 outputOpIndices,
322                 targetOpHandles,
323                 wantMetadata,
324                 outputTensorHandles);
325       } finally {
326         runRef.close();
327       }
328       List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
329       for (long h : outputTensorHandles) {
330         try {
331           outputs.add(Tensor.fromHandle(h));
332         } catch (Exception e) {
333           for (Tensor<?> t : outputs) {
334             t.close();
335           }
336           outputs.clear();
337           throw e;
338         }
339       }
340       Run ret = new Run();
341       ret.outputs = outputs;
342       ret.metadata = metadata;
343       return ret;
344     }
345 
346     private class Reference implements AutoCloseable {
Reference()347       public Reference() {
348         synchronized (nativeHandleLock) {
349           if (nativeHandle == 0) {
350             throw new IllegalStateException("run() cannot be called on the Session after close()");
351           }
352           ++numActiveRuns;
353         }
354       }
355 
356       @Override
close()357       public void close() {
358         synchronized (nativeHandleLock) {
359           if (nativeHandle == 0) {
360             return;
361           }
362           if (--numActiveRuns == 0) {
363             nativeHandleLock.notifyAll();
364           }
365         }
366       }
367     }
368 
operationByName(String opName)369     private Operation operationByName(String opName) {
370       Operation op = graph.operation(opName);
371       if (op == null) {
372         throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph");
373       }
374       return op;
375     }
376 
377     @SuppressWarnings("rawtypes")
parseOutput(String opName)378     private Output<?> parseOutput(String opName) {
379       int colon = opName.lastIndexOf(':');
380       if (colon == -1 || colon == opName.length() - 1) {
381         return new Output(operationByName(opName), 0);
382       }
383       try {
384         String op = opName.substring(0, colon);
385         int index = Integer.parseInt(opName.substring(colon + 1));
386         return new Output(operationByName(op), index);
387       } catch (NumberFormatException e) {
388         return new Output(operationByName(opName), 0);
389       }
390     }
391 
392     private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>();
393     private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>();
394     private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>();
395     private ArrayList<Operation> targets = new ArrayList<Operation>();
396     private byte[] runOptions = null;
397   }
398 
399   /** Create a Runner to execute graph operations and evaluate Tensors. */
runner()400   public Runner runner() {
401     return new Runner();
402   }
403 
404   /**
405    * Output tensors and metadata obtained when executing a session.
406    *
407    * <p>See {@link Runner#runAndFetchMetadata()}
408    */
409   public static final class Run {
410     /** Tensors from requested fetches. */
411     public List<Tensor<?>> outputs;
412 
413     /**
414      * (Experimental): Metadata about the run.
415      *
416      * <p>A serialized <a
417      * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
418      * protocol buffer</a>. The org.tensorflow package is free of any protocol buffer dependencies
419      * in order to remain friendly to resource constrained systems (where something like <a
420      * href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
421      * be more appropriate). A cost of that is this opaque blob. This choice is under review and
422      * this field may be replaced by more type-safe equivalents at any time.
423      */
424     public byte[] metadata;
425   }
426 
427   private final Graph graph;
428   private final Graph.Reference graphRef;
429 
430   private final Object nativeHandleLock = new Object();
431   private long nativeHandle;
432   private int numActiveRuns;
433 
434   // TODO(ashankar): Remove after TensorFlow 1.2 has been released with allocate2().
allocate(long graphHandle)435   private static native long allocate(long graphHandle);
436 
allocate2(long graphHandle, String target, byte[] config)437   private static native long allocate2(long graphHandle, String target, byte[] config);
438 
delete(long handle)439   private static native void delete(long handle);
440 
441   /**
442    * Execute a session.
443    *
444    * <p>The author apologizes for the ugliness of the long argument list of this method. However,
445    * take solace in the fact that this is a private method meant to cross the JNI boundary.
446    *
447    * @param handle to the C API TF_Session object (Session.nativeHandle)
448    * @param runOptions serialized representation of a RunOptions protocol buffer, or null
449    * @param inputOpHandles (see inputOpIndices)
450    * @param inputOpIndices (see inputTensorHandles)
451    * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values
452    *     that are being "fed" (do not need to be computed) during graph execution.
453    *     inputTensorHandles[i] (which correponds to a Tensor.nativeHandle) is considered to be the
454    *     inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that
455    *     inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
456    * @param outputOpHandles (see outputOpIndices)
457    * @param outputOpIndices together with outputOpHandles identifies the set of values that should
458    *     be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is
459    *     required that outputOpHandles.length == outputOpIndices.length.
460    * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose
461    *     output will not be returned
462    * @param wantRunMetadata indicates whether metadata about this execution should be returned.
463    * @param outputTensorHandles will be filled in with handles to the outputs requested. It is
464    *     required that outputTensorHandles.length == outputOpHandles.length.
465    * @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol
466    *     buffer, false otherwise.
467    */
run( long handle, byte[] runOptions, long[] inputTensorHandles, long[] inputOpHandles, int[] inputOpIndices, long[] outputOpHandles, int[] outputOpIndices, long[] targetOpHandles, boolean wantRunMetadata, long[] outputTensorHandles)468   private static native byte[] run(
469       long handle,
470       byte[] runOptions,
471       long[] inputTensorHandles,
472       long[] inputOpHandles,
473       int[] inputOpIndices,
474       long[] outputOpHandles,
475       int[] outputOpIndices,
476       long[] targetOpHandles,
477       boolean wantRunMetadata,
478       long[] outputTensorHandles);
479 }
480