1# Copyright (C) 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from urllib.parse import urlparse
16
17from .http import TraceProcessorHttp
18from .loader import get_loader
19from .protos import ProtoFactory
20from .shell import load_shell
21
22
23# Custom exception raised if any trace_processor functions return a
24# response with an error defined
25class TraceProcessorException(Exception):
26
27  def __init__(self, message):
28    super().__init__(message)
29
30
31class TraceProcessor:
32
33  # Values of these constants correspond to the QueryResponse message at
34  # protos/perfetto/trace_processor/trace_processor.proto
35  # Value 0 corresponds to CELL_INVALID, which is represented as None in
36  # this class
37  QUERY_CELL_NULL_FIELD_ID = 1
38  QUERY_CELL_VARINT_FIELD_ID = 2
39  QUERY_CELL_FLOAT64_FIELD_ID = 3
40  QUERY_CELL_STRING_FIELD_ID = 4
41  QUERY_CELL_BLOB_FIELD_ID = 5
42
43  # This is the class returned to the user and contains one row of the
44  # resultant query. Each column name is stored as an attribute of this
45  # class, with the value corresponding to the column name and row in
46  # the query results table.
47  class Row(object):
48
49    def __str__(self):
50      return str(self.__dict__)
51
52    def __repr__(self):
53      return self.__dict__
54
55  class QueryResultIterator:
56
57    def __init__(self, column_names, batches):
58      self.__batches = batches
59      self.__column_names = column_names
60      self.__batch_index = 0
61      self.__next_index = 0
62      # TODO(lalitm): Look into changing string_cells to bytes in the protobuf
63      self.__string_cells = memoryview(bytes(batches[0].string_cells, 'utf-8'))
64      self.__string_index = 0
65
66    def get_cell_list(self, proto_index):
67      if proto_index == TraceProcessor.QUERY_CELL_NULL_FIELD_ID:
68        return None
69      elif proto_index == TraceProcessor.QUERY_CELL_VARINT_FIELD_ID:
70        return self.__batches[self.__batch_index].varint_cells
71      elif proto_index == TraceProcessor.QUERY_CELL_FLOAT64_FIELD_ID:
72        return self.__batches[self.__batch_index].float64_cells
73      elif proto_index == TraceProcessor.QUERY_CELL_BLOB_FIELD_ID:
74        return self.__batches[self.__batch_index].blob_cells
75      else:
76        raise TraceProcessorException('Invalid cell type')
77
78    def cells(self):
79      return self.__batches[self.__batch_index].cells
80
81    # To use the query result as a populated Pandas dataframe, this
82    # function must be called directly after calling query inside
83    # TraceProcesor.
84    def as_pandas_dataframe(self):
85      try:
86        import numpy as np
87        import pandas as pd
88
89        df = pd.DataFrame(columns=self.__column_names)
90
91        # Populate the dataframe with the query results
92        while True:
93          # If all cells are read, then check if last batch before
94          # returning the populated dataframe
95          if self.__next_index >= len(self.__batches[self.__batch_index].cells):
96            if self.__batches[self.__batch_index].is_last_batch:
97              ordered_df = df.reset_index(drop=True)
98              return ordered_df
99            self.__batch_index += 1
100            self.__next_index = 0
101            self.__string_cells = memoryview(
102                bytes(self.__batches[self.__batch_index].string_cells, 'utf-8'))
103            self.__string_index = 0
104
105          row = []
106          for num, column_name in enumerate(self.__column_names):
107            cell_type = self.__batches[self.__batch_index].cells[
108                self.__next_index + num]
109            if cell_type == TraceProcessor.QUERY_CELL_STRING_FIELD_ID:
110              start_index = self.__string_index
111              while self.__string_cells[self.__string_index] != 0:
112                self.__string_index += 1
113              row.append(
114                  str(self.__string_cells[start_index:self.__string_index],
115                      'utf-8'))
116              self.__string_index += 1
117            else:
118              cell_list = self.get_cell_list(cell_type)
119              if cell_list is None:
120                row.append(np.NAN)
121              else:
122                row.append(cell_list.pop(0))
123          df.loc[-1] = row
124          df.index = df.index + 1
125          self.__next_index = self.__next_index + len(self.__column_names)
126
127      except ModuleNotFoundError:
128        raise TraceProcessorException(
129            'The sufficient libraries are not installed')
130
131    def __iter__(self):
132      return self
133
134    def __next__(self):
135      # If all cells are read, then check if last batch before raising
136      # StopIteration
137      if self.__next_index >= len(self.cells()):
138        if self.__batches[self.__batch_index].is_last_batch:
139          raise StopIteration
140        self.__batch_index += 1
141        self.__next_index = 0
142        self.__string_cells = memoryview(
143            bytes(self.__batches[self.__batch_index].string_cells, 'utf-8'))
144        self.__string_index = 0
145
146      row = TraceProcessor.Row()
147      for num, column_name in enumerate(self.__column_names):
148        cell_type = self.__batches[self.__batch_index].cells[self.__next_index +
149                                                             num]
150        if cell_type == TraceProcessor.QUERY_CELL_STRING_FIELD_ID:
151          start_index = self.__string_index
152          while self.__string_cells[self.__string_index] != 0:
153            self.__string_index += 1
154          setattr(
155              row, column_name,
156              str(self.__string_cells[start_index:self.__string_index],
157                  'utf-8'))
158          self.__string_index += 1
159        else:
160          cell_list = self.get_cell_list(cell_type)
161          if cell_list is None:
162            setattr(row, column_name, None)
163          else:
164            setattr(row, column_name, cell_list.pop(0))
165      self.__next_index = self.__next_index + len(self.__column_names)
166      return row
167
168  def __init__(self, addr=None, file_path=None, bin_path=None,
169               unique_port=True):
170    # Load trace_processor_shell or access via given address
171    if addr:
172      p = urlparse(addr)
173      tp = TraceProcessorHttp(p.netloc if p.netloc else p.path)
174    else:
175      url, self.subprocess = load_shell(
176          bin_path=bin_path, unique_port=unique_port)
177      tp = TraceProcessorHttp(url)
178    self.http = tp
179    self.protos = ProtoFactory()
180
181    # Parse trace by its file_path into the loaded instance of trace_processor
182    if file_path:
183      get_loader().parse_file(self.http, file_path)
184
185  def query(self, sql):
186    """Executes passed in SQL query using class defined HTTP API, and returns
187    the response as a QueryResultIterator. Raises TraceProcessorException if
188    the response returns with an error.
189
190    Args:
191      sql: SQL query written as a String
192
193    Returns:
194      A class which can iterate through each row of the results table. This
195      can also be converted to a pandas dataframe by calling the
196      as_pandas_dataframe() function after calling query.
197    """
198    response = self.http.execute_query(sql)
199    if response.error:
200      raise TraceProcessorException(response.error)
201
202    return TraceProcessor.QueryResultIterator(response.column_names,
203                                              response.batch)
204
205  def metric(self, metrics):
206    """Returns the metrics data corresponding to the passed in trace metric.
207    Raises TraceProcessorException if the response returns with an error.
208
209    Args:
210      metrics: A list of valid metrics as defined in TraceMetrics
211
212    Returns:
213      The metrics data as a proto message
214    """
215    response = self.http.compute_metric(metrics)
216    if response.error:
217      raise TraceProcessorException(response.error)
218
219    metrics = self.protos.TraceMetrics()
220    metrics.ParseFromString(response.metrics)
221    return metrics
222
223  def enable_metatrace(self):
224    """Enable metatrace for the currently running trace_processor.
225    """
226    return self.http.enable_metatrace()
227
228  def disable_and_read_metatrace(self):
229    """Disable and return the metatrace formed from the currently running
230    trace_processor. This must be enabled before attempting to disable. This
231    returns the serialized bytes of the metatrace data directly. Raises
232    TraceProcessorException if the response returns with an error.
233    """
234    response = self.http.disable_and_read_metatrace()
235    if response.error:
236      raise TraceProcessorException(response.error)
237
238    return response.metatrace
239
240  # TODO(@aninditaghosh): Investigate context managers for
241  # cleaner usage
242  def close(self):
243    if hasattr(self, 'subprocess'):
244      self.subprocess.kill()
245    self.http.conn.close()
246