1 // Copyright 2020 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Provides infrastructure for de/serializing descriptors embedded in Rust data structures.
6 //!
7 //! # Example
8 //!
9 //! ```
10 //! use serde_json::to_string;
11 //! use sys_util::{
12 //!     FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors,
13 //!     deserialize_with_descriptors,
14 //! };
15 //! use tempfile::tempfile;
16 //!
17 //! let tmp_f = tempfile().unwrap();
18 //!
19 //! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File.
20 //! let data = FileSerdeWrapper(tmp_f);
21 //!
22 //! // Wraps Serialize types to collect side channel descriptors as Serialize is called.
23 //! let data_wrapper = SerializeDescriptors::new(&data);
24 //!
25 //! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors
26 //! // as the data structures are serialized by the serializer.
27 //! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
28 //!
29 //! // If data_wrapper contains any side channel descriptor refs
30 //! // (it contains tmp_f in this case), we can retrieve the actual descriptors
31 //! // from the side channel using into_descriptors().
32 //! let out_descriptors = data_wrapper.into_descriptors();
33 //!
34 //! // When sending out_json over some transport, also send out_descriptors.
35 //!
36 //! // For this example, we aren't really transporting data across the process, but we do need to
37 //! // convert the descriptor type.
38 //! let mut safe_descriptors = out_descriptors
39 //!     .iter()
40 //!     .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
41 //!     .collect();
42 //! std::mem::forget(data); // Prevent double drop of tmp_f.
43 //!
44 //! // The deserialize_with_descriptors function is used give the descriptor deserializers access
45 //! // to side channel descriptors.
46 //! let res: FileSerdeWrapper =
47 //!     deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors)
48 //!        .expect("failed to deserialize");
49 //! ```
50 
51 use std::cell::{Cell, RefCell};
52 use std::convert::TryInto;
53 use std::fmt;
54 use std::fs::File;
55 use std::ops::{Deref, DerefMut};
56 use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
57 
58 use serde::de::{self, Error, Visitor};
59 use serde::ser;
60 use serde::{Deserialize, Deserializer, Serialize, Serializer};
61 
62 use crate::{RawDescriptor, SafeDescriptor};
63 
64 thread_local! {
65     static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default();
66 }
67 
68 /// Initializes the thread local storage for descriptor serialization. Fails if it was already
69 /// initialized without an intervening `take_descriptor_dst` on this thread.
init_descriptor_dst() -> Result<(), &'static str>70 fn init_descriptor_dst() -> Result<(), &'static str> {
71     DESCRIPTOR_DST.with(|d| {
72         let mut descriptors = d.borrow_mut();
73         if descriptors.is_some() {
74             return Err(
75                 "attempt to initialize descriptor destination that was already initialized",
76             );
77         }
78         *descriptors = Some(Default::default());
79         Ok(())
80     })
81 }
82 
83 /// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call
84 /// to `init_descriptor_dst` on this thread.
take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str>85 fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> {
86     match DESCRIPTOR_DST.with(|d| d.replace(None)) {
87         Some(d) => Ok(d),
88         None => Err("attempt to take descriptor destination before it was initialized"),
89     }
90 }
91 
92 /// Pushes a descriptor on the thread local destination of descriptors, returning the index in which
93 /// the descriptor was pushed.
94 //
95 /// Returns Err if the thread local destination was not already initialized.
push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str>96 fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> {
97     DESCRIPTOR_DST.with(|d| {
98         d.borrow_mut()
99             .as_mut()
100             .ok_or("attempt to serialize descriptor without descriptor destination")
101             .map(|descriptors| {
102                 let index = descriptors.len();
103                 descriptors.push(rd);
104                 index
105             })
106     })
107 }
108 
109 /// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct.
110 ///
111 /// If there is no parent `SerializeDescriptors` being serialized, this will return an error.
112 ///
113 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
114 /// "...")]` attribute which will make use of this function.
serialize_descriptor<S: Serializer>( rd: &RawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>115 pub fn serialize_descriptor<S: Serializer>(
116     rd: &RawDescriptor,
117     se: S,
118 ) -> std::result::Result<S::Ok, S::Error> {
119     let index = push_descriptor(*rd).map_err(ser::Error::custom)?;
120     se.serialize_u32(
121         index
122             .try_into()
123             .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?,
124     )
125 }
126 
127 /// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when
128 /// given to an ordinary `Serializer`.
129 ///
130 /// This is the corresponding type to use for serialization before using
131 /// `deserialize_with_descriptors`.
132 ///
133 /// # Examples
134 ///
135 /// ```
136 /// use serde_json::to_string;
137 /// use sys_util::{FileSerdeWrapper, SerializeDescriptors};
138 /// use tempfile::tempfile;
139 ///
140 /// let tmp_f = tempfile().unwrap();
141 /// let data = FileSerdeWrapper(tmp_f);
142 /// let data_wrapper = SerializeDescriptors::new(&data);
143 ///
144 /// // Serializes `v` as normal...
145 /// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
146 /// // If `serialize_descriptor` was called, we can capture the descriptors from here.
147 /// let out_descriptors = data_wrapper.into_descriptors();
148 /// ```
149 pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>);
150 
151 impl<'a, T: Serialize> SerializeDescriptors<'a, T> {
new(inner: &'a T) -> Self152     pub fn new(inner: &'a T) -> Self {
153         Self(inner, Default::default())
154     }
155 
into_descriptors(self) -> Vec<RawDescriptor>156     pub fn into_descriptors(self) -> Vec<RawDescriptor> {
157         self.1.into_inner()
158     }
159 }
160 
161 impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,162     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
163     where
164         S: Serializer,
165     {
166         init_descriptor_dst().map_err(ser::Error::custom)?;
167 
168         // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to
169         // take_descriptor_dst afterwards.
170         let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer)));
171         self.1.set(take_descriptor_dst().unwrap());
172         match res {
173             Ok(r) => r,
174             Err(e) => resume_unwind(e),
175         }
176     }
177 }
178 
179 thread_local! {
180     static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default();
181 }
182 
183 /// Sets the thread local storage of descriptors for deserialization. Fails if this was already
184 /// called without a call to `take_descriptor_src` on this thread.
185 ///
186 /// This is given as a collection of `Option` so that unused descriptors can be returned.
set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str>187 fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> {
188     DESCRIPTOR_SRC.with(|d| {
189         let mut src = d.borrow_mut();
190         if src.is_some() {
191             return Err("attempt to set descriptor source that was already set");
192         }
193         *src = Some(descriptors);
194         Ok(())
195     })
196 }
197 
198 /// Takes the thread local storage of descriptors for deserialization. Fails if the storage was
199 /// already taken or never set with `set_descriptor_src`.
200 ///
201 /// If deserialization was done, the descriptors will mostly come back as `None` unless some of them
202 /// were unused.
take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str>203 fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> {
204     DESCRIPTOR_SRC.with(|d| {
205         d.replace(None)
206             .ok_or("attempt to take descriptor source which was never set")
207     })
208 }
209 
210 /// Takes a descriptor at the given index from the thread local source of descriptors.
211 //
212 /// Returns None if the thread local source was not already initialized.
take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str>213 fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> {
214     DESCRIPTOR_SRC.with(|d| {
215         d.borrow_mut()
216             .as_mut()
217             .ok_or("attempt to deserialize descriptor without descriptor source")?
218             .get_mut(index)
219             .ok_or("attempt to deserialize out of bounds descriptor")?
220             .take()
221             .ok_or("attempt to deserialize descriptor that was already taken")
222     })
223 }
224 
225 /// Deserializes a descriptor provided via `deserialize_with_descriptors`.
226 ///
227 /// If `deserialize_with_descriptors` is not in the call chain, this will return an error.
228 ///
229 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
230 /// "...")]` attribute which will make use of this function.
deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error> where D: Deserializer<'de>,231 pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error>
232 where
233     D: Deserializer<'de>,
234 {
235     struct DescriptorVisitor;
236 
237     impl<'de> Visitor<'de> for DescriptorVisitor {
238         type Value = u32;
239 
240         fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
241             formatter.write_str("an integer which fits into a u32")
242         }
243 
244         fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> {
245             Ok(value as _)
246         }
247 
248         fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> {
249             Ok(value as _)
250         }
251 
252         fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> {
253             Ok(value)
254         }
255 
256         fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
257             value.try_into().map_err(E::custom)
258         }
259 
260         fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> {
261             value.try_into().map_err(E::custom)
262         }
263 
264         fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> {
265             value.try_into().map_err(E::custom)
266         }
267 
268         fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> {
269             value.try_into().map_err(E::custom)
270         }
271 
272         fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> {
273             value.try_into().map_err(E::custom)
274         }
275 
276         fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
277             value.try_into().map_err(E::custom)
278         }
279 
280         fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> {
281             value.try_into().map_err(E::custom)
282         }
283     }
284 
285     let index = de.deserialize_u32(DescriptorVisitor)? as usize;
286     take_descriptor(index).map_err(D::Error::custom)
287 }
288 
289 /// Allows the use of any serde deserializer within a closure while providing access to the a set of
290 /// descriptors for use in `deserialize_descriptor`.
291 ///
292 /// This is the corresponding call to use deserialize after using `SerializeDescriptors`.
293 ///
294 /// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an
295 /// error.
deserialize_with_descriptors<F, T, E>( f: F, descriptors: &mut Vec<Option<SafeDescriptor>>, ) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: de::Error,296 pub fn deserialize_with_descriptors<F, T, E>(
297     f: F,
298     descriptors: &mut Vec<Option<SafeDescriptor>>,
299 ) -> Result<T, E>
300 where
301     F: FnOnce() -> Result<T, E>,
302     E: de::Error,
303 {
304     let swap_descriptors = std::mem::take(descriptors);
305     set_descriptor_src(swap_descriptors).map_err(E::custom)?;
306 
307     // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to
308     // take_descriptor_src afterwards.
309     let res = catch_unwind(AssertUnwindSafe(f));
310 
311     // unwrap is used because set_descriptor_src is always called before this, so it should never
312     // panic.
313     *descriptors = take_descriptor_src().unwrap();
314 
315     match res {
316         Ok(r) => r,
317         Err(e) => resume_unwind(e),
318     }
319 }
320 
321 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
322 /// attribute. It only works with fields with `RawDescriptor` type.
323 ///
324 /// # Examples
325 ///
326 /// ```
327 /// use serde::{Deserialize, Serialize};
328 /// use sys_util::RawDescriptor;
329 ///
330 /// #[derive(Serialize, Deserialize)]
331 /// struct RawContainer {
332 ///     #[serde(with = "sys_util::with_raw_descriptor")]
333 ///     rd: RawDescriptor,
334 /// }
335 /// ```
336 pub mod with_raw_descriptor {
337     use crate::{IntoRawDescriptor, RawDescriptor};
338     use serde::Deserializer;
339 
340     pub use super::serialize_descriptor as serialize;
341 
deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error> where D: Deserializer<'de>,342     pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error>
343     where
344         D: Deserializer<'de>,
345     {
346         super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor)
347     }
348 }
349 
350 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
351 /// attribute.
352 ///
353 /// # Examples
354 ///
355 /// ```
356 /// use std::fs::File;
357 /// use serde::{Deserialize, Serialize};
358 /// use sys_util::RawDescriptor;
359 ///
360 /// #[derive(Serialize, Deserialize)]
361 /// struct FileContainer {
362 ///     #[serde(with = "sys_util::with_as_descriptor")]
363 ///     file: File,
364 /// }
365 /// ```
366 pub mod with_as_descriptor {
367     use crate::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor};
368     use serde::{Deserializer, Serializer};
369 
serialize<S: Serializer>( rd: &dyn AsRawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>370     pub fn serialize<S: Serializer>(
371         rd: &dyn AsRawDescriptor,
372         se: S,
373     ) -> std::result::Result<S::Ok, S::Error> {
374         super::serialize_descriptor(&rd.as_raw_descriptor(), se)
375     }
376 
deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error> where D: Deserializer<'de>, T: FromRawDescriptor,377     pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error>
378     where
379         D: Deserializer<'de>,
380         T: FromRawDescriptor,
381     {
382         super::deserialize_descriptor(de)
383             .map(IntoRawDescriptor::into_raw_descriptor)
384             .map(|rd| unsafe { T::from_raw_descriptor(rd) })
385     }
386 }
387 
388 /// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when
389 /// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type
390 /// `Option<File>`.
391 #[derive(Serialize, Deserialize)]
392 #[serde(transparent)]
393 pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File);
394 
395 impl fmt::Debug for FileSerdeWrapper {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result396     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
397         self.0.fmt(f)
398     }
399 }
400 
401 impl From<File> for FileSerdeWrapper {
from(file: File) -> Self402     fn from(file: File) -> Self {
403         FileSerdeWrapper(file)
404     }
405 }
406 
407 impl From<FileSerdeWrapper> for File {
from(f: FileSerdeWrapper) -> File408     fn from(f: FileSerdeWrapper) -> File {
409         f.0
410     }
411 }
412 
413 impl Deref for FileSerdeWrapper {
414     type Target = File;
deref(&self) -> &Self::Target415     fn deref(&self) -> &Self::Target {
416         &self.0
417     }
418 }
419 
420 impl DerefMut for FileSerdeWrapper {
deref_mut(&mut self) -> &mut Self::Target421     fn deref_mut(&mut self) -> &mut Self::Target {
422         &mut self.0
423     }
424 }
425 
426 #[cfg(test)]
427 mod tests {
428     use crate::{
429         deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper,
430         FromRawDescriptor, RawDescriptor, SafeDescriptor, SerializeDescriptors,
431     };
432 
433     use std::collections::HashMap;
434     use std::fs::File;
435     use std::mem::ManuallyDrop;
436     use std::os::unix::io::AsRawFd;
437 
438     use serde::{de::DeserializeOwned, Deserialize, Serialize};
439     use tempfile::tempfile;
440 
deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T441     fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T {
442         let mut safe_descriptors = descriptors
443             .iter()
444             .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
445             .collect();
446 
447         let res =
448             deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors)
449                 .unwrap();
450 
451         assert!(safe_descriptors.iter().all(|v| v.is_none()));
452 
453         res
454     }
455 
456     #[test]
raw()457     fn raw() {
458         #[derive(Serialize, Deserialize, PartialEq, Debug)]
459         struct RawContainer {
460             #[serde(with = "with_raw_descriptor")]
461             rd: RawDescriptor,
462         }
463         // Specifically chosen to not overlap a real descriptor to avoid having to allocate any
464         // descriptors for this test.
465         let fake_rd = 5_123_457 as _;
466         let v = RawContainer { rd: fake_rd };
467         let v_serialize = SerializeDescriptors::new(&v);
468         let json = serde_json::to_string(&v_serialize).unwrap();
469         let descriptors = v_serialize.into_descriptors();
470         let res = deserialize(&json, &descriptors);
471         assert_eq!(v, res);
472     }
473 
474     #[test]
file()475     fn file() {
476         #[derive(Serialize, Deserialize)]
477         struct FileContainer {
478             #[serde(with = "with_as_descriptor")]
479             file: File,
480         }
481 
482         let v = FileContainer {
483             file: tempfile().unwrap(),
484         };
485         let v_serialize = SerializeDescriptors::new(&v);
486         let json = serde_json::to_string(&v_serialize).unwrap();
487         let descriptors = v_serialize.into_descriptors();
488         let v = ManuallyDrop::new(v);
489         let res: FileContainer = deserialize(&json, &descriptors);
490         assert_eq!(v.file.as_raw_fd(), res.file.as_raw_fd());
491     }
492 
493     #[test]
option()494     fn option() {
495         #[derive(Serialize, Deserialize)]
496         struct TestOption {
497             a: Option<FileSerdeWrapper>,
498             b: Option<FileSerdeWrapper>,
499         }
500 
501         let v = TestOption {
502             a: None,
503             b: Some(tempfile().unwrap().into()),
504         };
505         let v_serialize = SerializeDescriptors::new(&v);
506         let json = serde_json::to_string(&v_serialize).unwrap();
507         let descriptors = v_serialize.into_descriptors();
508         let v = ManuallyDrop::new(v);
509         let res: TestOption = deserialize(&json, &descriptors);
510         assert!(res.a.is_none());
511         assert!(res.b.is_some());
512         assert_eq!(
513             v.b.as_ref().unwrap().as_raw_fd(),
514             res.b.unwrap().as_raw_fd()
515         );
516     }
517 
518     #[test]
map()519     fn map() {
520         let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new();
521         v.insert("a".into(), tempfile().unwrap().into());
522         v.insert("b".into(), tempfile().unwrap().into());
523         v.insert("c".into(), tempfile().unwrap().into());
524         let v_serialize = SerializeDescriptors::new(&v);
525         let json = serde_json::to_string(&v_serialize).unwrap();
526         let descriptors = v_serialize.into_descriptors();
527         // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is
528         // done this way to prevent a double close of the files (which should reside in `res`)
529         // without triggering the leak sanitizer on `v`'s HashMap heap memory.
530         let v: HashMap<_, _> = v
531             .into_iter()
532             .map(|(k, v)| (k, ManuallyDrop::new(v)))
533             .collect();
534         let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors);
535 
536         assert_eq!(v.len(), res.len());
537         for (k, v) in v.iter() {
538             assert_eq!(res.get(k).unwrap().as_raw_fd(), v.as_raw_fd());
539         }
540     }
541 }
542