1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <tuple>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h"
37 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
38 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
39 #include "tensorflow/core/profiler/utils/xplane_schema.h"
40 #include "tensorflow/core/profiler/utils/xplane_utils.h"
41 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
42 
43 namespace tensorflow {
44 namespace profiler {
45 
46 namespace {
47 
48 constexpr int64 kInvalidStepId = -1;
49 
50 // Index of the time-sorted memory_profile_snapshots list, and the
51 // MemoryActivityMetadata proto it contains.
52 using IndexMetaPair = std::pair<int64 /*index*/, const MemoryActivityMetadata*>;
53 
IsMemoryAllocation(int64 event_type)54 bool IsMemoryAllocation(int64 event_type) {
55   return event_type == HostEventType::kMemoryAllocation;
56 }
57 
IsMemoryDeallocation(int64 event_type)58 bool IsMemoryDeallocation(int64 event_type) {
59   return event_type == HostEventType::kMemoryDeallocation;
60 }
61 
UpdateProfileSummary(const MemoryAggregationStats & stats,int64 time_offset_ps,MemoryProfileSummary * summary)62 void UpdateProfileSummary(const MemoryAggregationStats& stats,
63                           int64 time_offset_ps, MemoryProfileSummary* summary) {
64   // Update the peak memory usage over allocator's lifetime.
65   summary->set_peak_bytes_usage_lifetime(stats.peak_bytes_in_use());
66   MemoryAggregationStats* peak_stats = summary->mutable_peak_stats();
67   // If we reach (or stay at) peak memory usage within the profiling window,
68   // update memory profile summary.
69   if (stats.stack_reserved_bytes() + stats.heap_allocated_bytes() >=
70       peak_stats->peak_bytes_in_use()) {
71     *peak_stats = stats;
72     peak_stats->set_peak_bytes_in_use(stats.stack_reserved_bytes() +
73                                       stats.heap_allocated_bytes());
74     summary->set_peak_stats_time_ps(time_offset_ps);
75     summary->set_memory_capacity(stats.stack_reserved_bytes() +
76                                  stats.heap_allocated_bytes() +
77                                  stats.free_memory_bytes());
78   }
79 }
80 
81 // Generate memory profile proto by processing host trace XPlane.
GenerateMemoryProfile(const XPlane * host_trace)82 MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) {
83   XPlaneVisitor plane = CreateTfXPlaneVisitor(host_trace);
84   MemoryProfile memory_profile;
85   // Iterate over all XEvents in the XPlane, and add the XStats to a new
86   // MemoryProfileSnapshot if the EventType is kMemoryAllocation or
87   // kMemoryDeallocation.
88   plane.ForEachLine([&](const XLineVisitor& line) {
89     line.ForEachEvent([&](const XEventVisitor& event) {
90       int64 event_type = event.Type().value_or(kUnknownHostEventType);
91       if (!(IsMemoryAllocation(event_type) ||
92             IsMemoryDeallocation(event_type))) {
93         return;
94       }
95 
96       MemoryAggregationStats stats;
97       MemoryActivityMetadata metadata;
98       if (IsMemoryAllocation(event_type)) {
99         metadata.set_memory_activity(ALLOCATION);
100       } else if (IsMemoryDeallocation(event_type)) {
101         metadata.set_memory_activity(DEALLOCATION);
102       }
103       metadata.set_step_id(kInvalidStepId);
104 
105       std::string memory_id;
106       event.ForEachStat([&](const XStatVisitor& stat) {
107         if (!stat.Type().has_value()) return;
108         switch (stat.Type().value()) {
109           case StatType::kIndexOnHost:
110           case StatType::kDeviceOrdinal:
111             memory_id = absl::StrCat(stat.IntValue());
112             break;
113           case StatType::kAllocatorName:
114             memory_id = std::string(stat.StrOrRefValue());
115             break;
116           case StatType::kBytesReserved:
117             stats.set_stack_reserved_bytes(stat.IntValue());
118             break;
119           case StatType::kBytesAllocated:
120             stats.set_heap_allocated_bytes(stat.IntValue());
121             break;
122           case StatType::kBytesAvailable:
123             stats.set_free_memory_bytes(stat.IntValue());
124             break;
125           case StatType::kFragmentation:
126             stats.set_fragmentation(stat.DoubleValue());
127             break;
128           case StatType::kPeakBytesInUse:
129             stats.set_peak_bytes_in_use(stat.IntValue());
130             break;
131           case StatType::kRequestedBytes:
132             metadata.set_requested_bytes(stat.IntValue());
133             break;
134           case StatType::kAllocationBytes:
135             metadata.set_allocation_bytes(stat.IntValue());
136             break;
137           case StatType::kAddress:
138             metadata.set_address(stat.IntValue());
139             break;
140           case StatType::kTfOp:
141             metadata.set_tf_op_name(std::string(stat.StrOrRefValue()));
142             break;
143           case StatType::kGroupId:
144             metadata.set_step_id(stat.IntValue());
145             break;
146           case StatType::kRegionType:
147             metadata.set_region_type(std::string(stat.StrOrRefValue()));
148             break;
149           case StatType::kDataType:
150             metadata.set_data_type(tensorflow::DataTypeString(
151                 static_cast<tensorflow::DataType>(stat.IntValue())));
152             break;
153           case StatType::kTensorShapes:
154             metadata.set_tensor_shape(std::string(stat.StrOrRefValue()));
155             break;
156         }
157       });
158 
159       MemoryProfileSummary* summary =
160           (*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
161               .mutable_profile_summary();
162       UpdateProfileSummary(stats, event.OffsetPs(), summary);
163 
164       MemoryProfileSnapshot* snapshot =
165           (*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
166               .add_memory_profile_snapshots();
167       snapshot->set_time_offset_ps(event.OffsetPs());
168       *snapshot->mutable_aggregation_stats() = std::move(stats);
169       *snapshot->mutable_activity_metadata() = std::move(metadata);
170     });
171   });
172   return memory_profile;
173 }
174 
175 // Fix invalid step ids of snapshots at the beginning/end of the profile or at
176 // the step boundaries. The snapshots with invalid step ids at the beginning get
177 // 0 for their step ids. Those at the step boundaries or at the end get the
178 // previous snapshot's step id + 1.
UpdateStepId(PerAllocatorMemoryProfile * memory_profile)179 void UpdateStepId(PerAllocatorMemoryProfile* memory_profile) {
180   int64 last_valid_step_id = -1;
181   // Snapshots are already sorted in time.
182   for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) {
183     DCHECK(snapshot.has_activity_metadata());
184     if (snapshot.mutable_activity_metadata()->step_id() == kInvalidStepId) {
185       snapshot.mutable_activity_metadata()->set_step_id(last_valid_step_id + 1);
186     } else {
187       last_valid_step_id = snapshot.mutable_activity_metadata()->step_id();
188     }
189   }
190 }
191 
192 // Update the MemoryActivityMetadata for each deallocation event by copying from
193 // matching allocation.
UpdateDeallocation(PerAllocatorMemoryProfile * memory_profile)194 void UpdateDeallocation(PerAllocatorMemoryProfile* memory_profile) {
195   absl::flat_hash_map<uint64 /*address*/, const MemoryActivityMetadata*>
196       addr_metadata_map;
197   for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) {
198     // Match the deallocation with previous allocation based on address.
199     uint64 address = snapshot.activity_metadata().address();
200     if (snapshot.activity_metadata().memory_activity() == DEALLOCATION) {
201       if (addr_metadata_map.contains(address)) {
202         const MemoryActivityMetadata* alloc_meta = addr_metadata_map[address];
203         snapshot.mutable_activity_metadata()->set_tf_op_name(
204             alloc_meta->tf_op_name());
205         snapshot.mutable_activity_metadata()->set_region_type(
206             alloc_meta->region_type());
207         snapshot.mutable_activity_metadata()->set_data_type(
208             alloc_meta->data_type());
209         snapshot.mutable_activity_metadata()->set_tensor_shape(
210             alloc_meta->tensor_shape());
211         // In case of following (unexpected) deallocations to the same chunk
212         // address, leave the metadata as it is (empty or already captured).
213         addr_metadata_map.erase(address);
214       } else {
215         VLOG(2)
216             << "Can't find matching memory allocation for this deallocation: "
217             << snapshot.DebugString();
218       }
219     } else if (!addr_metadata_map.contains(address)) {  // Allocation.
220       addr_metadata_map[address] = &snapshot.activity_metadata();
221     } else {
222       VLOG(2) << "There are two allocations recorded for the same address: "
223               << address
224               << ". The later allocation event is: " << snapshot.DebugString();
225     }
226   }
227   VLOG(2) << "Number of allocations that cannot find matching dealloctions: "
228           << addr_metadata_map.size();
229 }
230 
231 // Return the step id for the peak memory usage data point.
GetPeakMemoryStep(int64 peak_bytes_profile,const PerAllocatorMemoryProfile * memory_profile)232 int64 GetPeakMemoryStep(int64 peak_bytes_profile,
233                         const PerAllocatorMemoryProfile* memory_profile) {
234   int64 peak_bytes_profile_step_id = 0;
235   for (const auto& snapshot : memory_profile->memory_profile_snapshots()) {
236     // Get the step id of the peak memory usage.
237     if (peak_bytes_profile ==
238         snapshot.aggregation_stats().heap_allocated_bytes() +
239             snapshot.aggregation_stats().stack_reserved_bytes()) {
240       DCHECK(snapshot.has_activity_metadata());
241       peak_bytes_profile_step_id = snapshot.activity_metadata().step_id();
242     }
243   }
244   return peak_bytes_profile_step_id;
245 }
246 
247 // Functor that compares (index, metadata) pair to sort in the order of
248 // allocation bytes and requested bytes (descending), as well as TF Op name,
249 // region type, data type, and tensor shape (ascending).
250 struct MetadataComparator {
operator ()tensorflow::profiler::__anone581aad50111::MetadataComparator251   bool operator()(const IndexMetaPair& a, const IndexMetaPair& b) const {
252     const MemoryActivityMetadata* a_meta = a.second;
253     const MemoryActivityMetadata* b_meta = b.second;
254     DCHECK_NE(a_meta, nullptr);
255     DCHECK_NE(b_meta, nullptr);
256 
257     auto lhs =
258         std::make_tuple(-a_meta->allocation_bytes(), -a_meta->requested_bytes(),
259                         a_meta->tf_op_name(), a_meta->region_type(),
260                         a_meta->data_type(), a_meta->tensor_shape());
261     auto rhs =
262         std::make_tuple(-b_meta->allocation_bytes(), -b_meta->requested_bytes(),
263                         b_meta->tf_op_name(), b_meta->region_type(),
264                         b_meta->data_type(), b_meta->tensor_shape());
265     return lhs < rhs;
266   }
267 };
268 
269 // If applicable, add items into active_allocs vector and special_allocations
270 // proto for the unmapped memory usage (in heap) and stack reservation at peak.
InsertSpecialAllocations(int64 unmapped_allocation_bytes,int64 step_id,PerAllocatorMemoryProfile * memory_profile,std::vector<IndexMetaPair> * active_allocs)271 void InsertSpecialAllocations(int64 unmapped_allocation_bytes, int64 step_id,
272                               PerAllocatorMemoryProfile* memory_profile,
273                               std::vector<IndexMetaPair>* active_allocs) {
274   int index = 0;
275   if (unmapped_allocation_bytes > 0) {
276     MemoryActivityMetadata* special_allocation =
277         memory_profile->add_special_allocations();
278     special_allocation->set_memory_activity(ALLOCATION);
279     special_allocation->set_requested_bytes(unmapped_allocation_bytes);
280     special_allocation->set_allocation_bytes(unmapped_allocation_bytes);
281     special_allocation->set_address(0);
282     special_allocation->set_tf_op_name("unused preallocated device memory");
283     special_allocation->set_step_id(step_id);
284     special_allocation->set_region_type("persist/dynamic");
285     special_allocation->set_data_type(
286         tensorflow::DataTypeString(static_cast<tensorflow::DataType>(0)));
287     special_allocation->set_tensor_shape("unknown");
288     active_allocs->push_back({--index, special_allocation});
289   }
290   int64 stack_bytes =
291       memory_profile->profile_summary().peak_stats().stack_reserved_bytes();
292   if (stack_bytes > 0) {
293     MemoryActivityMetadata* special_allocation =
294         memory_profile->add_special_allocations();
295     special_allocation->set_memory_activity(ALLOCATION);
296     special_allocation->set_requested_bytes(stack_bytes);
297     special_allocation->set_allocation_bytes(stack_bytes);
298     special_allocation->set_address(0);
299     special_allocation->set_tf_op_name("stack");
300     special_allocation->set_step_id(step_id);
301     special_allocation->set_region_type("stack");
302     special_allocation->set_data_type(
303         tensorflow::DataTypeString(static_cast<tensorflow::DataType>(0)));
304     special_allocation->set_tensor_shape("unknown");
305     active_allocs->push_back({--index, special_allocation});
306   }
307 }
308 
operator ==(const IndexMetaPair & a,const IndexMetaPair & b)309 bool operator==(const IndexMetaPair& a, const IndexMetaPair& b) {
310   const MemoryActivityMetadata* a_meta = a.second;
311   const MemoryActivityMetadata* b_meta = b.second;
312   return a_meta->allocation_bytes() == b_meta->allocation_bytes() &&
313          a_meta->requested_bytes() == b_meta->requested_bytes() &&
314          a_meta->tf_op_name() == b_meta->tf_op_name() &&
315          a_meta->region_type() == b_meta->region_type() &&
316          a_meta->data_type() == b_meta->data_type() &&
317          a_meta->tensor_shape() == b_meta->tensor_shape();
318 }
319 
320 // Generate the memory breakdown table of active allocations at the peak usage
321 // (within profiling window) and fill each ActiveAllocation proto (i.e. a row).
ProcessActiveAllocations(int64 peak_bytes_profile_step_id,PerAllocatorMemoryProfile * memory_profile)322 void ProcessActiveAllocations(int64 peak_bytes_profile_step_id,
323                               PerAllocatorMemoryProfile* memory_profile) {
324   int64 unmapped_allocation_bytes =
325       memory_profile->profile_summary().peak_stats().heap_allocated_bytes();
326   int64 unmapped_deallocation_bytes = 0;
327   absl::flat_hash_map<int64 /*address*/, IndexMetaPair> active_alloc_map;
328   // Only account for the memory activities in the step that includes peak
329   // memory usage.
330   for (int i = 0; i < memory_profile->memory_profile_snapshots_size(); i++) {
331     const auto& snapshot = memory_profile->memory_profile_snapshots().at(i);
332     DCHECK(snapshot.has_activity_metadata());
333     const MemoryActivityMetadata& metadata = snapshot.activity_metadata();
334     if (snapshot.time_offset_ps() >
335         memory_profile->profile_summary().peak_stats_time_ps())
336       break;
337     if (metadata.step_id() != peak_bytes_profile_step_id) continue;
338 
339     if (metadata.memory_activity() == ALLOCATION) {
340       active_alloc_map[metadata.address()] = {i, &metadata};
341       unmapped_allocation_bytes -= metadata.allocation_bytes();
342     } else {
343       DCHECK_EQ(metadata.memory_activity(), DEALLOCATION);
344       if (active_alloc_map.contains(metadata.address())) {
345         active_alloc_map.erase(metadata.address());
346       } else {
347         unmapped_deallocation_bytes += metadata.allocation_bytes();
348       }
349       unmapped_allocation_bytes += metadata.allocation_bytes();
350     }
351   }
352   // This separates the persistent memory from the freed memory from last step's
353   // allocations.
354   unmapped_allocation_bytes -= unmapped_deallocation_bytes;
355 
356   VLOG(2) << "unmapped_allocation_bytes=" << unmapped_allocation_bytes
357           << ", unmapped_deallocation_bytes=" << unmapped_deallocation_bytes;
358 
359   // Using pair of (index, MemoryActivityMetadata*) so that we can sort by the
360   // metadata, and fetch metadata by indexing the time-sorted snapshots at
361   // frontend.
362   std::vector<IndexMetaPair> active_allocs;
363   for (const auto& address_and_index_meta : active_alloc_map) {
364     active_allocs.push_back(address_and_index_meta.second);
365   }
366 
367   InsertSpecialAllocations(unmapped_allocation_bytes,
368                            peak_bytes_profile_step_id, memory_profile,
369                            &active_allocs);
370 
371   std::sort(active_allocs.begin(), active_allocs.end(), MetadataComparator());
372 
373   // Fill the sorted active_allocations proto messages at peak memory usage.
374   // Merge identical allocations and show occurrences.
375   for (int i = 0, end = active_allocs.size(); i < end; i++) {
376     ActiveAllocation* allocation = memory_profile->add_active_allocations();
377     allocation->set_snapshot_index(active_allocs[i].first);
378     if (active_allocs[i].first < 0) {
379       allocation->set_special_index(-active_allocs[i].first - 1);
380     } else {
381       allocation->set_special_index(-1);
382     }
383     allocation->set_num_occurrences(1);
384     const int last_alloc = active_allocs.size() - 1;
385     while (i < last_alloc && active_allocs[i] == active_allocs[i + 1]) {
386       allocation->set_num_occurrences(allocation->num_occurrences() + 1);
387       i++;
388     }
389   }
390 
391   VLOG(2) << "Distinctive active allocation count="
392           << memory_profile->active_allocations_size();
393 }
394 
395 struct Sample {
396   int64 orig_index;  // original index to the snapshot.
397   MemoryProfileSnapshot* snapshot;
398 };
399 
400 // This function samples max_num_snapshots from snapshots. We first keep the
401 // snapshots referenced by active_allocations in the samples. After this, if
402 // there is still room for more samples, we pick more from snapshots into the
403 // samples. Then, we sort the samples in time (so that they can be correctly
404 // displayed on the timeline). Finally, we need to adjust the original indices
405 // (to snapshots) in active_allocations to the new indices in the samples.
SampleSnapshots(int64 max_num_snapshots,protobuf::RepeatedPtrField<MemoryProfileSnapshot> * snapshots,protobuf::RepeatedPtrField<ActiveAllocation> * active_allocations)406 void SampleSnapshots(
407     int64 max_num_snapshots,
408     protobuf::RepeatedPtrField<MemoryProfileSnapshot>* snapshots,
409     protobuf::RepeatedPtrField<ActiveAllocation>* active_allocations) {
410   if (snapshots->size() <= max_num_snapshots) return;
411 
412   std::vector<Sample> samples;
413 
414   // First, puts the snapshots referenced by active_allocations in samples[].
415   absl::flat_hash_set<int64> allocation_snapshot_indices;
416   for (const auto& allocation : *active_allocations) {
417     auto orig_index = allocation.snapshot_index();
418     if (orig_index < 0) continue;
419     allocation_snapshot_indices.insert(orig_index);
420     samples.push_back({orig_index, &(*snapshots)[orig_index]});
421     if (allocation_snapshot_indices.size() >= max_num_snapshots) break;
422   }
423 
424   // Second, extracts remaining samples from snapshots.
425   int64 num_samples_remained =
426       max_num_snapshots - allocation_snapshot_indices.size();
427   if (num_samples_remained > 0) {
428     std::vector<Sample> remaining;
429     for (int64 i = 0; i < snapshots->size(); i++) {
430       if (allocation_snapshot_indices.contains(i)) continue;
431       // snapshots[i] is not yet sampled; put it in remaining[] for further
432       // consideration.
433       remaining.push_back({i, &(*snapshots)[i]});
434     }
435     // Moves the num_samples_remained snapshots with least free bytes to the
436     // beginning of remaining[].
437     absl::c_partial_sort(
438         remaining, remaining.begin() + num_samples_remained,
439         [](const Sample& a, const Sample& b) {
440           return a.snapshot->aggregation_stats().free_memory_bytes() <
441                  b.snapshot->aggregation_stats().free_memory_bytes();
442         });
443     // Copies the first num_samples_remained in remaining[] to samples[].
444     for (int64 i = 0; i < num_samples_remained; i++)
445       samples.push_back(remaining[i]);
446   }
447 
448   // Third, sorts samples[] in ascending order of time_offset_ps.
449   absl::c_sort(samples, [](const Sample& a, const Sample& b) {
450     return a.snapshot->time_offset_ps() < b.snapshot->time_offset_ps();
451   });
452 
453   // Fourth, constructs a map from the original snapshot index to samples index.
454   absl::flat_hash_map</*original=*/int64, /*new=*/int64> index_map;
455   for (int64 i = 0; i < samples.size(); i++) {
456     index_map[samples[i].orig_index] = i;
457   }
458 
459   // Fifth, changes the original snapshot indices in active_allocations to the
460   // sample indices.
461   for (auto& allocation : *active_allocations) {
462     auto orig_index = allocation.snapshot_index();
463     if (orig_index < 0) continue;
464     auto new_index = gtl::FindWithDefault(index_map, orig_index, -1);
465     allocation.set_snapshot_index(new_index);
466   }
467 
468   // Sixth, replaces *snapshot by samples[]
469   protobuf::RepeatedPtrField<MemoryProfileSnapshot> new_snapshots;
470   new_snapshots.Reserve(samples.size());
471   for (const auto& sample : samples) {
472     *new_snapshots.Add() = std::move(*sample.snapshot);
473   }
474   *snapshots = std::move(new_snapshots);
475 }
476 
477 // Post-process the memory profile to correctly update proto fields, and break
478 // down peak memory usage for each allocator.
ProcessMemoryProfileProto(int64 max_num_snapshots,MemoryProfile * memory_profile)479 void ProcessMemoryProfileProto(int64 max_num_snapshots,
480                                MemoryProfile* memory_profile) {
481   memory_profile->set_num_hosts(1);
482   // Add sorted memory ids within memory profile data to the selection list.
483   for (const auto& id_and_allocator_profile :
484        memory_profile->memory_profile_per_allocator()) {
485     if (!id_and_allocator_profile.second.memory_profile_snapshots().empty()) {
486       memory_profile->add_memory_ids(id_and_allocator_profile.first);
487     }
488   }
489   absl::c_sort(*memory_profile->mutable_memory_ids());
490 
491   for (auto& id_and_allocator_profile :
492        *memory_profile->mutable_memory_profile_per_allocator()) {
493     PerAllocatorMemoryProfile* allocator_memory_profile =
494         &id_and_allocator_profile.second;
495     protobuf::RepeatedPtrField<MemoryProfileSnapshot>* snapshots =
496         allocator_memory_profile->mutable_memory_profile_snapshots();
497     // Sort the memory_profile_snapshots by time_offset_ps (ascending) in proto.
498     absl::c_sort(*snapshots, [](const MemoryProfileSnapshot& a,
499                                 const MemoryProfileSnapshot& b) {
500       return a.time_offset_ps() < b.time_offset_ps();
501     });
502 
503     UpdateStepId(allocator_memory_profile);
504     UpdateDeallocation(allocator_memory_profile);
505 
506     int64 peak_step_id =
507         GetPeakMemoryStep(allocator_memory_profile->profile_summary()
508                               .peak_stats()
509                               .peak_bytes_in_use(),
510                           allocator_memory_profile);
511     ProcessActiveAllocations(peak_step_id, allocator_memory_profile);
512     SampleSnapshots(max_num_snapshots, snapshots,
513                     allocator_memory_profile->mutable_active_allocations());
514   }
515 }
516 
517 template <typename Proto>
ConvertProtoToJson(const Proto & proto_output,std::string * json_output)518 Status ConvertProtoToJson(const Proto& proto_output, std::string* json_output) {
519   protobuf::util::JsonPrintOptions json_options;
520   json_options.always_print_primitive_fields = true;
521   auto status = protobuf::util::MessageToJsonString(proto_output, json_output,
522                                                     json_options);
523   if (!status.ok()) {
524     // Convert error_msg google::protobuf::StringPiece (or absl::string_view) to
525     // tensorflow::StringPiece.
526     auto error_msg = status.message();
527     return errors::Internal(
528         "Could not convert proto to JSON string: ",
529         absl::string_view(error_msg.data(), error_msg.length()));
530   }
531   return Status::OK();
532 }
533 
534 }  // namespace
535 
ConvertXPlaneToMemoryProfile(const XPlane & host_plane,int64 max_num_snapshots)536 MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane,
537                                            int64 max_num_snapshots) {
538   MemoryProfile memory_profile = GenerateMemoryProfile(&host_plane);
539   ProcessMemoryProfileProto(max_num_snapshots, &memory_profile);
540   return memory_profile;
541 }
542 
ConvertXSpaceToMemoryProfileJson(const XSpace & xspace,std::string * json_output)543 Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace,
544                                         std::string* json_output) {
545   if (const XPlane* host_plane =
546           FindPlaneWithName(xspace, kHostThreadsPlaneName)) {
547     MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane);
548     TF_RETURN_IF_ERROR(ConvertProtoToJson(memory_profile, json_output));
549   }
550   return Status::OK();
551 }
552 
553 }  // namespace profiler
554 }  // namespace tensorflow
555