// Copyright 2020 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. //! Provides infrastructure for de/serializing descriptors embedded in Rust data structures. //! //! # Example //! //! ``` //! use serde_json::to_string; //! use sys_util::{ //! FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors, //! deserialize_with_descriptors, //! }; //! use tempfile::tempfile; //! //! let tmp_f = tempfile().unwrap(); //! //! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File. //! let data = FileSerdeWrapper(tmp_f); //! //! // Wraps Serialize types to collect side channel descriptors as Serialize is called. //! let data_wrapper = SerializeDescriptors::new(&data); //! //! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors //! // as the data structures are serialized by the serializer. //! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize"); //! //! // If data_wrapper contains any side channel descriptor refs //! // (it contains tmp_f in this case), we can retrieve the actual descriptors //! // from the side channel using into_descriptors(). //! let out_descriptors = data_wrapper.into_descriptors(); //! //! // When sending out_json over some transport, also send out_descriptors. //! //! // For this example, we aren't really transporting data across the process, but we do need to //! // convert the descriptor type. //! let mut safe_descriptors = out_descriptors //! .iter() //! .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) })) //! .collect(); //! std::mem::forget(data); // Prevent double drop of tmp_f. //! //! // The deserialize_with_descriptors function is used give the descriptor deserializers access //! // to side channel descriptors. //! let res: FileSerdeWrapper = //! deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors) //! .expect("failed to deserialize"); //! ``` use std::cell::{Cell, RefCell}; use std::convert::TryInto; use std::fmt; use std::fs::File; use std::ops::{Deref, DerefMut}; use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; use serde::de::{self, Error, Visitor}; use serde::ser; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::{RawDescriptor, SafeDescriptor}; thread_local! { static DESCRIPTOR_DST: RefCell>> = Default::default(); } /// Initializes the thread local storage for descriptor serialization. Fails if it was already /// initialized without an intervening `take_descriptor_dst` on this thread. fn init_descriptor_dst() -> Result<(), &'static str> { DESCRIPTOR_DST.with(|d| { let mut descriptors = d.borrow_mut(); if descriptors.is_some() { return Err( "attempt to initialize descriptor destination that was already initialized", ); } *descriptors = Some(Default::default()); Ok(()) }) } /// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call /// to `init_descriptor_dst` on this thread. fn take_descriptor_dst() -> Result, &'static str> { match DESCRIPTOR_DST.with(|d| d.replace(None)) { Some(d) => Ok(d), None => Err("attempt to take descriptor destination before it was initialized"), } } /// Pushes a descriptor on the thread local destination of descriptors, returning the index in which /// the descriptor was pushed. // /// Returns Err if the thread local destination was not already initialized. fn push_descriptor(rd: RawDescriptor) -> Result { DESCRIPTOR_DST.with(|d| { d.borrow_mut() .as_mut() .ok_or("attempt to serialize descriptor without descriptor destination") .map(|descriptors| { let index = descriptors.len(); descriptors.push(rd); index }) }) } /// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct. /// /// If there is no parent `SerializeDescriptors` being serialized, this will return an error. /// /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with = /// "...")]` attribute which will make use of this function. pub fn serialize_descriptor( rd: &RawDescriptor, se: S, ) -> std::result::Result { let index = push_descriptor(*rd).map_err(ser::Error::custom)?; se.serialize_u32( index .try_into() .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?, ) } /// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when /// given to an ordinary `Serializer`. /// /// This is the corresponding type to use for serialization before using /// `deserialize_with_descriptors`. /// /// # Examples /// /// ``` /// use serde_json::to_string; /// use sys_util::{FileSerdeWrapper, SerializeDescriptors}; /// use tempfile::tempfile; /// /// let tmp_f = tempfile().unwrap(); /// let data = FileSerdeWrapper(tmp_f); /// let data_wrapper = SerializeDescriptors::new(&data); /// /// // Serializes `v` as normal... /// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize"); /// // If `serialize_descriptor` was called, we can capture the descriptors from here. /// let out_descriptors = data_wrapper.into_descriptors(); /// ``` pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell>); impl<'a, T: Serialize> SerializeDescriptors<'a, T> { pub fn new(inner: &'a T) -> Self { Self(inner, Default::default()) } pub fn into_descriptors(self) -> Vec { self.1.into_inner() } } impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> { fn serialize(&self, serializer: S) -> Result where S: Serializer, { init_descriptor_dst().map_err(ser::Error::custom)?; // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to // take_descriptor_dst afterwards. let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer))); self.1.set(take_descriptor_dst().unwrap()); match res { Ok(r) => r, Err(e) => resume_unwind(e), } } } thread_local! { static DESCRIPTOR_SRC: RefCell>>> = Default::default(); } /// Sets the thread local storage of descriptors for deserialization. Fails if this was already /// called without a call to `take_descriptor_src` on this thread. /// /// This is given as a collection of `Option` so that unused descriptors can be returned. fn set_descriptor_src(descriptors: Vec>) -> Result<(), &'static str> { DESCRIPTOR_SRC.with(|d| { let mut src = d.borrow_mut(); if src.is_some() { return Err("attempt to set descriptor source that was already set"); } *src = Some(descriptors); Ok(()) }) } /// Takes the thread local storage of descriptors for deserialization. Fails if the storage was /// already taken or never set with `set_descriptor_src`. /// /// If deserialization was done, the descriptors will mostly come back as `None` unless some of them /// were unused. fn take_descriptor_src() -> Result>, &'static str> { DESCRIPTOR_SRC.with(|d| { d.replace(None) .ok_or("attempt to take descriptor source which was never set") }) } /// Takes a descriptor at the given index from the thread local source of descriptors. // /// Returns None if the thread local source was not already initialized. fn take_descriptor(index: usize) -> Result { DESCRIPTOR_SRC.with(|d| { d.borrow_mut() .as_mut() .ok_or("attempt to deserialize descriptor without descriptor source")? .get_mut(index) .ok_or("attempt to deserialize out of bounds descriptor")? .take() .ok_or("attempt to deserialize descriptor that was already taken") }) } /// Deserializes a descriptor provided via `deserialize_with_descriptors`. /// /// If `deserialize_with_descriptors` is not in the call chain, this will return an error. /// /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with = /// "...")]` attribute which will make use of this function. pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result where D: Deserializer<'de>, { struct DescriptorVisitor; impl<'de> Visitor<'de> for DescriptorVisitor { type Value = u32; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str("an integer which fits into a u32") } fn visit_u8(self, value: u8) -> Result { Ok(value as _) } fn visit_u16(self, value: u16) -> Result { Ok(value as _) } fn visit_u32(self, value: u32) -> Result { Ok(value) } fn visit_u64(self, value: u64) -> Result { value.try_into().map_err(E::custom) } fn visit_u128(self, value: u128) -> Result { value.try_into().map_err(E::custom) } fn visit_i8(self, value: i8) -> Result { value.try_into().map_err(E::custom) } fn visit_i16(self, value: i16) -> Result { value.try_into().map_err(E::custom) } fn visit_i32(self, value: i32) -> Result { value.try_into().map_err(E::custom) } fn visit_i64(self, value: i64) -> Result { value.try_into().map_err(E::custom) } fn visit_i128(self, value: i128) -> Result { value.try_into().map_err(E::custom) } } let index = de.deserialize_u32(DescriptorVisitor)? as usize; take_descriptor(index).map_err(D::Error::custom) } /// Allows the use of any serde deserializer within a closure while providing access to the a set of /// descriptors for use in `deserialize_descriptor`. /// /// This is the corresponding call to use deserialize after using `SerializeDescriptors`. /// /// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an /// error. pub fn deserialize_with_descriptors( f: F, descriptors: &mut Vec>, ) -> Result where F: FnOnce() -> Result, E: de::Error, { let swap_descriptors = std::mem::take(descriptors); set_descriptor_src(swap_descriptors).map_err(E::custom)?; // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to // take_descriptor_src afterwards. let res = catch_unwind(AssertUnwindSafe(f)); // unwrap is used because set_descriptor_src is always called before this, so it should never // panic. *descriptors = take_descriptor_src().unwrap(); match res { Ok(r) => r, Err(e) => resume_unwind(e), } } /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]` /// attribute. It only works with fields with `RawDescriptor` type. /// /// # Examples /// /// ``` /// use serde::{Deserialize, Serialize}; /// use sys_util::RawDescriptor; /// /// #[derive(Serialize, Deserialize)] /// struct RawContainer { /// #[serde(with = "sys_util::with_raw_descriptor")] /// rd: RawDescriptor, /// } /// ``` pub mod with_raw_descriptor { use crate::{IntoRawDescriptor, RawDescriptor}; use serde::Deserializer; pub use super::serialize_descriptor as serialize; pub fn deserialize<'de, D>(de: D) -> std::result::Result where D: Deserializer<'de>, { super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor) } } /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]` /// attribute. /// /// # Examples /// /// ``` /// use std::fs::File; /// use serde::{Deserialize, Serialize}; /// use sys_util::RawDescriptor; /// /// #[derive(Serialize, Deserialize)] /// struct FileContainer { /// #[serde(with = "sys_util::with_as_descriptor")] /// file: File, /// } /// ``` pub mod with_as_descriptor { use crate::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor}; use serde::{Deserializer, Serializer}; pub fn serialize( rd: &dyn AsRawDescriptor, se: S, ) -> std::result::Result { super::serialize_descriptor(&rd.as_raw_descriptor(), se) } pub fn deserialize<'de, D, T>(de: D) -> std::result::Result where D: Deserializer<'de>, T: FromRawDescriptor, { super::deserialize_descriptor(de) .map(IntoRawDescriptor::into_raw_descriptor) .map(|rd| unsafe { T::from_raw_descriptor(rd) }) } } /// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when /// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type /// `Option`. #[derive(Serialize, Deserialize)] #[serde(transparent)] pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File); impl fmt::Debug for FileSerdeWrapper { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.0.fmt(f) } } impl From for FileSerdeWrapper { fn from(file: File) -> Self { FileSerdeWrapper(file) } } impl From for File { fn from(f: FileSerdeWrapper) -> File { f.0 } } impl Deref for FileSerdeWrapper { type Target = File; fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for FileSerdeWrapper { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } #[cfg(test)] mod tests { use crate::{ deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, FromRawDescriptor, RawDescriptor, SafeDescriptor, SerializeDescriptors, }; use std::collections::HashMap; use std::fs::File; use std::mem::ManuallyDrop; use std::os::unix::io::AsRawFd; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tempfile::tempfile; fn deserialize(json: &str, descriptors: &[RawDescriptor]) -> T { let mut safe_descriptors = descriptors .iter() .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) })) .collect(); let res = deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors) .unwrap(); assert!(safe_descriptors.iter().all(|v| v.is_none())); res } #[test] fn raw() { #[derive(Serialize, Deserialize, PartialEq, Debug)] struct RawContainer { #[serde(with = "with_raw_descriptor")] rd: RawDescriptor, } // Specifically chosen to not overlap a real descriptor to avoid having to allocate any // descriptors for this test. let fake_rd = 5_123_457 as _; let v = RawContainer { rd: fake_rd }; let v_serialize = SerializeDescriptors::new(&v); let json = serde_json::to_string(&v_serialize).unwrap(); let descriptors = v_serialize.into_descriptors(); let res = deserialize(&json, &descriptors); assert_eq!(v, res); } #[test] fn file() { #[derive(Serialize, Deserialize)] struct FileContainer { #[serde(with = "with_as_descriptor")] file: File, } let v = FileContainer { file: tempfile().unwrap(), }; let v_serialize = SerializeDescriptors::new(&v); let json = serde_json::to_string(&v_serialize).unwrap(); let descriptors = v_serialize.into_descriptors(); let v = ManuallyDrop::new(v); let res: FileContainer = deserialize(&json, &descriptors); assert_eq!(v.file.as_raw_fd(), res.file.as_raw_fd()); } #[test] fn option() { #[derive(Serialize, Deserialize)] struct TestOption { a: Option, b: Option, } let v = TestOption { a: None, b: Some(tempfile().unwrap().into()), }; let v_serialize = SerializeDescriptors::new(&v); let json = serde_json::to_string(&v_serialize).unwrap(); let descriptors = v_serialize.into_descriptors(); let v = ManuallyDrop::new(v); let res: TestOption = deserialize(&json, &descriptors); assert!(res.a.is_none()); assert!(res.b.is_some()); assert_eq!( v.b.as_ref().unwrap().as_raw_fd(), res.b.unwrap().as_raw_fd() ); } #[test] fn map() { let mut v: HashMap = HashMap::new(); v.insert("a".into(), tempfile().unwrap().into()); v.insert("b".into(), tempfile().unwrap().into()); v.insert("c".into(), tempfile().unwrap().into()); let v_serialize = SerializeDescriptors::new(&v); let json = serde_json::to_string(&v_serialize).unwrap(); let descriptors = v_serialize.into_descriptors(); // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is // done this way to prevent a double close of the files (which should reside in `res`) // without triggering the leak sanitizer on `v`'s HashMap heap memory. let v: HashMap<_, _> = v .into_iter() .map(|(k, v)| (k, ManuallyDrop::new(v))) .collect(); let res: HashMap = deserialize(&json, &descriptors); assert_eq!(v.len(), res.len()); for (k, v) in v.iter() { assert_eq!(res.get(k).unwrap().as_raw_fd(), v.as_raw_fd()); } } }