1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <android-base/logging.h>
20 #include <libnl++/bits.h>
21 
22 #include <linux/netlink.h>
23 
24 #include <optional>
25 
26 namespace android::nl {
27 
28 /**
29  * Buffer wrapper containing netlink structure (e.g. nlmsghdr, nlattr).
30  *
31  * This is a C++-style, memory safe(r) and generic implementation of linux/netlink.h macros.
32  *
33  * While netlink structures contain information about their total length (with payload), they can
34  * not be trusted - the value may either be larger than the buffer message is allocated in or
35  * smaller than the header itself (so it couldn't even fit itself).
36  *
37  * As a solution, Buffer<> keeps track of two lengths (both attribute for header with payload):
38  * - buffer length - how much memory was allocated to a given structure
39  * - declared length - what nlmsg_len or nla_len says how long the structure is
40  *
41  * In most cases buffer length would be larger than declared length (or equal - modulo alignment -
42  * for continuous data). If that's not the case, there is a potential of ouf-of-bounds read which
43  * this template attempts to protect against.
44  */
45 template <typename T>
46 class Buffer {
47   public:
48     /**
49      * Constructs empty buffer of size 0.
50      */
Buffer()51     Buffer() : mData(nullptr), mBufferEnd(nullptr) {}
52 
53     /**
54      * Buffer constructor.
55      *
56      * \param data A pointer to the data the Buffer wraps.
57      * \param bufLen Length of the buffer.
58      */
Buffer(const T * data,size_t bufLen)59     Buffer(const T* data, size_t bufLen) : mData(data), mBufferEnd(pointerAdd(data, bufLen)) {}
60 
61     const T* operator->() const {
62         CHECK(firstOk()) << "buffer can't fit the first element's header";
63         return mData;
64     }
65 
getFirst()66     std::pair<bool, const T&> getFirst() const {
67         if (!ok()) {
68             static const T empty = {};
69             return {false, empty};
70         }
71         return {true, *mData};
72     }
73 
74     /**
75      * Copy the first element of the buffer.
76      *
77      * This is a memory-safe cast operation, useful for reading e.g. uint32_t values
78      * from 1-byte buffer. If the buffer is smaller than the copied type, the rest is
79      * padded with default constructor output (usually zeros).
80      */
copyFirst()81     T copyFirst() const {
82         T val = {};
83         memcpy(&val, mData, std::min(sizeof(val), remainingLength()));
84         return val;
85     }
86 
firstOk()87     bool firstOk() const { return sizeof(T) <= remainingLength(); }
88 
89     template <typename D>
90     const Buffer<D> data(size_t offset = 0) const {
91         return {impl::data<const T, const D>(mData, offset), dataEnd()};
92     }
93 
94     template <typename B>
getOffset(Buffer<B> inner)95     std::optional<uintptr_t> getOffset(Buffer<B> inner) const {
96         const auto selfStart = uintptr_t(mData);
97         const auto selfEnd = uintptr_t(mBufferEnd);
98         const auto innerStart = uintptr_t(inner.mData);
99         const auto innerEnd = uintptr_t(inner.mBufferEnd);
100 
101         if (innerStart < selfStart || innerEnd > selfEnd) return std::nullopt;
102 
103         return innerStart - selfStart;
104     }
105 
106     class iterator {
107       public:
iterator()108         iterator() : mCurrent(nullptr, size_t(0)) {
109             CHECK(isEnd()) << "end() iterator should indicate it's beyond end";
110         }
iterator(const Buffer<T> & buf)111         iterator(const Buffer<T>& buf) : mCurrent(buf) {}
112 
113         iterator operator++() {
114             // mBufferEnd stays the same
115             mCurrent.mData = reinterpret_cast<const T*>(  //
116                     uintptr_t(mCurrent.mData) + impl::align(mCurrent.declaredLength()));
117 
118             return *this;
119         }
120 
121         bool operator==(const iterator& other) const {
122             // all iterators beyond end are the same
123             if (isEnd() && other.isEnd()) return true;
124 
125             return uintptr_t(other.mCurrent.mData) == uintptr_t(mCurrent.mData);
126         }
127 
128         const Buffer<T>& operator*() const { return mCurrent; }
129 
isEnd()130         bool isEnd() const { return !mCurrent.ok(); }
131 
132       protected:
133         Buffer<T> mCurrent;
134     };
begin()135     iterator begin() const { return {*this}; }
end()136     iterator end() const { return {}; }
137 
138     class raw_iterator : public iterator {
139       public:
140         iterator operator++() {
141             this->mCurrent.mData++;  // ignore alignment
142             return *this;
143         }
144         const T& operator*() const { return *this->mCurrent.mData; }
145     };
146 
147     class raw_view {
148       public:
raw_view(const Buffer<T> & buffer)149         raw_view(const Buffer<T>& buffer) : mBuffer(buffer) {}
begin()150         raw_iterator begin() const { return {mBuffer}; }
end()151         raw_iterator end() const { return {}; }
152 
ptr()153         const T* ptr() const { return mBuffer.mData; }
len()154         size_t len() const { return mBuffer.remainingLength(); }
155 
156       private:
157         const Buffer<T> mBuffer;
158     };
159 
getRaw()160     raw_view getRaw() const { return {*this}; }
161 
162   private:
163     const T* mData;
164     const void* mBufferEnd;
165 
Buffer(const T * data,const void * bufferEnd)166     Buffer(const T* data, const void* bufferEnd) : mData(data), mBufferEnd(bufferEnd) {}
167 
ok()168     bool ok() const { return declaredLength() <= remainingLength(); }
169 
170     // to be specialized individually for each T with payload after a header
declaredLengthImpl()171     inline size_t declaredLengthImpl() const { return sizeof(T); }
172 
declaredLength()173     size_t declaredLength() const {
174         // We can't even fit a header, so let's return some absurd high value to trip off
175         // buffer overflow checks.
176         static constexpr size_t badHeaderLength = std::numeric_limits<size_t>::max() / 2;
177 
178         if (sizeof(T) > remainingLength()) return badHeaderLength;
179         const auto len = declaredLengthImpl();
180         if (sizeof(T) > len) return badHeaderLength;
181         return len;
182     }
183 
remainingLength()184     size_t remainingLength() const {
185         auto len = intptr_t(mBufferEnd) - intptr_t(mData);
186         return (len >= 0) ? len : 0;
187     }
188 
dataEnd()189     const void* dataEnd() const {
190         auto declaredEnd = pointerAdd(mData, declaredLength());
191         return std::min(declaredEnd, mBufferEnd);
192     }
193 
pointerAdd(const void * ptr,size_t len)194     static const void* pointerAdd(const void* ptr, size_t len) {
195         return reinterpret_cast<const void*>(uintptr_t(ptr) + len);
196     }
197 
198     template <typename D>
199     friend class Buffer;  // calling private constructor of data buffers
200 };
201 
202 template <>
declaredLengthImpl()203 inline size_t Buffer<nlmsghdr>::declaredLengthImpl() const {
204     return mData->nlmsg_len;
205 }
206 
207 template <>
declaredLengthImpl()208 inline size_t Buffer<nlattr>::declaredLengthImpl() const {
209     return mData->nla_len;
210 }
211 
212 }  // namespace android::nl
213