1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 package org.tensorflow.contrib.android; 17 18 import android.content.res.AssetManager; 19 import android.os.Build.VERSION; 20 import android.os.Trace; 21 import android.text.TextUtils; 22 import android.util.Log; 23 import java.io.ByteArrayOutputStream; 24 import java.io.FileInputStream; 25 import java.io.IOException; 26 import java.io.InputStream; 27 import java.nio.ByteBuffer; 28 import java.nio.DoubleBuffer; 29 import java.nio.FloatBuffer; 30 import java.nio.IntBuffer; 31 import java.nio.LongBuffer; 32 import java.util.ArrayList; 33 import java.util.List; 34 import org.tensorflow.Graph; 35 import org.tensorflow.Operation; 36 import org.tensorflow.Session; 37 import org.tensorflow.Tensor; 38 import org.tensorflow.TensorFlow; 39 import org.tensorflow.Tensors; 40 import org.tensorflow.types.UInt8; 41 42 /** 43 * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface 44 * for inference. 45 * 46 * <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an 47 * example usage. 48 */ 49 public class TensorFlowInferenceInterface { 50 private static final String TAG = "TensorFlowInferenceInterface"; 51 private static final String ASSET_FILE_PREFIX = "file:///android_asset/"; 52 53 /* 54 * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. 55 * 56 * @param assetManager The AssetManager to use to load the model file. 57 * @param model The filepath to the GraphDef proto representing the model. 58 */ TensorFlowInferenceInterface(AssetManager assetManager, String model)59 public TensorFlowInferenceInterface(AssetManager assetManager, String model) { 60 prepareNativeRuntime(); 61 62 this.modelName = model; 63 this.g = new Graph(); 64 this.sess = new Session(g); 65 this.runner = sess.runner(); 66 67 final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX); 68 InputStream is = null; 69 try { 70 String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model; 71 is = assetManager.open(aname); 72 } catch (IOException e) { 73 if (hasAssetPrefix) { 74 throw new RuntimeException("Failed to load model from '" + model + "'", e); 75 } 76 // Perhaps the model file is not an asset but is on disk. 77 try { 78 is = new FileInputStream(model); 79 } catch (IOException e2) { 80 throw new RuntimeException("Failed to load model from '" + model + "'", e); 81 } 82 } 83 84 try { 85 if (VERSION.SDK_INT >= 18) { 86 Trace.beginSection("initializeTensorFlow"); 87 Trace.beginSection("readGraphDef"); 88 } 89 90 // TODO(ashankar): Can we somehow mmap the contents instead of copying them? 91 byte[] graphDef = new byte[is.available()]; 92 final int numBytesRead = is.read(graphDef); 93 if (numBytesRead != graphDef.length) { 94 throw new IOException( 95 "read error: read only " 96 + numBytesRead 97 + " of the graph, expected to read " 98 + graphDef.length); 99 } 100 101 if (VERSION.SDK_INT >= 18) { 102 Trace.endSection(); // readGraphDef. 103 } 104 105 loadGraph(graphDef, g); 106 is.close(); 107 Log.i(TAG, "Successfully loaded model from '" + model + "'"); 108 109 if (VERSION.SDK_INT >= 18) { 110 Trace.endSection(); // initializeTensorFlow. 111 } 112 } catch (IOException e) { 113 throw new RuntimeException("Failed to load model from '" + model + "'", e); 114 } 115 } 116 117 /* 118 * Load a TensorFlow model from provided InputStream. 119 * Note: The InputStream will not be closed after loading model, users need to 120 * close it themselves. 121 * 122 * @param is The InputStream to use to load the model. 123 */ TensorFlowInferenceInterface(InputStream is)124 public TensorFlowInferenceInterface(InputStream is) { 125 prepareNativeRuntime(); 126 127 // modelName is redundant for model loading from input stream, here is for 128 // avoiding error in initialization as modelName is marked final. 129 this.modelName = ""; 130 this.g = new Graph(); 131 this.sess = new Session(g); 132 this.runner = sess.runner(); 133 134 try { 135 if (VERSION.SDK_INT >= 18) { 136 Trace.beginSection("initializeTensorFlow"); 137 Trace.beginSection("readGraphDef"); 138 } 139 140 int baosInitSize = is.available() > 16384 ? is.available() : 16384; 141 ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize); 142 int numBytesRead; 143 byte[] buf = new byte[16384]; 144 while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) { 145 baos.write(buf, 0, numBytesRead); 146 } 147 byte[] graphDef = baos.toByteArray(); 148 149 if (VERSION.SDK_INT >= 18) { 150 Trace.endSection(); // readGraphDef. 151 } 152 153 loadGraph(graphDef, g); 154 Log.i(TAG, "Successfully loaded model from the input stream"); 155 156 if (VERSION.SDK_INT >= 18) { 157 Trace.endSection(); // initializeTensorFlow. 158 } 159 } catch (IOException e) { 160 throw new RuntimeException("Failed to load model from the input stream", e); 161 } 162 } 163 164 /* 165 * Construct a TensorFlowInferenceInterface with provided Graph 166 * 167 * @param g The Graph to use to construct this interface. 168 */ TensorFlowInferenceInterface(Graph g)169 public TensorFlowInferenceInterface(Graph g) { 170 prepareNativeRuntime(); 171 172 // modelName is redundant here, here is for 173 // avoiding error in initialization as modelName is marked final. 174 this.modelName = ""; 175 this.g = g; 176 this.sess = new Session(g); 177 this.runner = sess.runner(); 178 } 179 180 /** 181 * Runs inference between the previously registered input nodes (via feed*) and the requested 182 * output nodes. Output nodes can then be queried with the fetch* methods. 183 * 184 * @param outputNames A list of output nodes which should be filled by the inference pass. 185 */ run(String[] outputNames)186 public void run(String[] outputNames) { 187 run(outputNames, false); 188 } 189 190 /** 191 * Runs inference between the previously registered input nodes (via feed*) and the requested 192 * output nodes. Output nodes can then be queried with the fetch* methods. 193 * 194 * @param outputNames A list of output nodes which should be filled by the inference pass. 195 */ run(String[] outputNames, boolean enableStats)196 public void run(String[] outputNames, boolean enableStats) { 197 run(outputNames, enableStats, new String[] {}); 198 } 199 200 /** An overloaded version of runInference that allows supplying targetNodeNames as well */ run(String[] outputNames, boolean enableStats, String[] targetNodeNames)201 public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) { 202 // Release any Tensors from the previous run calls. 203 closeFetches(); 204 205 // Add fetches. 206 for (String o : outputNames) { 207 fetchNames.add(o); 208 TensorId tid = TensorId.parse(o); 209 runner.fetch(tid.name, tid.outputIndex); 210 } 211 212 // Add targets. 213 for (String t : targetNodeNames) { 214 runner.addTarget(t); 215 } 216 217 // Run the session. 218 try { 219 if (enableStats) { 220 Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata(); 221 fetchTensors = r.outputs; 222 223 if (runStats == null) { 224 runStats = new RunStats(); 225 } 226 runStats.add(r.metadata); 227 } else { 228 fetchTensors = runner.run(); 229 } 230 } catch (RuntimeException e) { 231 // Ideally the exception would have been let through, but since this interface predates the 232 // TensorFlow Java API, must return -1. 233 Log.e( 234 TAG, 235 "Failed to run TensorFlow inference with inputs:[" 236 + TextUtils.join(", ", feedNames) 237 + "], outputs:[" 238 + TextUtils.join(", ", fetchNames) 239 + "]"); 240 throw e; 241 } finally { 242 // Always release the feeds (to save resources) and reset the runner, this run is 243 // over. 244 closeFeeds(); 245 runner = sess.runner(); 246 } 247 } 248 249 /** Returns a reference to the Graph describing the computation run during inference. */ graph()250 public Graph graph() { 251 return g; 252 } 253 graphOperation(String operationName)254 public Operation graphOperation(String operationName) { 255 final Operation operation = g.operation(operationName); 256 if (operation == null) { 257 throw new RuntimeException( 258 "Node '" + operationName + "' does not exist in model '" + modelName + "'"); 259 } 260 return operation; 261 } 262 263 /** Returns the last stat summary string if logging is enabled. */ getStatString()264 public String getStatString() { 265 return (runStats == null) ? "" : runStats.summary(); 266 } 267 268 /** 269 * Cleans up the state associated with this Object. 270 * 271 * <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns. 272 */ close()273 public void close() { 274 closeFeeds(); 275 closeFetches(); 276 sess.close(); 277 g.close(); 278 if (runStats != null) { 279 runStats.close(); 280 } 281 runStats = null; 282 } 283 284 @Override finalize()285 protected void finalize() throws Throwable { 286 try { 287 close(); 288 } finally { 289 super.finalize(); 290 } 291 } 292 293 // Methods for taking a native Tensor and filling it with values from Java arrays. 294 295 /** 296 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 297 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 298 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 299 * destination has capacity, the copy is truncated. 300 */ feed(String inputName, boolean[] src, long... dims)301 public void feed(String inputName, boolean[] src, long... dims) { 302 byte[] b = new byte[src.length]; 303 304 for (int i = 0; i < src.length; i++) { 305 b[i] = src[i] ? (byte) 1 : (byte) 0; 306 } 307 308 addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b))); 309 } 310 311 /** 312 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 313 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 314 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 315 * destination has capacity, the copy is truncated. 316 */ feed(String inputName, float[] src, long... dims)317 public void feed(String inputName, float[] src, long... dims) { 318 addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src))); 319 } 320 321 /** 322 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 323 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 324 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 325 * destination has capacity, the copy is truncated. 326 */ feed(String inputName, int[] src, long... dims)327 public void feed(String inputName, int[] src, long... dims) { 328 addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src))); 329 } 330 331 /** 332 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 333 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 334 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 335 * destination has capacity, the copy is truncated. 336 */ feed(String inputName, long[] src, long... dims)337 public void feed(String inputName, long[] src, long... dims) { 338 addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src))); 339 } 340 341 /** 342 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 343 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 344 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 345 * destination has capacity, the copy is truncated. 346 */ feed(String inputName, double[] src, long... dims)347 public void feed(String inputName, double[] src, long... dims) { 348 addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src))); 349 } 350 351 /** 352 * Given a source array with shape {@link dims} and content {@link src}, copy the contents into 353 * the input Tensor with name {@link inputName}. The source array {@link src} must have at least 354 * as many elements as that of the destination Tensor. If {@link src} has more elements than the 355 * destination has capacity, the copy is truncated. 356 */ feed(String inputName, byte[] src, long... dims)357 public void feed(String inputName, byte[] src, long... dims) { 358 addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src))); 359 } 360 361 /** 362 * Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued 363 * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not 364 * a Java {@code String} (which is a sequence of characters). 365 */ feedString(String inputName, byte[] src)366 public void feedString(String inputName, byte[] src) { 367 addFeed(inputName, Tensors.create(src)); 368 } 369 370 /** 371 * Copy an array of byte sequences into the input Tensor with name {@link inputName} as a 372 * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an 373 * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters). 374 */ feedString(String inputName, byte[][] src)375 public void feedString(String inputName, byte[][] src) { 376 addFeed(inputName, Tensors.create(src)); 377 } 378 379 // Methods for taking a native Tensor and filling it with src from Java native IO buffers. 380 381 /** 382 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 383 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 384 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 385 * elements as that of the destination Tensor. If {@link src} has more elements than the 386 * destination has capacity, the copy is truncated. 387 */ feed(String inputName, FloatBuffer src, long... dims)388 public void feed(String inputName, FloatBuffer src, long... dims) { 389 addFeed(inputName, Tensor.create(dims, src)); 390 } 391 392 /** 393 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 394 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 395 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 396 * elements as that of the destination Tensor. If {@link src} has more elements than the 397 * destination has capacity, the copy is truncated. 398 */ feed(String inputName, IntBuffer src, long... dims)399 public void feed(String inputName, IntBuffer src, long... dims) { 400 addFeed(inputName, Tensor.create(dims, src)); 401 } 402 403 /** 404 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 405 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 406 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 407 * elements as that of the destination Tensor. If {@link src} has more elements than the 408 * destination has capacity, the copy is truncated. 409 */ feed(String inputName, LongBuffer src, long... dims)410 public void feed(String inputName, LongBuffer src, long... dims) { 411 addFeed(inputName, Tensor.create(dims, src)); 412 } 413 414 /** 415 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 416 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 417 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 418 * elements as that of the destination Tensor. If {@link src} has more elements than the 419 * destination has capacity, the copy is truncated. 420 */ feed(String inputName, DoubleBuffer src, long... dims)421 public void feed(String inputName, DoubleBuffer src, long... dims) { 422 addFeed(inputName, Tensor.create(dims, src)); 423 } 424 425 /** 426 * Given a source buffer with shape {@link dims} and content {@link src}, both stored as 427 * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input 428 * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many 429 * elements as that of the destination Tensor. If {@link src} has more elements than the 430 * destination has capacity, the copy is truncated. 431 */ feed(String inputName, ByteBuffer src, long... dims)432 public void feed(String inputName, ByteBuffer src, long... dims) { 433 addFeed(inputName, Tensor.create(UInt8.class, dims, src)); 434 } 435 436 /** 437 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 438 * dst} must have length greater than or equal to that of the source Tensor. This operation will 439 * not affect dst's content past the source Tensor's size. 440 */ fetch(String outputName, float[] dst)441 public void fetch(String outputName, float[] dst) { 442 fetch(outputName, FloatBuffer.wrap(dst)); 443 } 444 445 /** 446 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 447 * dst} must have length greater than or equal to that of the source Tensor. This operation will 448 * not affect dst's content past the source Tensor's size. 449 */ fetch(String outputName, int[] dst)450 public void fetch(String outputName, int[] dst) { 451 fetch(outputName, IntBuffer.wrap(dst)); 452 } 453 454 /** 455 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 456 * dst} must have length greater than or equal to that of the source Tensor. This operation will 457 * not affect dst's content past the source Tensor's size. 458 */ fetch(String outputName, long[] dst)459 public void fetch(String outputName, long[] dst) { 460 fetch(outputName, LongBuffer.wrap(dst)); 461 } 462 463 /** 464 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 465 * dst} must have length greater than or equal to that of the source Tensor. This operation will 466 * not affect dst's content past the source Tensor's size. 467 */ fetch(String outputName, double[] dst)468 public void fetch(String outputName, double[] dst) { 469 fetch(outputName, DoubleBuffer.wrap(dst)); 470 } 471 472 /** 473 * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link 474 * dst} must have length greater than or equal to that of the source Tensor. This operation will 475 * not affect dst's content past the source Tensor's size. 476 */ fetch(String outputName, byte[] dst)477 public void fetch(String outputName, byte[] dst) { 478 fetch(outputName, ByteBuffer.wrap(dst)); 479 } 480 481 /** 482 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 483 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 484 * or equal to that of the source Tensor. This operation will not affect dst's content past the 485 * source Tensor's size. 486 */ fetch(String outputName, FloatBuffer dst)487 public void fetch(String outputName, FloatBuffer dst) { 488 getTensor(outputName).writeTo(dst); 489 } 490 491 /** 492 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 493 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 494 * or equal to that of the source Tensor. This operation will not affect dst's content past the 495 * source Tensor's size. 496 */ fetch(String outputName, IntBuffer dst)497 public void fetch(String outputName, IntBuffer dst) { 498 getTensor(outputName).writeTo(dst); 499 } 500 501 /** 502 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 503 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 504 * or equal to that of the source Tensor. This operation will not affect dst's content past the 505 * source Tensor's size. 506 */ fetch(String outputName, LongBuffer dst)507 public void fetch(String outputName, LongBuffer dst) { 508 getTensor(outputName).writeTo(dst); 509 } 510 511 /** 512 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 513 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 514 * or equal to that of the source Tensor. This operation will not affect dst's content past the 515 * source Tensor's size. 516 */ fetch(String outputName, DoubleBuffer dst)517 public void fetch(String outputName, DoubleBuffer dst) { 518 getTensor(outputName).writeTo(dst); 519 } 520 521 /** 522 * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and 523 * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than 524 * or equal to that of the source Tensor. This operation will not affect dst's content past the 525 * source Tensor's size. 526 */ fetch(String outputName, ByteBuffer dst)527 public void fetch(String outputName, ByteBuffer dst) { 528 getTensor(outputName).writeTo(dst); 529 } 530 prepareNativeRuntime()531 private void prepareNativeRuntime() { 532 Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); 533 try { 534 // Hack to see if the native libraries have been loaded. 535 new RunStats(); 536 Log.i(TAG, "TensorFlow native methods already loaded"); 537 } catch (UnsatisfiedLinkError e1) { 538 Log.i( 539 TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference"); 540 try { 541 System.loadLibrary("tensorflow_inference"); 542 Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)"); 543 } catch (UnsatisfiedLinkError e2) { 544 throw new RuntimeException( 545 "Native TF methods not found; check that the correct native" 546 + " libraries are present in the APK."); 547 } 548 } 549 } 550 loadGraph(byte[] graphDef, Graph g)551 private void loadGraph(byte[] graphDef, Graph g) throws IOException { 552 final long startMs = System.currentTimeMillis(); 553 554 if (VERSION.SDK_INT >= 18) { 555 Trace.beginSection("importGraphDef"); 556 } 557 558 try { 559 g.importGraphDef(graphDef); 560 } catch (IllegalArgumentException e) { 561 throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage()); 562 } 563 564 if (VERSION.SDK_INT >= 18) { 565 Trace.endSection(); // importGraphDef. 566 } 567 568 final long endMs = System.currentTimeMillis(); 569 Log.i( 570 TAG, 571 "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); 572 } 573 addFeed(String inputName, Tensor<?> t)574 private void addFeed(String inputName, Tensor<?> t) { 575 // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. 576 TensorId tid = TensorId.parse(inputName); 577 runner.feed(tid.name, tid.outputIndex, t); 578 feedNames.add(inputName); 579 feedTensors.add(t); 580 } 581 582 private static class TensorId { 583 String name; 584 int outputIndex; 585 586 // Parse output names into a TensorId. 587 // 588 // E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1) parse(String name)589 public static TensorId parse(String name) { 590 TensorId tid = new TensorId(); 591 int colonIndex = name.lastIndexOf(':'); 592 if (colonIndex < 0) { 593 tid.outputIndex = 0; 594 tid.name = name; 595 return tid; 596 } 597 try { 598 tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1)); 599 tid.name = name.substring(0, colonIndex); 600 } catch (NumberFormatException e) { 601 tid.outputIndex = 0; 602 tid.name = name; 603 } 604 return tid; 605 } 606 } 607 getTensor(String outputName)608 private Tensor<?> getTensor(String outputName) { 609 int i = 0; 610 for (String n : fetchNames) { 611 if (n.equals(outputName)) { 612 return fetchTensors.get(i); 613 } 614 ++i; 615 } 616 throw new RuntimeException( 617 "Node '" + outputName + "' was not provided to run(), so it cannot be read"); 618 } 619 closeFeeds()620 private void closeFeeds() { 621 for (Tensor<?> t : feedTensors) { 622 t.close(); 623 } 624 feedTensors.clear(); 625 feedNames.clear(); 626 } 627 closeFetches()628 private void closeFetches() { 629 for (Tensor<?> t : fetchTensors) { 630 t.close(); 631 } 632 fetchTensors.clear(); 633 fetchNames.clear(); 634 } 635 636 // Immutable state. 637 private final String modelName; 638 private final Graph g; 639 private final Session sess; 640 641 // State reset on every call to run. 642 private Session.Runner runner; 643 private List<String> feedNames = new ArrayList<String>(); 644 private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>(); 645 private List<String> fetchNames = new ArrayList<String>(); 646 private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>(); 647 648 // Mutable state. 649 private RunStats runStats; 650 } 651