1 use smallvec::{smallvec, SmallVec};
2 use std::ffi::{CStr, CString, NulError};
3 
4 /// Similar to std::ffi::CString, but avoids heap allocating if the string is
5 /// small enough. Also guarantees it's input is UTF-8 -- used for cases where we
6 /// need to pass a NUL-terminated string to SQLite, and we have a `&str`.
7 #[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
8 pub(crate) struct SmallCString(smallvec::SmallVec<[u8; 16]>);
9 
10 impl SmallCString {
11     #[inline]
new(s: &str) -> Result<Self, NulError>12     pub fn new(s: &str) -> Result<Self, NulError> {
13         if s.as_bytes().contains(&0u8) {
14             return Err(Self::fabricate_nul_error(s));
15         }
16         let mut buf = SmallVec::with_capacity(s.len() + 1);
17         buf.extend_from_slice(s.as_bytes());
18         buf.push(0);
19         let res = Self(buf);
20         res.debug_checks();
21         Ok(res)
22     }
23 
24     #[inline]
as_str(&self) -> &str25     pub fn as_str(&self) -> &str {
26         self.debug_checks();
27         // Constructor takes a &str so this is safe.
28         unsafe { std::str::from_utf8_unchecked(self.as_bytes_without_nul()) }
29     }
30 
31     /// Get the bytes not including the NUL terminator. E.g. the bytes which
32     /// make up our `str`:
33     /// - `SmallCString::new("foo").as_bytes_without_nul() == b"foo"`
34     /// - `SmallCString::new("foo").as_bytes_with_nul() == b"foo\0"
35     #[inline]
as_bytes_without_nul(&self) -> &[u8]36     pub fn as_bytes_without_nul(&self) -> &[u8] {
37         self.debug_checks();
38         &self.0[..self.len()]
39     }
40 
41     /// Get the bytes behind this str *including* the NUL terminator. This
42     /// should never return an empty slice.
43     #[inline]
as_bytes_with_nul(&self) -> &[u8]44     pub fn as_bytes_with_nul(&self) -> &[u8] {
45         self.debug_checks();
46         &self.0
47     }
48 
49     #[inline]
50     #[cfg(debug_assertions)]
debug_checks(&self)51     fn debug_checks(&self) {
52         debug_assert_ne!(self.0.len(), 0);
53         debug_assert_eq!(self.0[self.0.len() - 1], 0);
54         let strbytes = &self.0[..(self.0.len() - 1)];
55         debug_assert!(!strbytes.contains(&0));
56         debug_assert!(std::str::from_utf8(strbytes).is_ok());
57     }
58 
59     #[inline]
60     #[cfg(not(debug_assertions))]
debug_checks(&self)61     fn debug_checks(&self) {}
62 
63     #[inline]
len(&self) -> usize64     pub fn len(&self) -> usize {
65         debug_assert_ne!(self.0.len(), 0);
66         self.0.len() - 1
67     }
68 
69     #[inline]
70     #[allow(unused)] // clippy wants this function.
is_empty(&self) -> bool71     pub fn is_empty(&self) -> bool {
72         self.len() == 0
73     }
74 
75     #[inline]
as_cstr(&self) -> &CStr76     pub fn as_cstr(&self) -> &CStr {
77         let bytes = self.as_bytes_with_nul();
78         debug_assert!(CStr::from_bytes_with_nul(bytes).is_ok());
79         unsafe { CStr::from_bytes_with_nul_unchecked(bytes) }
80     }
81 
82     #[cold]
fabricate_nul_error(b: &str) -> NulError83     fn fabricate_nul_error(b: &str) -> NulError {
84         CString::new(b).unwrap_err()
85     }
86 }
87 
88 impl Default for SmallCString {
89     #[inline]
default() -> Self90     fn default() -> Self {
91         Self(smallvec![0])
92     }
93 }
94 
95 impl std::fmt::Debug for SmallCString {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result96     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97         f.debug_tuple("SmallCString").field(&self.as_str()).finish()
98     }
99 }
100 
101 impl std::ops::Deref for SmallCString {
102     type Target = CStr;
103     #[inline]
deref(&self) -> &CStr104     fn deref(&self) -> &CStr {
105         self.as_cstr()
106     }
107 }
108 
109 impl PartialEq<SmallCString> for str {
110     #[inline]
eq(&self, s: &SmallCString) -> bool111     fn eq(&self, s: &SmallCString) -> bool {
112         s.as_bytes_without_nul() == self.as_bytes()
113     }
114 }
115 
116 impl PartialEq<str> for SmallCString {
117     #[inline]
eq(&self, s: &str) -> bool118     fn eq(&self, s: &str) -> bool {
119         self.as_bytes_without_nul() == s.as_bytes()
120     }
121 }
122 
123 impl std::borrow::Borrow<str> for SmallCString {
124     #[inline]
borrow(&self) -> &str125     fn borrow(&self) -> &str {
126         self.as_str()
127     }
128 }
129 
130 #[cfg(test)]
131 mod test {
132     use super::*;
133 
134     #[test]
test_small_cstring()135     fn test_small_cstring() {
136         // We don't go through the normal machinery for default, so make sure
137         // things work.
138         assert_eq!(SmallCString::default().0, SmallCString::new("").unwrap().0);
139         assert_eq!(SmallCString::new("foo").unwrap().len(), 3);
140         assert_eq!(
141             SmallCString::new("foo").unwrap().as_bytes_with_nul(),
142             b"foo\0"
143         );
144         assert_eq!(
145             SmallCString::new("foo").unwrap().as_bytes_without_nul(),
146             b"foo",
147         );
148 
149         assert_eq!(SmallCString::new("��").unwrap().len(), 4);
150         assert_eq!(
151             SmallCString::new("��").unwrap().0.as_slice(),
152             b"\xf0\x9f\x98\x80\0",
153         );
154         assert_eq!(
155             SmallCString::new("��").unwrap().as_bytes_without_nul(),
156             b"\xf0\x9f\x98\x80",
157         );
158 
159         assert_eq!(SmallCString::new("").unwrap().len(), 0);
160         assert!(SmallCString::new("").unwrap().is_empty());
161 
162         assert_eq!(SmallCString::new("").unwrap().0.as_slice(), b"\0");
163         assert_eq!(SmallCString::new("").unwrap().as_bytes_without_nul(), b"");
164 
165         assert!(SmallCString::new("\0").is_err());
166         assert!(SmallCString::new("\0abc").is_err());
167         assert!(SmallCString::new("abc\0").is_err());
168     }
169 }
170