nautilus_core/
serialization.rs1use bytes::Bytes;
19use serde::{
20 Deserializer,
21 de::{Unexpected, Visitor},
22};
23
24struct BoolVisitor;
25use serde::{Deserialize, Serialize};
26
27pub trait Serializable: Serialize + for<'de> Deserialize<'de> {
29 fn from_json_bytes(data: &[u8]) -> Result<Self, serde_json::Error> {
35 serde_json::from_slice(data)
36 }
37
38 fn to_json_bytes(&self) -> Result<Bytes, serde_json::Error> {
44 serde_json::to_vec(self).map(Bytes::from)
45 }
46}
47
48pub use self::msgpack::{FromMsgPack, MsgPackSerializable, ToMsgPack};
49
50pub mod msgpack {
55 use bytes::Bytes;
56 use serde::{Deserialize, Serialize};
57
58 use super::Serializable;
59
60 pub trait FromMsgPack: for<'de> Deserialize<'de> + Sized {
62 fn from_msgpack_bytes(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
68 rmp_serde::from_slice(data)
69 }
70 }
71
72 pub trait ToMsgPack: Serialize {
74 fn to_msgpack_bytes(&self) -> Result<Bytes, rmp_serde::encode::Error> {
80 rmp_serde::to_vec_named(self).map(Bytes::from)
81 }
82 }
83
84 pub trait MsgPackSerializable: Serializable + FromMsgPack + ToMsgPack {}
88
89 impl<T> FromMsgPack for T where T: Serializable {}
90
91 impl<T> ToMsgPack for T where T: Serializable {}
92
93 impl<T> MsgPackSerializable for T where T: Serializable {}
94}
95
96impl Visitor<'_> for BoolVisitor {
97 type Value = u8;
98
99 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 formatter.write_str("a boolean as u8")
101 }
102
103 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
104 where
105 E: serde::de::Error,
106 {
107 Ok(u8::from(value))
108 }
109
110 #[allow(
111 clippy::cast_possible_truncation,
112 reason = "Intentional for parsing, value range validated"
113 )]
114 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
115 where
116 E: serde::de::Error,
117 {
118 if value > 1 {
123 Err(E::invalid_value(Unexpected::Unsigned(value), &self))
124 } else {
125 Ok(value as u8)
126 }
127 }
128}
129
130pub fn from_bool_as_u8<'de, D>(deserializer: D) -> Result<u8, D::Error>
136where
137 D: Deserializer<'de>,
138{
139 deserializer.deserialize_any(BoolVisitor)
140}
141
142#[cfg(test)]
143mod tests {
144 use rstest::*;
145 use serde::{Deserialize, Serialize};
146
147 use super::{
148 Serializable, from_bool_as_u8,
149 msgpack::{FromMsgPack, ToMsgPack},
150 };
151
152 #[derive(Deserialize)]
153 pub struct TestStruct {
154 #[serde(deserialize_with = "from_bool_as_u8")]
155 pub value: u8,
156 }
157
158 #[rstest]
159 #[case(r#"{"value": true}"#, 1)]
160 #[case(r#"{"value": false}"#, 0)]
161 fn test_deserialize_bool_as_u8_with_boolean(#[case] json_str: &str, #[case] expected: u8) {
162 let test_struct: TestStruct = serde_json::from_str(json_str).unwrap();
163 assert_eq!(test_struct.value, expected);
164 }
165
166 #[rstest]
167 #[case(r#"{"value": 1}"#, 1)]
168 #[case(r#"{"value": 0}"#, 0)]
169 fn test_deserialize_bool_as_u8_with_u64(#[case] json_str: &str, #[case] expected: u8) {
170 let test_struct: TestStruct = serde_json::from_str(json_str).unwrap();
171 assert_eq!(test_struct.value, expected);
172 }
173
174 #[rstest]
175 fn test_deserialize_bool_as_u8_with_invalid_integer() {
176 let json = r#"{"value": 2}"#;
178 let result: Result<TestStruct, _> = serde_json::from_str(json);
179 assert!(result.is_err());
180 }
181
182 #[derive(Serialize, Deserialize, PartialEq, Debug)]
183 struct SerializableTestStruct {
184 id: u32,
185 name: String,
186 value: f64,
187 }
188
189 impl Serializable for SerializableTestStruct {}
190
191 #[rstest]
192 fn test_serializable_json_roundtrip() {
193 let original = SerializableTestStruct {
194 id: 42,
195 name: "test".to_string(),
196 value: std::f64::consts::PI,
197 };
198
199 let json_bytes = original.to_json_bytes().unwrap();
200 let deserialized = SerializableTestStruct::from_json_bytes(&json_bytes).unwrap();
201
202 assert_eq!(original, deserialized);
203 }
204
205 #[rstest]
206 fn test_serializable_msgpack_roundtrip() {
207 let original = SerializableTestStruct {
208 id: 123,
209 name: "msgpack_test".to_string(),
210 value: std::f64::consts::E,
211 };
212
213 let msgpack_bytes = original.to_msgpack_bytes().unwrap();
214 let deserialized = SerializableTestStruct::from_msgpack_bytes(&msgpack_bytes).unwrap();
215
216 assert_eq!(original, deserialized);
217 }
218
219 #[rstest]
220 fn test_serializable_json_invalid_data() {
221 let invalid_json = b"invalid json data";
222 let result = SerializableTestStruct::from_json_bytes(invalid_json);
223 assert!(result.is_err());
224 }
225
226 #[rstest]
227 fn test_serializable_msgpack_invalid_data() {
228 let invalid_msgpack = b"invalid msgpack data";
229 let result = SerializableTestStruct::from_msgpack_bytes(invalid_msgpack);
230 assert!(result.is_err());
231 }
232
233 #[rstest]
234 fn test_serializable_json_empty_values() {
235 let test_struct = SerializableTestStruct {
236 id: 0,
237 name: String::new(),
238 value: 0.0,
239 };
240
241 let json_bytes = test_struct.to_json_bytes().unwrap();
242 let deserialized = SerializableTestStruct::from_json_bytes(&json_bytes).unwrap();
243
244 assert_eq!(test_struct, deserialized);
245 }
246
247 #[rstest]
248 fn test_serializable_msgpack_empty_values() {
249 let test_struct = SerializableTestStruct {
250 id: 0,
251 name: String::new(),
252 value: 0.0,
253 };
254
255 let msgpack_bytes = test_struct.to_msgpack_bytes().unwrap();
256 let deserialized = SerializableTestStruct::from_msgpack_bytes(&msgpack_bytes).unwrap();
257
258 assert_eq!(test_struct, deserialized);
259 }
260}