1#!/usr/bin/python
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Classes for generic netlink."""
18
19import collections
20from socket import *  # pylint: disable=wildcard-import
21import struct
22
23import cstruct
24import netlink
25
26### Generic netlink constants. See include/uapi/linux/genetlink.h.
27# The generic netlink control family.
28GENL_ID_CTRL = 16
29
30# Commands.
31CTRL_CMD_GETFAMILY = 3
32
33# Attributes.
34CTRL_ATTR_FAMILY_ID = 1
35CTRL_ATTR_FAMILY_NAME = 2
36CTRL_ATTR_VERSION = 3
37CTRL_ATTR_HDRSIZE = 4
38CTRL_ATTR_MAXATTR = 5
39CTRL_ATTR_OPS = 6
40CTRL_ATTR_MCAST_GROUPS = 7
41
42# Attributes netsted inside CTRL_ATTR_OPS.
43CTRL_ATTR_OP_ID = 1
44CTRL_ATTR_OP_FLAGS = 2
45
46
47# Data structure formats.
48# These aren't constants, they're classes. So, pylint: disable=invalid-name
49Genlmsghdr = cstruct.Struct("genlmsghdr", "BBxx", "cmd version")
50
51
52class GenericNetlink(netlink.NetlinkSocket):
53  """Base class for all generic netlink classes."""
54
55  NL_DEBUG = []
56
57  def __init__(self):
58    super(GenericNetlink, self).__init__(netlink.NETLINK_GENERIC)
59
60  def _SendCommand(self, family, command, version, data, flags):
61    genlmsghdr = Genlmsghdr((command, version))
62    self._SendNlRequest(family, genlmsghdr.Pack() + data, flags)
63
64  def _Dump(self, family, command, version):
65    msg = Genlmsghdr((command, version))
66    return super(GenericNetlink, self)._Dump(family, msg, Genlmsghdr, "")
67
68
69class GenericNetlinkControl(GenericNetlink):
70  """Generic netlink control class.
71
72  This interface is used to manage other generic netlink families. We currently
73  use it only to find the family ID for address families of interest."""
74
75  def _DecodeOps(self, data):
76    ops = []
77    Op = collections.namedtuple("Op", ["id", "flags"])
78    while data:
79      # Skip the nest marker.
80      datalen, index, data = data[:2], data[2:4], data[4:]
81
82      nla, nla_data, data = self._ReadNlAttr(data)
83      if nla.nla_type != CTRL_ATTR_OP_ID:
84        raise ValueError("Expected CTRL_ATTR_OP_ID, got %d" % nla.nla_type)
85      op_id = struct.unpack("=I", nla_data)[0]
86
87      nla, nla_data, data = self._ReadNlAttr(data)
88      if nla.nla_type != CTRL_ATTR_OP_FLAGS:
89        raise ValueError("Expected CTRL_ATTR_OP_FLAGS, got %d" % nla.type)
90      op_flags = struct.unpack("=I", nla_data)[0]
91
92      ops.append(Op(op_id, op_flags))
93    return ops
94
95  def _Decode(self, command, msg, nla_type, nla_data):
96    """Decodes generic netlink control attributes to human-readable format."""
97
98    name = self._GetConstantName(__name__, nla_type, "CTRL_ATTR_")
99
100    if name == "CTRL_ATTR_FAMILY_ID":
101      data = struct.unpack("=H", nla_data)[0]
102    elif name == "CTRL_ATTR_FAMILY_NAME":
103      data = nla_data.strip("\x00")
104    elif name in ["CTRL_ATTR_VERSION", "CTRL_ATTR_HDRSIZE", "CTRL_ATTR_MAXATTR"]:
105      data = struct.unpack("=I", nla_data)[0]
106    elif name == "CTRL_ATTR_OPS":
107      data = self._DecodeOps(nla_data)
108    else:
109      data = nla_data
110
111    return name, data
112
113  def GetFamily(self, name):
114    """Returns the family ID for the specified family name."""
115    data = self._NlAttrStr(CTRL_ATTR_FAMILY_NAME, name)
116    self._SendCommand(GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 0, data, netlink.NLM_F_REQUEST)
117    hdr, attrs = self._GetMsg(Genlmsghdr)
118    return attrs["CTRL_ATTR_FAMILY_ID"]
119
120
121if __name__ == "__main__":
122  g = GenericNetlinkControl()
123  print(g.GetFamily("tcp_metrics"))
124