1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 package org.tensorflow.contrib.android;
17 
18 import android.content.res.AssetManager;
19 import android.os.Build.VERSION;
20 import android.os.Trace;
21 import android.text.TextUtils;
22 import android.util.Log;
23 import java.io.ByteArrayOutputStream;
24 import java.io.FileInputStream;
25 import java.io.IOException;
26 import java.io.InputStream;
27 import java.nio.ByteBuffer;
28 import java.nio.DoubleBuffer;
29 import java.nio.FloatBuffer;
30 import java.nio.IntBuffer;
31 import java.nio.LongBuffer;
32 import java.util.ArrayList;
33 import java.util.List;
34 import org.tensorflow.Graph;
35 import org.tensorflow.Operation;
36 import org.tensorflow.Session;
37 import org.tensorflow.Tensor;
38 import org.tensorflow.TensorFlow;
39 import org.tensorflow.Tensors;
40 import org.tensorflow.types.UInt8;
41 
42 /**
43  * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
44  * for inference.
45  *
46  * <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
47  * example usage.
48  */
49 public class TensorFlowInferenceInterface {
50   private static final String TAG = "TensorFlowInferenceInterface";
51   private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
52 
53   /*
54    * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
55    *
56    * @param assetManager The AssetManager to use to load the model file.
57    * @param model The filepath to the GraphDef proto representing the model.
58    */
TensorFlowInferenceInterface(AssetManager assetManager, String model)59   public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
60     prepareNativeRuntime();
61 
62     this.modelName = model;
63     this.g = new Graph();
64     this.sess = new Session(g);
65     this.runner = sess.runner();
66 
67     final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
68     InputStream is = null;
69     try {
70       String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
71       is = assetManager.open(aname);
72     } catch (IOException e) {
73       if (hasAssetPrefix) {
74         throw new RuntimeException("Failed to load model from '" + model + "'", e);
75       }
76       // Perhaps the model file is not an asset but is on disk.
77       try {
78         is = new FileInputStream(model);
79       } catch (IOException e2) {
80         throw new RuntimeException("Failed to load model from '" + model + "'", e);
81       }
82     }
83 
84     try {
85       if (VERSION.SDK_INT >= 18) {
86         Trace.beginSection("initializeTensorFlow");
87         Trace.beginSection("readGraphDef");
88       }
89 
90       // TODO(ashankar): Can we somehow mmap the contents instead of copying them?
91       byte[] graphDef = new byte[is.available()];
92       final int numBytesRead = is.read(graphDef);
93       if (numBytesRead != graphDef.length) {
94         throw new IOException(
95             "read error: read only "
96                 + numBytesRead
97                 + " of the graph, expected to read "
98                 + graphDef.length);
99       }
100 
101       if (VERSION.SDK_INT >= 18) {
102         Trace.endSection(); // readGraphDef.
103       }
104 
105       loadGraph(graphDef, g);
106       is.close();
107       Log.i(TAG, "Successfully loaded model from '" + model + "'");
108 
109       if (VERSION.SDK_INT >= 18) {
110         Trace.endSection(); // initializeTensorFlow.
111       }
112     } catch (IOException e) {
113       throw new RuntimeException("Failed to load model from '" + model + "'", e);
114     }
115   }
116 
117   /*
118    * Load a TensorFlow model from provided InputStream.
119    * Note: The InputStream will not be closed after loading model, users need to
120    * close it themselves.
121    *
122    * @param is The InputStream to use to load the model.
123    */
TensorFlowInferenceInterface(InputStream is)124   public TensorFlowInferenceInterface(InputStream is) {
125     prepareNativeRuntime();
126 
127     // modelName is redundant for model loading from input stream, here is for
128     // avoiding error in initialization as modelName is marked final.
129     this.modelName = "";
130     this.g = new Graph();
131     this.sess = new Session(g);
132     this.runner = sess.runner();
133 
134     try {
135       if (VERSION.SDK_INT >= 18) {
136         Trace.beginSection("initializeTensorFlow");
137         Trace.beginSection("readGraphDef");
138       }
139 
140       int baosInitSize = is.available() > 16384 ? is.available() : 16384;
141       ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
142       int numBytesRead;
143       byte[] buf = new byte[16384];
144       while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
145         baos.write(buf, 0, numBytesRead);
146       }
147       byte[] graphDef = baos.toByteArray();
148 
149       if (VERSION.SDK_INT >= 18) {
150         Trace.endSection(); // readGraphDef.
151       }
152 
153       loadGraph(graphDef, g);
154       Log.i(TAG, "Successfully loaded model from the input stream");
155 
156       if (VERSION.SDK_INT >= 18) {
157         Trace.endSection(); // initializeTensorFlow.
158       }
159     } catch (IOException e) {
160       throw new RuntimeException("Failed to load model from the input stream", e);
161     }
162   }
163 
164   /*
165    * Construct a TensorFlowInferenceInterface with provided Graph
166    *
167    * @param g The Graph to use to construct this interface.
168    */
TensorFlowInferenceInterface(Graph g)169   public TensorFlowInferenceInterface(Graph g) {
170     prepareNativeRuntime();
171 
172     // modelName is redundant here, here is for
173     // avoiding error in initialization as modelName is marked final.
174     this.modelName = "";
175     this.g = g;
176     this.sess = new Session(g);
177     this.runner = sess.runner();
178   }
179 
180   /**
181    * Runs inference between the previously registered input nodes (via feed*) and the requested
182    * output nodes. Output nodes can then be queried with the fetch* methods.
183    *
184    * @param outputNames A list of output nodes which should be filled by the inference pass.
185    */
run(String[] outputNames)186   public void run(String[] outputNames) {
187     run(outputNames, false);
188   }
189 
190   /**
191    * Runs inference between the previously registered input nodes (via feed*) and the requested
192    * output nodes. Output nodes can then be queried with the fetch* methods.
193    *
194    * @param outputNames A list of output nodes which should be filled by the inference pass.
195    */
run(String[] outputNames, boolean enableStats)196   public void run(String[] outputNames, boolean enableStats) {
197     run(outputNames, enableStats, new String[] {});
198   }
199 
200   /** An overloaded version of runInference that allows supplying targetNodeNames as well */
run(String[] outputNames, boolean enableStats, String[] targetNodeNames)201   public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
202     // Release any Tensors from the previous run calls.
203     closeFetches();
204 
205     // Add fetches.
206     for (String o : outputNames) {
207       fetchNames.add(o);
208       TensorId tid = TensorId.parse(o);
209       runner.fetch(tid.name, tid.outputIndex);
210     }
211 
212     // Add targets.
213     for (String t : targetNodeNames) {
214       runner.addTarget(t);
215     }
216 
217     // Run the session.
218     try {
219       if (enableStats) {
220         Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
221         fetchTensors = r.outputs;
222 
223         if (runStats == null) {
224           runStats = new RunStats();
225         }
226         runStats.add(r.metadata);
227       } else {
228         fetchTensors = runner.run();
229       }
230     } catch (RuntimeException e) {
231       // Ideally the exception would have been let through, but since this interface predates the
232       // TensorFlow Java API, must return -1.
233       Log.e(
234           TAG,
235           "Failed to run TensorFlow inference with inputs:["
236               + TextUtils.join(", ", feedNames)
237               + "], outputs:["
238               + TextUtils.join(", ", fetchNames)
239               + "]");
240       throw e;
241     } finally {
242       // Always release the feeds (to save resources) and reset the runner, this run is
243       // over.
244       closeFeeds();
245       runner = sess.runner();
246     }
247   }
248 
249   /** Returns a reference to the Graph describing the computation run during inference. */
graph()250   public Graph graph() {
251     return g;
252   }
253 
graphOperation(String operationName)254   public Operation graphOperation(String operationName) {
255     final Operation operation = g.operation(operationName);
256     if (operation == null) {
257       throw new RuntimeException(
258           "Node '" + operationName + "' does not exist in model '" + modelName + "'");
259     }
260     return operation;
261   }
262 
263   /** Returns the last stat summary string if logging is enabled. */
getStatString()264   public String getStatString() {
265     return (runStats == null) ? "" : runStats.summary();
266   }
267 
268   /**
269    * Cleans up the state associated with this Object.
270    *
271    * <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
272    */
close()273   public void close() {
274     closeFeeds();
275     closeFetches();
276     sess.close();
277     g.close();
278     if (runStats != null) {
279       runStats.close();
280     }
281     runStats = null;
282   }
283 
284   @Override
finalize()285   protected void finalize() throws Throwable {
286     try {
287       close();
288     } finally {
289       super.finalize();
290     }
291   }
292 
293   // Methods for taking a native Tensor and filling it with values from Java arrays.
294 
295   /**
296    * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
297    * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
298    * as many elements as that of the destination Tensor. If {@link src} has more elements than the
299    * destination has capacity, the copy is truncated.
300    */
feed(String inputName, boolean[] src, long... dims)301   public void feed(String inputName, boolean[] src, long... dims) {
302     byte[] b = new byte[src.length];
303 
304     for (int i = 0; i < src.length; i++) {
305       b[i] = src[i] ? (byte) 1 : (byte) 0;
306     }
307 
308     addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
309   }
310 
311   /**
312    * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
313    * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
314    * as many elements as that of the destination Tensor. If {@link src} has more elements than the
315    * destination has capacity, the copy is truncated.
316    */
feed(String inputName, float[] src, long... dims)317   public void feed(String inputName, float[] src, long... dims) {
318     addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
319   }
320 
321   /**
322    * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
323    * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
324    * as many elements as that of the destination Tensor. If {@link src} has more elements than the
325    * destination has capacity, the copy is truncated.
326    */
feed(String inputName, int[] src, long... dims)327   public void feed(String inputName, int[] src, long... dims) {
328     addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
329   }
330 
331   /**
332    * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
333    * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
334    * as many elements as that of the destination Tensor. If {@link src} has more elements than the
335    * destination has capacity, the copy is truncated.
336    */
feed(String inputName, long[] src, long... dims)337   public void feed(String inputName, long[] src, long... dims) {
338     addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
339   }
340 
341   /**
342    * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
343    * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
344    * as many elements as that of the destination Tensor. If {@link src} has more elements than the
345    * destination has capacity, the copy is truncated.
346    */
feed(String inputName, double[] src, long... dims)347   public void feed(String inputName, double[] src, long... dims) {
348     addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
349   }
350 
351   /**
352    * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
353    * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
354    * as many elements as that of the destination Tensor. If {@link src} has more elements than the
355    * destination has capacity, the copy is truncated.
356    */
feed(String inputName, byte[] src, long... dims)357   public void feed(String inputName, byte[] src, long... dims) {
358     addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
359   }
360 
361   /**
362    * Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
363    * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
364    * a Java {@code String} (which is a sequence of characters).
365    */
feedString(String inputName, byte[] src)366   public void feedString(String inputName, byte[] src) {
367     addFeed(inputName, Tensors.create(src));
368   }
369 
370   /**
371    * Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
372    * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
373    * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
374    */
feedString(String inputName, byte[][] src)375   public void feedString(String inputName, byte[][] src) {
376     addFeed(inputName, Tensors.create(src));
377   }
378 
379   // Methods for taking a native Tensor and filling it with src from Java native IO buffers.
380 
381   /**
382    * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
383    * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
384    * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
385    * elements as that of the destination Tensor. If {@link src} has more elements than the
386    * destination has capacity, the copy is truncated.
387    */
feed(String inputName, FloatBuffer src, long... dims)388   public void feed(String inputName, FloatBuffer src, long... dims) {
389     addFeed(inputName, Tensor.create(dims, src));
390   }
391 
392   /**
393    * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
394    * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
395    * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
396    * elements as that of the destination Tensor. If {@link src} has more elements than the
397    * destination has capacity, the copy is truncated.
398    */
feed(String inputName, IntBuffer src, long... dims)399   public void feed(String inputName, IntBuffer src, long... dims) {
400     addFeed(inputName, Tensor.create(dims, src));
401   }
402 
403   /**
404    * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
405    * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
406    * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
407    * elements as that of the destination Tensor. If {@link src} has more elements than the
408    * destination has capacity, the copy is truncated.
409    */
feed(String inputName, LongBuffer src, long... dims)410   public void feed(String inputName, LongBuffer src, long... dims) {
411     addFeed(inputName, Tensor.create(dims, src));
412   }
413 
414   /**
415    * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
416    * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
417    * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
418    * elements as that of the destination Tensor. If {@link src} has more elements than the
419    * destination has capacity, the copy is truncated.
420    */
feed(String inputName, DoubleBuffer src, long... dims)421   public void feed(String inputName, DoubleBuffer src, long... dims) {
422     addFeed(inputName, Tensor.create(dims, src));
423   }
424 
425   /**
426    * Given a source buffer with shape {@link dims} and content {@link src}, both stored as
427    * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
428    * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
429    * elements as that of the destination Tensor. If {@link src} has more elements than the
430    * destination has capacity, the copy is truncated.
431    */
feed(String inputName, ByteBuffer src, long... dims)432   public void feed(String inputName, ByteBuffer src, long... dims) {
433     addFeed(inputName, Tensor.create(UInt8.class, dims, src));
434   }
435 
436   /**
437    * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
438    * dst} must have length greater than or equal to that of the source Tensor. This operation will
439    * not affect dst's content past the source Tensor's size.
440    */
fetch(String outputName, float[] dst)441   public void fetch(String outputName, float[] dst) {
442     fetch(outputName, FloatBuffer.wrap(dst));
443   }
444 
445   /**
446    * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
447    * dst} must have length greater than or equal to that of the source Tensor. This operation will
448    * not affect dst's content past the source Tensor's size.
449    */
fetch(String outputName, int[] dst)450   public void fetch(String outputName, int[] dst) {
451     fetch(outputName, IntBuffer.wrap(dst));
452   }
453 
454   /**
455    * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
456    * dst} must have length greater than or equal to that of the source Tensor. This operation will
457    * not affect dst's content past the source Tensor's size.
458    */
fetch(String outputName, long[] dst)459   public void fetch(String outputName, long[] dst) {
460     fetch(outputName, LongBuffer.wrap(dst));
461   }
462 
463   /**
464    * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
465    * dst} must have length greater than or equal to that of the source Tensor. This operation will
466    * not affect dst's content past the source Tensor's size.
467    */
fetch(String outputName, double[] dst)468   public void fetch(String outputName, double[] dst) {
469     fetch(outputName, DoubleBuffer.wrap(dst));
470   }
471 
472   /**
473    * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
474    * dst} must have length greater than or equal to that of the source Tensor. This operation will
475    * not affect dst's content past the source Tensor's size.
476    */
fetch(String outputName, byte[] dst)477   public void fetch(String outputName, byte[] dst) {
478     fetch(outputName, ByteBuffer.wrap(dst));
479   }
480 
481   /**
482    * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
483    * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
484    * or equal to that of the source Tensor. This operation will not affect dst's content past the
485    * source Tensor's size.
486    */
fetch(String outputName, FloatBuffer dst)487   public void fetch(String outputName, FloatBuffer dst) {
488     getTensor(outputName).writeTo(dst);
489   }
490 
491   /**
492    * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
493    * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
494    * or equal to that of the source Tensor. This operation will not affect dst's content past the
495    * source Tensor's size.
496    */
fetch(String outputName, IntBuffer dst)497   public void fetch(String outputName, IntBuffer dst) {
498     getTensor(outputName).writeTo(dst);
499   }
500 
501   /**
502    * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
503    * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
504    * or equal to that of the source Tensor. This operation will not affect dst's content past the
505    * source Tensor's size.
506    */
fetch(String outputName, LongBuffer dst)507   public void fetch(String outputName, LongBuffer dst) {
508     getTensor(outputName).writeTo(dst);
509   }
510 
511   /**
512    * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
513    * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
514    * or equal to that of the source Tensor. This operation will not affect dst's content past the
515    * source Tensor's size.
516    */
fetch(String outputName, DoubleBuffer dst)517   public void fetch(String outputName, DoubleBuffer dst) {
518     getTensor(outputName).writeTo(dst);
519   }
520 
521   /**
522    * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
523    * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
524    * or equal to that of the source Tensor. This operation will not affect dst's content past the
525    * source Tensor's size.
526    */
fetch(String outputName, ByteBuffer dst)527   public void fetch(String outputName, ByteBuffer dst) {
528     getTensor(outputName).writeTo(dst);
529   }
530 
prepareNativeRuntime()531   private void prepareNativeRuntime() {
532     Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
533     try {
534       // Hack to see if the native libraries have been loaded.
535       new RunStats();
536       Log.i(TAG, "TensorFlow native methods already loaded");
537     } catch (UnsatisfiedLinkError e1) {
538       Log.i(
539           TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
540       try {
541         System.loadLibrary("tensorflow_inference");
542         Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
543       } catch (UnsatisfiedLinkError e2) {
544         throw new RuntimeException(
545             "Native TF methods not found; check that the correct native"
546                 + " libraries are present in the APK.");
547       }
548     }
549   }
550 
loadGraph(byte[] graphDef, Graph g)551   private void loadGraph(byte[] graphDef, Graph g) throws IOException {
552     final long startMs = System.currentTimeMillis();
553 
554     if (VERSION.SDK_INT >= 18) {
555       Trace.beginSection("importGraphDef");
556     }
557 
558     try {
559       g.importGraphDef(graphDef);
560     } catch (IllegalArgumentException e) {
561       throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
562     }
563 
564     if (VERSION.SDK_INT >= 18) {
565       Trace.endSection(); // importGraphDef.
566     }
567 
568     final long endMs = System.currentTimeMillis();
569     Log.i(
570         TAG,
571         "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
572   }
573 
addFeed(String inputName, Tensor<?> t)574   private void addFeed(String inputName, Tensor<?> t) {
575     // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
576     TensorId tid = TensorId.parse(inputName);
577     runner.feed(tid.name, tid.outputIndex, t);
578     feedNames.add(inputName);
579     feedTensors.add(t);
580   }
581 
582   private static class TensorId {
583     String name;
584     int outputIndex;
585 
586     // Parse output names into a TensorId.
587     //
588     // E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
parse(String name)589     public static TensorId parse(String name) {
590       TensorId tid = new TensorId();
591       int colonIndex = name.lastIndexOf(':');
592       if (colonIndex < 0) {
593         tid.outputIndex = 0;
594         tid.name = name;
595         return tid;
596       }
597       try {
598         tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
599         tid.name = name.substring(0, colonIndex);
600       } catch (NumberFormatException e) {
601         tid.outputIndex = 0;
602         tid.name = name;
603       }
604       return tid;
605     }
606   }
607 
getTensor(String outputName)608   private Tensor<?> getTensor(String outputName) {
609     int i = 0;
610     for (String n : fetchNames) {
611       if (n.equals(outputName)) {
612         return fetchTensors.get(i);
613       }
614       ++i;
615     }
616     throw new RuntimeException(
617         "Node '" + outputName + "' was not provided to run(), so it cannot be read");
618   }
619 
closeFeeds()620   private void closeFeeds() {
621     for (Tensor<?> t : feedTensors) {
622       t.close();
623     }
624     feedTensors.clear();
625     feedNames.clear();
626   }
627 
closeFetches()628   private void closeFetches() {
629     for (Tensor<?> t : fetchTensors) {
630       t.close();
631     }
632     fetchTensors.clear();
633     fetchNames.clear();
634   }
635 
636   // Immutable state.
637   private final String modelName;
638   private final Graph g;
639   private final Session sess;
640 
641   // State reset on every call to run.
642   private Session.Runner runner;
643   private List<String> feedNames = new ArrayList<String>();
644   private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
645   private List<String> fetchNames = new ArrayList<String>();
646   private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
647 
648   // Mutable state.
649   private RunStats runStats;
650 }
651