1 use crate::descriptor::DescriptorProto;
2 use crate::descriptor::FileDescriptorProto;
3 use crate::descriptorx::find_message_by_rust_name;
4 use crate::reflect::acc::FieldAccessor;
5 use crate::reflect::find_message_or_enum::find_message_or_enum;
6 use crate::reflect::find_message_or_enum::MessageOrEnum;
7 use crate::reflect::FieldDescriptor;
8 use crate::Message;
9 use std::collections::HashMap;
10 use std::marker;
11 
12 trait MessageFactory: Send + Sync + 'static {
new_instance(&self) -> Box<dyn Message>13     fn new_instance(&self) -> Box<dyn Message>;
14 }
15 
16 struct MessageFactoryImpl<M>(marker::PhantomData<M>);
17 
18 impl<M> MessageFactory for MessageFactoryImpl<M>
19 where
20     M: 'static + Message + Default + Clone + PartialEq,
21 {
new_instance(&self) -> Box<dyn Message>22     fn new_instance(&self) -> Box<dyn Message> {
23         let m: M = Default::default();
24         Box::new(m)
25     }
26 }
27 
28 /// Dynamic message type
29 pub struct MessageDescriptor {
30     full_name: String,
31     proto: &'static DescriptorProto,
32     factory: &'static dyn MessageFactory,
33     fields: Vec<FieldDescriptor>,
34 
35     index_by_name: HashMap<String, usize>,
36     index_by_name_or_json_name: HashMap<String, usize>,
37     index_by_number: HashMap<u32, usize>,
38 }
39 
40 impl MessageDescriptor {
41     /// Get underlying `DescriptorProto` object.
get_proto(&self) -> &DescriptorProto42     pub fn get_proto(&self) -> &DescriptorProto {
43         self.proto
44     }
45 
46     /// Get a message descriptor for given message type
for_type<M: Message>() -> &'static MessageDescriptor47     pub fn for_type<M: Message>() -> &'static MessageDescriptor {
48         M::descriptor_static()
49     }
50 
compute_full_name(package: &str, path_to_package: &str, proto: &DescriptorProto) -> String51     fn compute_full_name(package: &str, path_to_package: &str, proto: &DescriptorProto) -> String {
52         let mut full_name = package.to_owned();
53         if path_to_package.len() != 0 {
54             if full_name.len() != 0 {
55                 full_name.push('.');
56             }
57             full_name.push_str(path_to_package);
58         }
59         if full_name.len() != 0 {
60             full_name.push('.');
61         }
62         full_name.push_str(proto.get_name());
63         full_name
64     }
65 
66     // Non-generic part of `new` is a separate function
67     // to reduce code bloat from multiple instantiations.
new_non_generic_by_rust_name( rust_name: &'static str, fields: Vec<FieldAccessor>, file: &'static FileDescriptorProto, factory: &'static dyn MessageFactory, ) -> MessageDescriptor68     fn new_non_generic_by_rust_name(
69         rust_name: &'static str,
70         fields: Vec<FieldAccessor>,
71         file: &'static FileDescriptorProto,
72         factory: &'static dyn MessageFactory,
73     ) -> MessageDescriptor {
74         let proto = find_message_by_rust_name(file, rust_name);
75 
76         let mut field_proto_by_name = HashMap::new();
77         for field_proto in proto.message.get_field() {
78             field_proto_by_name.insert(field_proto.get_name(), field_proto);
79         }
80 
81         let mut index_by_name = HashMap::new();
82         let mut index_by_name_or_json_name = HashMap::new();
83         let mut index_by_number = HashMap::new();
84 
85         let mut full_name = file.get_package().to_string();
86         if full_name.len() > 0 {
87             full_name.push('.');
88         }
89         full_name.push_str(proto.message.get_name());
90 
91         let fields: Vec<_> = fields
92             .into_iter()
93             .map(|f| {
94                 let proto = *field_proto_by_name.get(&f.name).unwrap();
95                 FieldDescriptor::new(f, proto)
96             })
97             .collect();
98         for (i, f) in fields.iter().enumerate() {
99             assert!(index_by_number
100                 .insert(f.proto().get_number() as u32, i)
101                 .is_none());
102             assert!(index_by_name
103                 .insert(f.proto().get_name().to_owned(), i)
104                 .is_none());
105             assert!(index_by_name_or_json_name
106                 .insert(f.proto().get_name().to_owned(), i)
107                 .is_none());
108 
109             let json_name = f.json_name().to_owned();
110 
111             if json_name != f.proto().get_name() {
112                 assert!(index_by_name_or_json_name.insert(json_name, i).is_none());
113             }
114         }
115         MessageDescriptor {
116             full_name,
117             proto: proto.message,
118             factory,
119             fields,
120             index_by_name,
121             index_by_name_or_json_name,
122             index_by_number,
123         }
124     }
125 
126     // Non-generic part of `new` is a separate function
127     // to reduce code bloat from multiple instantiations.
new_non_generic_by_pb_name( protobuf_name_to_package: &'static str, fields: Vec<FieldAccessor>, file_descriptor_proto: &'static FileDescriptorProto, factory: &'static dyn MessageFactory, ) -> MessageDescriptor128     fn new_non_generic_by_pb_name(
129         protobuf_name_to_package: &'static str,
130         fields: Vec<FieldAccessor>,
131         file_descriptor_proto: &'static FileDescriptorProto,
132         factory: &'static dyn MessageFactory,
133     ) -> MessageDescriptor {
134         let (path_to_package, proto) =
135             match find_message_or_enum(file_descriptor_proto, protobuf_name_to_package) {
136                 (path_to_package, MessageOrEnum::Message(m)) => (path_to_package, m),
137                 (_, MessageOrEnum::Enum(_)) => panic!("not a message"),
138             };
139 
140         let mut field_proto_by_name = HashMap::new();
141         for field_proto in proto.get_field() {
142             field_proto_by_name.insert(field_proto.get_name(), field_proto);
143         }
144 
145         let mut index_by_name = HashMap::new();
146         let mut index_by_name_or_json_name = HashMap::new();
147         let mut index_by_number = HashMap::new();
148 
149         let full_name = MessageDescriptor::compute_full_name(
150             file_descriptor_proto.get_package(),
151             &path_to_package,
152             &proto,
153         );
154         let fields: Vec<_> = fields
155             .into_iter()
156             .map(|f| {
157                 let proto = *field_proto_by_name.get(&f.name).unwrap();
158                 FieldDescriptor::new(f, proto)
159             })
160             .collect();
161 
162         for (i, f) in fields.iter().enumerate() {
163             assert!(index_by_number
164                 .insert(f.proto().get_number() as u32, i)
165                 .is_none());
166             assert!(index_by_name
167                 .insert(f.proto().get_name().to_owned(), i)
168                 .is_none());
169             assert!(index_by_name_or_json_name
170                 .insert(f.proto().get_name().to_owned(), i)
171                 .is_none());
172 
173             let json_name = f.json_name().to_owned();
174 
175             if json_name != f.proto().get_name() {
176                 assert!(index_by_name_or_json_name.insert(json_name, i).is_none());
177             }
178         }
179         MessageDescriptor {
180             full_name,
181             proto,
182             factory,
183             fields,
184             index_by_name,
185             index_by_name_or_json_name,
186             index_by_number,
187         }
188     }
189 
190     /// Construct a new message descriptor.
191     ///
192     /// This operation is called from generated code and rarely
193     /// need to be called directly.
194     #[doc(hidden)]
195     #[deprecated(
196         since = "2.12",
197         note = "Please regenerate .rs files from .proto files to use newer APIs"
198     )]
new<M: 'static + Message + Default + Clone + PartialEq>( rust_name: &'static str, fields: Vec<FieldAccessor>, file: &'static FileDescriptorProto, ) -> MessageDescriptor199     pub fn new<M: 'static + Message + Default + Clone + PartialEq>(
200         rust_name: &'static str,
201         fields: Vec<FieldAccessor>,
202         file: &'static FileDescriptorProto,
203     ) -> MessageDescriptor {
204         let factory = &MessageFactoryImpl(marker::PhantomData::<M>);
205         MessageDescriptor::new_non_generic_by_rust_name(rust_name, fields, file, factory)
206     }
207 
208     /// Construct a new message descriptor.
209     ///
210     /// This operation is called from generated code and rarely
211     /// need to be called directly.
212     #[doc(hidden)]
new_pb_name<M: 'static + Message + Default + Clone + PartialEq>( protobuf_name_to_package: &'static str, fields: Vec<FieldAccessor>, file_descriptor_proto: &'static FileDescriptorProto, ) -> MessageDescriptor213     pub fn new_pb_name<M: 'static + Message + Default + Clone + PartialEq>(
214         protobuf_name_to_package: &'static str,
215         fields: Vec<FieldAccessor>,
216         file_descriptor_proto: &'static FileDescriptorProto,
217     ) -> MessageDescriptor {
218         let factory = &MessageFactoryImpl(marker::PhantomData::<M>);
219         MessageDescriptor::new_non_generic_by_pb_name(
220             protobuf_name_to_package,
221             fields,
222             file_descriptor_proto,
223             factory,
224         )
225     }
226 
227     /// New empty message
new_instance(&self) -> Box<dyn Message>228     pub fn new_instance(&self) -> Box<dyn Message> {
229         self.factory.new_instance()
230     }
231 
232     /// Message name as given in `.proto` file
name(&self) -> &'static str233     pub fn name(&self) -> &'static str {
234         self.proto.get_name()
235     }
236 
237     /// Fully qualified protobuf message name
full_name(&self) -> &str238     pub fn full_name(&self) -> &str {
239         &self.full_name[..]
240     }
241 
242     /// Message field descriptors.
fields(&self) -> &[FieldDescriptor]243     pub fn fields(&self) -> &[FieldDescriptor] {
244         &self.fields
245     }
246 
247     /// Find message field by protobuf field name
248     ///
249     /// Note: protobuf field name might be different for Rust field name.
get_field_by_name<'a>(&'a self, name: &str) -> Option<&'a FieldDescriptor>250     pub fn get_field_by_name<'a>(&'a self, name: &str) -> Option<&'a FieldDescriptor> {
251         let &index = self.index_by_name.get(name)?;
252         Some(&self.fields[index])
253     }
254 
255     /// Find message field by field name or field JSON name
get_field_by_name_or_json_name<'a>(&'a self, name: &str) -> Option<&'a FieldDescriptor>256     pub fn get_field_by_name_or_json_name<'a>(&'a self, name: &str) -> Option<&'a FieldDescriptor> {
257         let &index = self.index_by_name_or_json_name.get(name)?;
258         Some(&self.fields[index])
259     }
260 
261     /// Find message field by field name
get_field_by_number(&self, number: u32) -> Option<&FieldDescriptor>262     pub fn get_field_by_number(&self, number: u32) -> Option<&FieldDescriptor> {
263         let &index = self.index_by_number.get(&number)?;
264         Some(&self.fields[index])
265     }
266 
267     /// Find field by name
268     // TODO: deprecate
field_by_name<'a>(&'a self, name: &str) -> &'a FieldDescriptor269     pub fn field_by_name<'a>(&'a self, name: &str) -> &'a FieldDescriptor {
270         // TODO: clone is weird
271         let &index = self.index_by_name.get(&name.to_string()).unwrap();
272         &self.fields[index]
273     }
274 
275     /// Find field by number
276     // TODO: deprecate
field_by_number<'a>(&'a self, number: u32) -> &'a FieldDescriptor277     pub fn field_by_number<'a>(&'a self, number: u32) -> &'a FieldDescriptor {
278         let &index = self.index_by_number.get(&number).unwrap();
279         &self.fields[index]
280     }
281 }
282