1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/summary/summary_db_writer.h"
16 
17 #include <deque>
18 
19 #include "tensorflow/core/summary/summary_converter.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/summary.pb.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/db/sqlite.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/util/event.pb.h"
28 
29 // TODO(jart): Break this up into multiple files with excellent unit tests.
30 // TODO(jart): Make decision to write in separate op.
31 // TODO(jart): Add really good busy handling.
32 
33 // clang-format off
34 #define CALL_SUPPORTED_TYPES(m) \
35   TF_CALL_tstring(m)             \
36   TF_CALL_half(m)               \
37   TF_CALL_float(m)              \
38   TF_CALL_double(m)             \
39   TF_CALL_complex64(m)          \
40   TF_CALL_complex128(m)         \
41   TF_CALL_int8(m)               \
42   TF_CALL_int16(m)              \
43   TF_CALL_int32(m)              \
44   TF_CALL_int64(m)              \
45   TF_CALL_uint8(m)              \
46   TF_CALL_uint16(m)             \
47   TF_CALL_uint32(m)             \
48   TF_CALL_uint64(m)
49 // clang-format on
50 
51 namespace tensorflow {
52 namespace {
53 
54 // https://www.sqlite.org/fileformat.html#record_format
55 const uint64 kIdTiers[] = {
56     0x7fffffULL,        // 23-bit (3 bytes on disk)
57     0x7fffffffULL,      // 31-bit (4 bytes on disk)
58     0x7fffffffffffULL,  // 47-bit (5 bytes on disk)
59                         // remaining bits for future use
60 };
61 const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
62 const int kIdCollisionDelayMicros = 10;
63 const int kMaxIdCollisions = 21;  // sum(2**i*10µs for i in range(21))~=21s
64 const int64 kAbsent = 0LL;
65 
66 const char* kScalarPluginName = "scalars";
67 const char* kImagePluginName = "images";
68 const char* kAudioPluginName = "audio";
69 const char* kHistogramPluginName = "histograms";
70 
71 const int64 kReserveMinBytes = 32;
72 const double kReserveMultiplier = 1.5;
73 const int64 kPreallocateRows = 1000;
74 
75 // Flush is a misnomer because what we're actually doing is having lots
76 // of commits inside any SqliteTransaction that writes potentially
77 // hundreds of megs but doesn't need the transaction to maintain its
78 // invariants. This ensures the WAL read penalty is small and might
79 // allow writers in other processes a chance to schedule.
80 const uint64 kFlushBytes = 1024 * 1024;
81 
DoubleTime(uint64 micros)82 double DoubleTime(uint64 micros) {
83   // TODO(@jart): Follow precise definitions for time laid out in schema.
84   // TODO(@jart): Use monotonic clock from gRPC codebase.
85   return static_cast<double>(micros) / 1.0e6;
86 }
87 
StringifyShape(const TensorShape & shape)88 string StringifyShape(const TensorShape& shape) {
89   string result;
90   bool first = true;
91   for (const auto& dim : shape) {
92     if (first) {
93       first = false;
94     } else {
95       strings::StrAppend(&result, ",");
96     }
97     strings::StrAppend(&result, dim.size);
98   }
99   return result;
100 }
101 
CheckSupportedType(const Tensor & t)102 Status CheckSupportedType(const Tensor& t) {
103 #define CASE(T)                  \
104   case DataTypeToEnum<T>::value: \
105     break;
106   switch (t.dtype()) {
107     CALL_SUPPORTED_TYPES(CASE)
108     default:
109       return errors::Unimplemented(DataTypeString(t.dtype()),
110                                    " tensors unsupported on platform");
111   }
112   return Status::OK();
113 #undef CASE
114 }
115 
AsScalar(const Tensor & t)116 Tensor AsScalar(const Tensor& t) {
117   Tensor t2{t.dtype(), {}};
118 #define CASE(T)                        \
119   case DataTypeToEnum<T>::value:       \
120     t2.scalar<T>()() = t.flat<T>()(0); \
121     break;
122   switch (t.dtype()) {
123     CALL_SUPPORTED_TYPES(CASE)
124     default:
125       t2 = {DT_FLOAT, {}};
126       t2.scalar<float>()() = NAN;
127       break;
128   }
129   return t2;
130 #undef CASE
131 }
132 
PatchPluginName(SummaryMetadata * metadata,const char * name)133 void PatchPluginName(SummaryMetadata* metadata, const char* name) {
134   if (metadata->plugin_data().plugin_name().empty()) {
135     metadata->mutable_plugin_data()->set_plugin_name(name);
136   }
137 }
138 
SetDescription(Sqlite * db,int64 id,const StringPiece & markdown)139 Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
140   const char* sql = R"sql(
141     INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
142   )sql";
143   SqliteStatement insert_desc;
144   TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc));
145   insert_desc.BindInt(1, id);
146   insert_desc.BindText(2, markdown);
147   return insert_desc.StepAndReset();
148 }
149 
150 /// \brief Generates unique IDs randomly in the [1,2**63-1] range.
151 ///
152 /// This class starts off generating IDs in the [1,2**23-1] range,
153 /// because it's human friendly and occupies 4 bytes max on disk with
154 /// SQLite's zigzag varint encoding. Then, each time a collision
155 /// happens, the random space is increased by 8 bits.
156 ///
157 /// This class uses exponential back-off so writes gradually slow down
158 /// as IDs become exhausted but reads are still possible.
159 ///
160 /// This class is thread safe.
161 class IdAllocator {
162  public:
IdAllocator(Env * env,Sqlite * db)163   IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} {
164     DCHECK(env_ != nullptr);
165     DCHECK(db_ != nullptr);
166   }
167 
CreateNewId(int64 * id)168   Status CreateNewId(int64* id) TF_LOCKS_EXCLUDED(mu_) {
169     mutex_lock lock(mu_);
170     Status s;
171     SqliteStatement stmt;
172     TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt));
173     for (int i = 0; i < kMaxIdCollisions; ++i) {
174       int64 tid = MakeRandomId();
175       stmt.BindInt(1, tid);
176       s = stmt.StepAndReset();
177       if (s.ok()) {
178         *id = tid;
179         break;
180       }
181       // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc
182       if (s.code() != error::INVALID_ARGUMENT) break;
183       if (tier_ < kMaxIdTier) {
184         LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of "
185                   << kMaxIdTier << ") so auto-adjusting to a higher tier";
186         ++tier_;
187       } else {
188         LOG(WARNING) << "IdAllocator (attempt #" << i << ") "
189                      << "resulted in a collision at the highest tier; this "
190                         "is problematic if it happens often; you can try "
191                         "pruning the Ids table; you can also file a bug "
192                         "asking for the ID space to be increased; otherwise "
193                         "writes will gradually slow down over time until they "
194                         "become impossible";
195       }
196       env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros);
197     }
198     return s;
199   }
200 
201  private:
MakeRandomId()202   int64 MakeRandomId() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
203     int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]);
204     if (id == kAbsent) ++id;
205     return id;
206   }
207 
208   mutex mu_;
209   Env* const env_;
210   Sqlite* const db_;
211   int tier_ TF_GUARDED_BY(mu_) = 0;
212 
213   TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator);
214 };
215 
216 class GraphWriter {
217  public:
Save(Sqlite * db,SqliteTransaction * txn,IdAllocator * ids,GraphDef * graph,uint64 now,int64 run_id,int64 * graph_id)218   static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids,
219                      GraphDef* graph, uint64 now, int64 run_id, int64* graph_id)
220       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
221     TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id));
222     GraphWriter saver{db, txn, graph, now, *graph_id};
223     saver.MapNameToNodeId();
224     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs");
225     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes");
226     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph");
227     return Status::OK();
228   }
229 
230  private:
GraphWriter(Sqlite * db,SqliteTransaction * txn,GraphDef * graph,uint64 now,int64 graph_id)231   GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now,
232               int64 graph_id)
233       : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {}
234 
MapNameToNodeId()235   void MapNameToNodeId() {
236     size_t toto = static_cast<size_t>(graph_->node_size());
237     name_copies_.reserve(toto);
238     name_to_node_id_.reserve(toto);
239     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
240       // Copy name into memory region, since we call clear_name() later.
241       // Then wrap in StringPiece so we can compare slices without copy.
242       name_copies_.emplace_back(graph_->node(node_id).name());
243       name_to_node_id_.emplace(name_copies_.back(), node_id);
244     }
245   }
246 
SaveNodeInputs()247   Status SaveNodeInputs() {
248     const char* sql = R"sql(
249       INSERT INTO NodeInputs (
250         graph_id,
251         node_id,
252         idx,
253         input_node_id,
254         input_node_idx,
255         is_control
256       ) VALUES (?, ?, ?, ?, ?, ?)
257     )sql";
258     SqliteStatement insert;
259     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
260     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
261       const NodeDef& node = graph_->node(node_id);
262       for (int idx = 0; idx < node.input_size(); ++idx) {
263         StringPiece name = node.input(idx);
264         int64 input_node_id;
265         int64 input_node_idx = 0;
266         int64 is_control = 0;
267         size_t i = name.rfind(':');
268         if (i != StringPiece::npos) {
269           if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1),
270                                      &input_node_idx)) {
271             return errors::DataLoss("Bad NodeDef.input: ", name);
272           }
273           name.remove_suffix(name.size() - i);
274         }
275         if (!name.empty() && name[0] == '^') {
276           name.remove_prefix(1);
277           is_control = 1;
278         }
279         auto e = name_to_node_id_.find(name);
280         if (e == name_to_node_id_.end()) {
281           return errors::DataLoss("Could not find node: ", name);
282         }
283         input_node_id = e->second;
284         insert.BindInt(1, graph_id_);
285         insert.BindInt(2, node_id);
286         insert.BindInt(3, idx);
287         insert.BindInt(4, input_node_id);
288         insert.BindInt(5, input_node_idx);
289         insert.BindInt(6, is_control);
290         unflushed_bytes_ += insert.size();
291         TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
292                                         " -> ", name);
293         TF_RETURN_IF_ERROR(MaybeFlush());
294       }
295     }
296     return Status::OK();
297   }
298 
SaveNodes()299   Status SaveNodes() {
300     const char* sql = R"sql(
301       INSERT INTO Nodes (
302         graph_id,
303         node_id,
304         node_name,
305         op,
306         device,
307         node_def)
308       VALUES (?, ?, ?, ?, ?, ?)
309     )sql";
310     SqliteStatement insert;
311     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
312     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
313       NodeDef* node = graph_->mutable_node(node_id);
314       insert.BindInt(1, graph_id_);
315       insert.BindInt(2, node_id);
316       insert.BindText(3, node->name());
317       insert.BindText(4, node->op());
318       insert.BindText(5, node->device());
319       node->clear_name();
320       node->clear_op();
321       node->clear_device();
322       node->clear_input();
323       string node_def;
324       if (node->SerializeToString(&node_def)) {
325         insert.BindBlobUnsafe(6, node_def);
326       }
327       unflushed_bytes_ += insert.size();
328       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
329       TF_RETURN_IF_ERROR(MaybeFlush());
330     }
331     return Status::OK();
332   }
333 
SaveGraph(int64 run_id)334   Status SaveGraph(int64 run_id) {
335     const char* sql = R"sql(
336       INSERT OR REPLACE INTO Graphs (
337         run_id,
338         graph_id,
339         inserted_time,
340         graph_def
341       ) VALUES (?, ?, ?, ?)
342     )sql";
343     SqliteStatement insert;
344     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
345     if (run_id != kAbsent) insert.BindInt(1, run_id);
346     insert.BindInt(2, graph_id_);
347     insert.BindDouble(3, DoubleTime(now_));
348     graph_->clear_node();
349     string graph_def;
350     if (graph_->SerializeToString(&graph_def)) {
351       insert.BindBlobUnsafe(4, graph_def);
352     }
353     return insert.StepAndReset();
354   }
355 
MaybeFlush()356   Status MaybeFlush() {
357     if (unflushed_bytes_ >= kFlushBytes) {
358       TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ",
359                                       unflushed_bytes_, " bytes");
360       unflushed_bytes_ = 0;
361     }
362     return Status::OK();
363   }
364 
365   Sqlite* const db_;
366   SqliteTransaction* const txn_;
367   uint64 unflushed_bytes_ = 0;
368   GraphDef* const graph_;
369   const uint64 now_;
370   const int64 graph_id_;
371   std::vector<string> name_copies_;
372   std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_;
373 
374   TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter);
375 };
376 
377 /// \brief Run metadata manager.
378 ///
379 /// This class gives us Tag IDs we can pass to SeriesWriter. In order
380 /// to do that, rows are created in the Ids, Tags, Runs, Experiments,
381 /// and Users tables.
382 ///
383 /// This class is thread safe.
384 class RunMetadata {
385  public:
RunMetadata(IdAllocator * ids,const string & experiment_name,const string & run_name,const string & user_name)386   RunMetadata(IdAllocator* ids, const string& experiment_name,
387               const string& run_name, const string& user_name)
388       : ids_{ids},
389         experiment_name_{experiment_name},
390         run_name_{run_name},
391         user_name_{user_name} {
392     DCHECK(ids_ != nullptr);
393   }
394 
experiment_name()395   const string& experiment_name() { return experiment_name_; }
run_name()396   const string& run_name() { return run_name_; }
user_name()397   const string& user_name() { return user_name_; }
398 
run_id()399   int64 run_id() TF_LOCKS_EXCLUDED(mu_) {
400     mutex_lock lock(mu_);
401     return run_id_;
402   }
403 
SetGraph(Sqlite * db,uint64 now,double computed_time,std::unique_ptr<GraphDef> g)404   Status SetGraph(Sqlite* db, uint64 now, double computed_time,
405                   std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db)
406       TF_LOCKS_EXCLUDED(mu_) {
407     int64 run_id;
408     {
409       mutex_lock lock(mu_);
410       TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
411       run_id = run_id_;
412     }
413     int64 graph_id;
414     SqliteTransaction txn(*db);  // only to increase performance
415     TF_RETURN_IF_ERROR(
416         GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id));
417     return txn.Commit();
418   }
419 
GetTagId(Sqlite * db,uint64 now,double computed_time,const string & tag_name,int64 * tag_id,const SummaryMetadata & metadata)420   Status GetTagId(Sqlite* db, uint64 now, double computed_time,
421                   const string& tag_name, int64* tag_id,
422                   const SummaryMetadata& metadata) TF_LOCKS_EXCLUDED(mu_) {
423     mutex_lock lock(mu_);
424     TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
425     auto e = tag_ids_.find(tag_name);
426     if (e != tag_ids_.end()) {
427       *tag_id = e->second;
428       return Status::OK();
429     }
430     TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id));
431     tag_ids_[tag_name] = *tag_id;
432     TF_RETURN_IF_ERROR(
433         SetDescription(db, *tag_id, metadata.summary_description()));
434     const char* sql = R"sql(
435       INSERT INTO Tags (
436         run_id,
437         tag_id,
438         tag_name,
439         inserted_time,
440         display_name,
441         plugin_name,
442         plugin_data
443       ) VALUES (
444         :run_id,
445         :tag_id,
446         :tag_name,
447         :inserted_time,
448         :display_name,
449         :plugin_name,
450         :plugin_data
451       )
452     )sql";
453     SqliteStatement insert;
454     TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
455     if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_);
456     insert.BindInt(":tag_id", *tag_id);
457     insert.BindTextUnsafe(":tag_name", tag_name);
458     insert.BindDouble(":inserted_time", DoubleTime(now));
459     insert.BindTextUnsafe(":display_name", metadata.display_name());
460     insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name());
461     insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content());
462     return insert.StepAndReset();
463   }
464 
465  private:
InitializeUser(Sqlite * db,uint64 now)466   Status InitializeUser(Sqlite* db, uint64 now)
467       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
468     if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
469     const char* get_sql = R"sql(
470       SELECT user_id FROM Users WHERE user_name = ?
471     )sql";
472     SqliteStatement get;
473     TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
474     get.BindText(1, user_name_);
475     bool is_done;
476     TF_RETURN_IF_ERROR(get.Step(&is_done));
477     if (!is_done) {
478       user_id_ = get.ColumnInt(0);
479       return Status::OK();
480     }
481     TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_));
482     const char* insert_sql = R"sql(
483       INSERT INTO Users (
484         user_id,
485         user_name,
486         inserted_time
487       ) VALUES (?, ?, ?)
488     )sql";
489     SqliteStatement insert;
490     TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
491     insert.BindInt(1, user_id_);
492     insert.BindText(2, user_name_);
493     insert.BindDouble(3, DoubleTime(now));
494     TF_RETURN_IF_ERROR(insert.StepAndReset());
495     return Status::OK();
496   }
497 
InitializeExperiment(Sqlite * db,uint64 now,double computed_time)498   Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time)
499       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
500     if (experiment_name_.empty()) return Status::OK();
501     if (experiment_id_ == kAbsent) {
502       TF_RETURN_IF_ERROR(InitializeUser(db, now));
503       const char* get_sql = R"sql(
504         SELECT
505           experiment_id,
506           started_time
507         FROM
508           Experiments
509         WHERE
510           user_id IS ?
511           AND experiment_name = ?
512       )sql";
513       SqliteStatement get;
514       TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
515       if (user_id_ != kAbsent) get.BindInt(1, user_id_);
516       get.BindText(2, experiment_name_);
517       bool is_done;
518       TF_RETURN_IF_ERROR(get.Step(&is_done));
519       if (!is_done) {
520         experiment_id_ = get.ColumnInt(0);
521         experiment_started_time_ = get.ColumnInt(1);
522       } else {
523         TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_));
524         experiment_started_time_ = computed_time;
525         const char* insert_sql = R"sql(
526           INSERT INTO Experiments (
527             user_id,
528             experiment_id,
529             experiment_name,
530             inserted_time,
531             started_time,
532             is_watching
533           ) VALUES (?, ?, ?, ?, ?, ?)
534         )sql";
535         SqliteStatement insert;
536         TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
537         if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
538         insert.BindInt(2, experiment_id_);
539         insert.BindText(3, experiment_name_);
540         insert.BindDouble(4, DoubleTime(now));
541         insert.BindDouble(5, computed_time);
542         insert.BindInt(6, 0);
543         TF_RETURN_IF_ERROR(insert.StepAndReset());
544       }
545     }
546     if (computed_time < experiment_started_time_) {
547       experiment_started_time_ = computed_time;
548       const char* update_sql = R"sql(
549         UPDATE
550           Experiments
551         SET
552           started_time = ?
553         WHERE
554           experiment_id = ?
555       )sql";
556       SqliteStatement update;
557       TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
558       update.BindDouble(1, computed_time);
559       update.BindInt(2, experiment_id_);
560       TF_RETURN_IF_ERROR(update.StepAndReset());
561     }
562     return Status::OK();
563   }
564 
InitializeRun(Sqlite * db,uint64 now,double computed_time)565   Status InitializeRun(Sqlite* db, uint64 now, double computed_time)
566       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
567     if (run_name_.empty()) return Status::OK();
568     TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time));
569     if (run_id_ == kAbsent) {
570       TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_));
571       run_started_time_ = computed_time;
572       const char* insert_sql = R"sql(
573         INSERT OR REPLACE INTO Runs (
574           experiment_id,
575           run_id,
576           run_name,
577           inserted_time,
578           started_time
579         ) VALUES (?, ?, ?, ?, ?)
580       )sql";
581       SqliteStatement insert;
582       TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
583       if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
584       insert.BindInt(2, run_id_);
585       insert.BindText(3, run_name_);
586       insert.BindDouble(4, DoubleTime(now));
587       insert.BindDouble(5, computed_time);
588       TF_RETURN_IF_ERROR(insert.StepAndReset());
589     }
590     if (computed_time < run_started_time_) {
591       run_started_time_ = computed_time;
592       const char* update_sql = R"sql(
593         UPDATE
594           Runs
595         SET
596           started_time = ?
597         WHERE
598           run_id = ?
599       )sql";
600       SqliteStatement update;
601       TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
602       update.BindDouble(1, computed_time);
603       update.BindInt(2, run_id_);
604       TF_RETURN_IF_ERROR(update.StepAndReset());
605     }
606     return Status::OK();
607   }
608 
609   mutex mu_;
610   IdAllocator* const ids_;
611   const string experiment_name_;
612   const string run_name_;
613   const string user_name_;
614   int64 experiment_id_ TF_GUARDED_BY(mu_) = kAbsent;
615   int64 run_id_ TF_GUARDED_BY(mu_) = kAbsent;
616   int64 user_id_ TF_GUARDED_BY(mu_) = kAbsent;
617   double experiment_started_time_ TF_GUARDED_BY(mu_) = 0.0;
618   double run_started_time_ TF_GUARDED_BY(mu_) = 0.0;
619   std::unordered_map<string, int64> tag_ids_ TF_GUARDED_BY(mu_);
620 
621   TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata);
622 };
623 
624 /// \brief Tensor writer for a single series, e.g. Tag.
625 ///
626 /// This class is thread safe.
627 class SeriesWriter {
628  public:
SeriesWriter(int64 series,RunMetadata * meta)629   SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} {
630     DCHECK(series_ > 0);
631   }
632 
Append(Sqlite * db,int64 step,uint64 now,double computed_time,const Tensor & t)633   Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
634                 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
635       TF_LOCKS_EXCLUDED(mu_) {
636     mutex_lock lock(mu_);
637     if (rowids_.empty()) {
638       Status s = Reserve(db, t);
639       if (!s.ok()) {
640         rowids_.clear();
641         return s;
642       }
643     }
644     int64 rowid = rowids_.front();
645     Status s = Write(db, rowid, step, computed_time, t);
646     if (s.ok()) {
647       ++count_;
648     }
649     rowids_.pop_front();
650     return s;
651   }
652 
Finish(Sqlite * db)653   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
654       TF_LOCKS_EXCLUDED(mu_) {
655     mutex_lock lock(mu_);
656     // Delete unused pre-allocated Tensors.
657     if (!rowids_.empty()) {
658       SqliteTransaction txn(*db);
659       const char* sql = R"sql(
660         DELETE FROM Tensors WHERE rowid = ?
661       )sql";
662       SqliteStatement deleter;
663       TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
664       for (size_t i = count_; i < rowids_.size(); ++i) {
665         deleter.BindInt(1, rowids_.front());
666         TF_RETURN_IF_ERROR(deleter.StepAndReset());
667         rowids_.pop_front();
668       }
669       TF_RETURN_IF_ERROR(txn.Commit());
670       rowids_.clear();
671     }
672     return Status::OK();
673   }
674 
675  private:
Write(Sqlite * db,int64 rowid,int64 step,double computed_time,const Tensor & t)676   Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time,
677                const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
678     if (t.dtype() == DT_STRING) {
679       if (t.dims() == 0) {
680         return Update(db, step, computed_time, t, t.scalar<tstring>()(), rowid);
681       } else {
682         SqliteTransaction txn(*db);
683         TF_RETURN_IF_ERROR(
684             Update(db, step, computed_time, t, StringPiece(), rowid));
685         TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid));
686         return txn.Commit();
687       }
688     } else {
689       return Update(db, step, computed_time, t, t.tensor_data(), rowid);
690     }
691   }
692 
Update(Sqlite * db,int64 step,double computed_time,const Tensor & t,const StringPiece & data,int64 rowid)693   Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
694                 const StringPiece& data, int64 rowid) {
695     const char* sql = R"sql(
696       UPDATE OR REPLACE
697         Tensors
698       SET
699         step = ?,
700         computed_time = ?,
701         dtype = ?,
702         shape = ?,
703         data = ?
704       WHERE
705         rowid = ?
706     )sql";
707     SqliteStatement stmt;
708     TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
709     stmt.BindInt(1, step);
710     stmt.BindDouble(2, computed_time);
711     stmt.BindInt(3, t.dtype());
712     stmt.BindText(4, StringifyShape(t.shape()));
713     stmt.BindBlobUnsafe(5, data);
714     stmt.BindInt(6, rowid);
715     TF_RETURN_IF_ERROR(stmt.StepAndReset());
716     return Status::OK();
717   }
718 
UpdateNdString(Sqlite * db,const Tensor & t,int64 tensor_rowid)719   Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid)
720       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
721     DCHECK_EQ(t.dtype(), DT_STRING);
722     DCHECK_GT(t.dims(), 0);
723     const char* deleter_sql = R"sql(
724       DELETE FROM TensorStrings WHERE tensor_rowid = ?
725     )sql";
726     SqliteStatement deleter;
727     TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter));
728     deleter.BindInt(1, tensor_rowid);
729     TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid);
730     const char* inserter_sql = R"sql(
731       INSERT INTO TensorStrings (
732         tensor_rowid,
733         idx,
734         data
735       ) VALUES (?, ?, ?)
736     )sql";
737     SqliteStatement inserter;
738     TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
739     auto flat = t.flat<tstring>();
740     for (int64 i = 0; i < flat.size(); ++i) {
741       inserter.BindInt(1, tensor_rowid);
742       inserter.BindInt(2, i);
743       inserter.BindBlobUnsafe(3, flat(i));
744       TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i);
745     }
746     return Status::OK();
747   }
748 
Reserve(Sqlite * db,const Tensor & t)749   Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
750       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
751     SqliteTransaction txn(*db);  // only for performance
752     unflushed_bytes_ = 0;
753     if (t.dtype() == DT_STRING) {
754       if (t.dims() == 0) {
755         TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<tstring>()().size()));
756       } else {
757         TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
758       }
759     } else {
760       TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size()));
761     }
762     return txn.Commit();
763   }
764 
ReserveData(Sqlite * db,SqliteTransaction * txn,size_t size)765   Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size)
766       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
767           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
768     int64 space =
769         static_cast<int64>(static_cast<double>(size) * kReserveMultiplier);
770     if (space < kReserveMinBytes) space = kReserveMinBytes;
771     return ReserveTensors(db, txn, space);
772   }
773 
ReserveTensors(Sqlite * db,SqliteTransaction * txn,int64 reserved_bytes)774   Status ReserveTensors(Sqlite* db, SqliteTransaction* txn,
775                         int64 reserved_bytes)
776       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
777           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
778     const char* sql = R"sql(
779       INSERT INTO Tensors (
780         series,
781         data
782       ) VALUES (?, ZEROBLOB(?))
783     )sql";
784     SqliteStatement insert;
785     TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
786     // TODO(jart): Maybe preallocate index pages by setting step. This
787     //             is tricky because UPDATE OR REPLACE can have a side
788     //             effect of deleting preallocated rows.
789     for (int64 i = 0; i < kPreallocateRows; ++i) {
790       insert.BindInt(1, series_);
791       insert.BindInt(2, reserved_bytes);
792       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
793       rowids_.push_back(db->last_insert_rowid());
794       unflushed_bytes_ += reserved_bytes;
795       TF_RETURN_IF_ERROR(MaybeFlush(db, txn));
796     }
797     return Status::OK();
798   }
799 
MaybeFlush(Sqlite * db,SqliteTransaction * txn)800   Status MaybeFlush(Sqlite* db, SqliteTransaction* txn)
801       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
802           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
803     if (unflushed_bytes_ >= kFlushBytes) {
804       TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ",
805                                       unflushed_bytes_, " bytes");
806       unflushed_bytes_ = 0;
807     }
808     return Status::OK();
809   }
810 
811   mutex mu_;
812   const int64 series_;
813   RunMetadata* const meta_;
814   uint64 count_ TF_GUARDED_BY(mu_) = 0;
815   std::deque<int64> rowids_ TF_GUARDED_BY(mu_);
816   uint64 unflushed_bytes_ TF_GUARDED_BY(mu_) = 0;
817 
818   TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
819 };
820 
821 /// \brief Tensor writer for a single Run.
822 ///
823 /// This class farms out tensors to SeriesWriter instances. It also
824 /// keeps track of whether or not someone is watching the TensorBoard
825 /// GUI, so it can avoid writes when possible.
826 ///
827 /// This class is thread safe.
828 class RunWriter {
829  public:
RunWriter(RunMetadata * meta)830   explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
831 
Append(Sqlite * db,int64 tag_id,int64 step,uint64 now,double computed_time,const Tensor & t)832   Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
833                 double computed_time, const Tensor& t)
834       SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) {
835     SeriesWriter* writer = GetSeriesWriter(tag_id);
836     return writer->Append(db, step, now, computed_time, t);
837   }
838 
Finish(Sqlite * db)839   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
840       TF_LOCKS_EXCLUDED(mu_) {
841     mutex_lock lock(mu_);
842     if (series_writers_.empty()) return Status::OK();
843     for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) {
844       if (!i->second) continue;
845       TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db),
846                                       "finish tag_id=", i->first);
847       i->second.reset();
848     }
849     return Status::OK();
850   }
851 
852  private:
GetSeriesWriter(int64 tag_id)853   SeriesWriter* GetSeriesWriter(int64 tag_id) TF_LOCKS_EXCLUDED(mu_) {
854     mutex_lock sl(mu_);
855     auto spot = series_writers_.find(tag_id);
856     if (spot == series_writers_.end()) {
857       SeriesWriter* writer = new SeriesWriter(tag_id, meta_);
858       series_writers_[tag_id].reset(writer);
859       return writer;
860     } else {
861       return spot->second.get();
862     }
863   }
864 
865   mutex mu_;
866   RunMetadata* const meta_;
867   std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_
868       TF_GUARDED_BY(mu_);
869 
870   TF_DISALLOW_COPY_AND_ASSIGN(RunWriter);
871 };
872 
873 /// \brief SQLite implementation of SummaryWriterInterface.
874 ///
875 /// This class is thread safe.
876 class SummaryDbWriter : public SummaryWriterInterface {
877  public:
SummaryDbWriter(Env * env,Sqlite * db,const string & experiment_name,const string & run_name,const string & user_name)878   SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name,
879                   const string& run_name, const string& user_name)
880       : SummaryWriterInterface(),
881         env_{env},
882         db_{db},
883         ids_{env_, db_},
884         meta_{&ids_, experiment_name, run_name, user_name},
885         run_{&meta_} {
886     DCHECK(env_ != nullptr);
887     db_->Ref();
888   }
889 
~SummaryDbWriter()890   ~SummaryDbWriter() override {
891     core::ScopedUnref unref(db_);
892     Status s = run_.Finish(db_);
893     if (!s.ok()) {
894       // TODO(jart): Retry on transient errors here.
895       LOG(ERROR) << s.ToString();
896     }
897     int64 run_id = meta_.run_id();
898     if (run_id == kAbsent) return;
899     const char* sql = R"sql(
900       UPDATE Runs SET finished_time = ? WHERE run_id = ?
901     )sql";
902     SqliteStatement update;
903     s = db_->Prepare(sql, &update);
904     if (s.ok()) {
905       update.BindDouble(1, DoubleTime(env_->NowMicros()));
906       update.BindInt(2, run_id);
907       s = update.StepAndReset();
908     }
909     if (!s.ok()) {
910       LOG(ERROR) << "Failed to set Runs[" << run_id
911                  << "].finish_time: " << s.ToString();
912     }
913   }
914 
Flush()915   Status Flush() override { return Status::OK(); }
916 
WriteTensor(int64 global_step,Tensor t,const string & tag,const string & serialized_metadata)917   Status WriteTensor(int64 global_step, Tensor t, const string& tag,
918                      const string& serialized_metadata) override {
919     TF_RETURN_IF_ERROR(CheckSupportedType(t));
920     SummaryMetadata metadata;
921     if (!metadata.ParseFromString(serialized_metadata)) {
922       return errors::InvalidArgument("Bad serialized_metadata");
923     }
924     return Write(global_step, t, tag, metadata);
925   }
926 
WriteScalar(int64 global_step,Tensor t,const string & tag)927   Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
928     TF_RETURN_IF_ERROR(CheckSupportedType(t));
929     SummaryMetadata metadata;
930     PatchPluginName(&metadata, kScalarPluginName);
931     return Write(global_step, AsScalar(t), tag, metadata);
932   }
933 
WriteGraph(int64 global_step,std::unique_ptr<GraphDef> g)934   Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
935     uint64 now = env_->NowMicros();
936     return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g));
937   }
938 
WriteEvent(std::unique_ptr<Event> e)939   Status WriteEvent(std::unique_ptr<Event> e) override {
940     return MigrateEvent(std::move(e));
941   }
942 
WriteHistogram(int64 global_step,Tensor t,const string & tag)943   Status WriteHistogram(int64 global_step, Tensor t,
944                         const string& tag) override {
945     uint64 now = env_->NowMicros();
946     std::unique_ptr<Event> e{new Event};
947     e->set_step(global_step);
948     e->set_wall_time(DoubleTime(now));
949     TF_RETURN_IF_ERROR(
950         AddTensorAsHistogramToSummary(t, tag, e->mutable_summary()));
951     return MigrateEvent(std::move(e));
952   }
953 
WriteImage(int64 global_step,Tensor t,const string & tag,int max_images,Tensor bad_color)954   Status WriteImage(int64 global_step, Tensor t, const string& tag,
955                     int max_images, Tensor bad_color) override {
956     uint64 now = env_->NowMicros();
957     std::unique_ptr<Event> e{new Event};
958     e->set_step(global_step);
959     e->set_wall_time(DoubleTime(now));
960     TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color,
961                                                  e->mutable_summary()));
962     return MigrateEvent(std::move(e));
963   }
964 
WriteAudio(int64 global_step,Tensor t,const string & tag,int max_outputs,float sample_rate)965   Status WriteAudio(int64 global_step, Tensor t, const string& tag,
966                     int max_outputs, float sample_rate) override {
967     uint64 now = env_->NowMicros();
968     std::unique_ptr<Event> e{new Event};
969     e->set_step(global_step);
970     e->set_wall_time(DoubleTime(now));
971     TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary(
972         t, tag, max_outputs, sample_rate, e->mutable_summary()));
973     return MigrateEvent(std::move(e));
974   }
975 
DebugString() const976   string DebugString() const override { return "SummaryDbWriter"; }
977 
978  private:
Write(int64 step,const Tensor & t,const string & tag,const SummaryMetadata & metadata)979   Status Write(int64 step, const Tensor& t, const string& tag,
980                const SummaryMetadata& metadata) {
981     uint64 now = env_->NowMicros();
982     double computed_time = DoubleTime(now);
983     int64 tag_id;
984     TF_RETURN_IF_ERROR(
985         meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
986     TF_RETURN_WITH_CONTEXT_IF_ERROR(
987         run_.Append(db_, tag_id, step, now, computed_time, t),
988         meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
989         "/", tag, "@", step);
990     return Status::OK();
991   }
992 
MigrateEvent(std::unique_ptr<Event> e)993   Status MigrateEvent(std::unique_ptr<Event> e) {
994     switch (e->what_case()) {
995       case Event::WhatCase::kSummary: {
996         uint64 now = env_->NowMicros();
997         auto summaries = e->mutable_summary();
998         for (int i = 0; i < summaries->value_size(); ++i) {
999           Summary::Value* value = summaries->mutable_value(i);
1000           TF_RETURN_WITH_CONTEXT_IF_ERROR(
1001               MigrateSummary(e.get(), value, now), meta_.user_name(), "/",
1002               meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(),
1003               "@", e->step());
1004         }
1005         break;
1006       }
1007       case Event::WhatCase::kGraphDef:
1008         TF_RETURN_WITH_CONTEXT_IF_ERROR(
1009             MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/",
1010             meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@",
1011             e->step());
1012         break;
1013       default:
1014         // TODO(@jart): Handle other stuff.
1015         break;
1016     }
1017     return Status::OK();
1018   }
1019 
MigrateGraph(const Event * e,const string & graph_def)1020   Status MigrateGraph(const Event* e, const string& graph_def) {
1021     uint64 now = env_->NowMicros();
1022     std::unique_ptr<GraphDef> graph{new GraphDef};
1023     if (!ParseProtoUnlimited(graph.get(), graph_def)) {
1024       return errors::InvalidArgument("bad proto");
1025     }
1026     return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph));
1027   }
1028 
MigrateSummary(const Event * e,Summary::Value * s,uint64 now)1029   Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) {
1030     switch (s->value_case()) {
1031       case Summary::Value::ValueCase::kTensor:
1032         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor");
1033         break;
1034       case Summary::Value::ValueCase::kSimpleValue:
1035         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar");
1036         break;
1037       case Summary::Value::ValueCase::kHisto:
1038         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo");
1039         break;
1040       case Summary::Value::ValueCase::kImage:
1041         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image");
1042         break;
1043       case Summary::Value::ValueCase::kAudio:
1044         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio");
1045         break;
1046       default:
1047         break;
1048     }
1049     return Status::OK();
1050   }
1051 
MigrateTensor(const Event * e,Summary::Value * s,uint64 now)1052   Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) {
1053     Tensor t;
1054     if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto");
1055     TF_RETURN_IF_ERROR(CheckSupportedType(t));
1056     int64 tag_id;
1057     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1058                                       &tag_id, s->metadata()));
1059     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1060   }
1061 
1062   // TODO(jart): Refactor Summary -> Tensor logic into separate file.
1063 
MigrateScalar(const Event * e,Summary::Value * s,uint64 now)1064   Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) {
1065     // See tensorboard/plugins/scalar/summary.py and data_compat.py
1066     Tensor t{DT_FLOAT, {}};
1067     t.scalar<float>()() = s->simple_value();
1068     int64 tag_id;
1069     PatchPluginName(s->mutable_metadata(), kScalarPluginName);
1070     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1071                                       &tag_id, s->metadata()));
1072     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1073   }
1074 
MigrateHistogram(const Event * e,Summary::Value * s,uint64 now)1075   Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
1076     const HistogramProto& histo = s->histo();
1077     int k = histo.bucket_size();
1078     if (k != histo.bucket_limit_size()) {
1079       return errors::InvalidArgument("size mismatch");
1080     }
1081     // See tensorboard/plugins/histogram/summary.py and data_compat.py
1082     Tensor t{DT_DOUBLE, {k, 3}};
1083     auto data = t.flat<double>();
1084     for (int i = 0, j = 0; i < k; ++i) {
1085       // TODO(nickfelt): reconcile with TensorBoard's data_compat.py
1086       // From summary.proto
1087       // Parallel arrays encoding the bucket boundaries and the bucket values.
1088       // bucket(i) is the count for the bucket i.  The range for
1089       // a bucket is:
1090       //   i == 0:  -DBL_MAX .. bucket_limit(0)
1091       //   i != 0:  bucket_limit(i-1) .. bucket_limit(i)
1092       double left_edge = (i == 0) ? std::numeric_limits<double>::min()
1093                                   : histo.bucket_limit(i - 1);
1094 
1095       data(j++) = left_edge;
1096       data(j++) = histo.bucket_limit(i);
1097       data(j++) = histo.bucket(i);
1098     }
1099     int64 tag_id;
1100     PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
1101     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1102                                       &tag_id, s->metadata()));
1103     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1104   }
1105 
MigrateImage(const Event * e,Summary::Value * s,uint64 now)1106   Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
1107     // See tensorboard/plugins/image/summary.py and data_compat.py
1108     Tensor t{DT_STRING, {3}};
1109     auto img = s->mutable_image();
1110     t.flat<tstring>()(0) = strings::StrCat(img->width());
1111     t.flat<tstring>()(1) = strings::StrCat(img->height());
1112     t.flat<tstring>()(2) = std::move(*img->mutable_encoded_image_string());
1113     int64 tag_id;
1114     PatchPluginName(s->mutable_metadata(), kImagePluginName);
1115     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1116                                       &tag_id, s->metadata()));
1117     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1118   }
1119 
MigrateAudio(const Event * e,Summary::Value * s,uint64 now)1120   Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
1121     // See tensorboard/plugins/audio/summary.py and data_compat.py
1122     Tensor t{DT_STRING, {1, 2}};
1123     auto wav = s->mutable_audio();
1124     t.flat<tstring>()(0) = std::move(*wav->mutable_encoded_audio_string());
1125     t.flat<tstring>()(1) = "";
1126     int64 tag_id;
1127     PatchPluginName(s->mutable_metadata(), kAudioPluginName);
1128     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1129                                       &tag_id, s->metadata()));
1130     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1131   }
1132 
1133   Env* const env_;
1134   Sqlite* const db_;
1135   IdAllocator ids_;
1136   RunMetadata meta_;
1137   RunWriter run_;
1138 };
1139 
1140 }  // namespace
1141 
CreateSummaryDbWriter(Sqlite * db,const string & experiment_name,const string & run_name,const string & user_name,Env * env,SummaryWriterInterface ** result)1142 Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name,
1143                              const string& run_name, const string& user_name,
1144                              Env* env, SummaryWriterInterface** result) {
1145   *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name);
1146   return Status::OK();
1147 }
1148 
1149 }  // namespace tensorflow
1150