1use std::str::FromStr;
26
27use bytes::Bytes;
28use rust_decimal::Decimal;
29use serde::{
30 Deserialize, Deserializer, Serialize, Serializer,
31 de::{Error, Unexpected, Visitor},
32 ser::SerializeSeq,
33};
34use ustr::Ustr;
35
36struct BoolVisitor;
37
38pub trait Serializable: Serialize + for<'de> Deserialize<'de> {
40 fn from_json_bytes(data: &[u8]) -> Result<Self, serde_json::Error> {
46 serde_json::from_slice(data)
47 }
48
49 fn to_json_bytes(&self) -> Result<Bytes, serde_json::Error> {
55 serde_json::to_vec(self).map(Bytes::from)
56 }
57}
58
59pub use self::msgpack::{FromMsgPack, MsgPackSerializable, ToMsgPack};
60
61pub mod msgpack {
66 use bytes::Bytes;
67 use serde::{Deserialize, Serialize};
68
69 use super::Serializable;
70
71 pub trait FromMsgPack: for<'de> Deserialize<'de> + Sized {
73 fn from_msgpack_bytes(data: &[u8]) -> Result<Self, rmp_serde::decode::Error> {
79 rmp_serde::from_slice(data)
80 }
81 }
82
83 pub trait ToMsgPack: Serialize {
85 fn to_msgpack_bytes(&self) -> Result<Bytes, rmp_serde::encode::Error> {
91 rmp_serde::to_vec_named(self).map(Bytes::from)
92 }
93 }
94
95 pub trait MsgPackSerializable: Serializable + FromMsgPack + ToMsgPack {}
99
100 impl<T> FromMsgPack for T where T: Serializable {}
101
102 impl<T> ToMsgPack for T where T: Serializable {}
103
104 impl<T> MsgPackSerializable for T where T: Serializable {}
105}
106
107impl Visitor<'_> for BoolVisitor {
108 type Value = u8;
109
110 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 formatter.write_str("a boolean as u8")
112 }
113
114 fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
115 where
116 E: serde::de::Error,
117 {
118 Ok(u8::from(value))
119 }
120
121 #[allow(
122 clippy::cast_possible_truncation,
123 reason = "Intentional for parsing, value range validated"
124 )]
125 fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
126 where
127 E: serde::de::Error,
128 {
129 if value > 1 {
134 Err(E::invalid_value(Unexpected::Unsigned(value), &self))
135 } else {
136 Ok(value as u8)
137 }
138 }
139}
140
141#[must_use]
145pub const fn default_true() -> bool {
146 true
147}
148
149#[must_use]
153pub const fn default_false() -> bool {
154 false
155}
156
157pub fn from_bool_as_u8<'de, D>(deserializer: D) -> Result<u8, D::Error>
163where
164 D: Deserializer<'de>,
165{
166 deserializer.deserialize_any(BoolVisitor)
167}
168
169pub fn deserialize_decimal_from_str<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
178where
179 D: Deserializer<'de>,
180{
181 let s = String::deserialize(deserializer)?;
182 Decimal::from_str(&s).map_err(D::Error::custom)
183}
184
185pub fn deserialize_decimal_or_zero<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
193where
194 D: Deserializer<'de>,
195{
196 let s: String = Deserialize::deserialize(deserializer)?;
197 if s.is_empty() || s == "0" {
198 Ok(Decimal::ZERO)
199 } else {
200 Decimal::from_str(&s).map_err(D::Error::custom)
201 }
202}
203
204pub fn deserialize_optional_decimal<'de, D>(deserializer: D) -> Result<Option<Decimal>, D::Error>
212where
213 D: Deserializer<'de>,
214{
215 let s: String = Deserialize::deserialize(deserializer)?;
216 if s.is_empty() || s == "0" {
217 Ok(None)
218 } else {
219 Decimal::from_str(&s).map(Some).map_err(D::Error::custom)
220 }
221}
222
223pub fn deserialize_optional_decimal_from_str<'de, D>(
232 deserializer: D,
233) -> Result<Option<Decimal>, D::Error>
234where
235 D: Deserializer<'de>,
236{
237 let opt = Option::<String>::deserialize(deserializer)?;
238 match opt {
239 Some(s) if !s.is_empty() => Decimal::from_str(&s).map(Some).map_err(D::Error::custom),
240 _ => Ok(None),
241 }
242}
243
244pub fn deserialize_optional_decimal_or_zero<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
252where
253 D: Deserializer<'de>,
254{
255 let opt: Option<String> = Deserialize::deserialize(deserializer)?;
256 match opt {
257 None => Ok(Decimal::ZERO),
258 Some(s) if s.is_empty() || s == "0" => Ok(Decimal::ZERO),
259 Some(s) => Decimal::from_str(&s).map_err(D::Error::custom),
260 }
261}
262
263pub fn deserialize_vec_decimal_from_str<'de, D>(deserializer: D) -> Result<Vec<Decimal>, D::Error>
269where
270 D: Deserializer<'de>,
271{
272 let strings = Vec::<String>::deserialize(deserializer)?;
273 strings
274 .into_iter()
275 .map(|s| Decimal::from_str(&s).map_err(D::Error::custom))
276 .collect()
277}
278
279pub fn serialize_decimal_as_str<S>(decimal: &Decimal, serializer: S) -> Result<S::Ok, S::Error>
285where
286 S: Serializer,
287{
288 serializer.serialize_str(&decimal.normalize().to_string())
289}
290
291pub fn serialize_optional_decimal_as_str<S>(
297 decimal: &Option<Decimal>,
298 serializer: S,
299) -> Result<S::Ok, S::Error>
300where
301 S: Serializer,
302{
303 match decimal {
304 Some(d) => serializer.serialize_str(&d.normalize().to_string()),
305 None => serializer.serialize_none(),
306 }
307}
308
309pub fn serialize_vec_decimal_as_str<S>(
315 decimals: &Vec<Decimal>,
316 serializer: S,
317) -> Result<S::Ok, S::Error>
318where
319 S: Serializer,
320{
321 let mut seq = serializer.serialize_seq(Some(decimals.len()))?;
322 for decimal in decimals {
323 seq.serialize_element(&decimal.normalize().to_string())?;
324 }
325 seq.end()
326}
327
328pub fn parse_decimal(s: &str) -> anyhow::Result<Decimal> {
334 Decimal::from_str(s).map_err(|e| anyhow::anyhow!("Failed to parse decimal from '{s}': {e}"))
335}
336
337pub fn parse_optional_decimal(s: &Option<String>) -> anyhow::Result<Option<Decimal>> {
343 match s {
344 None => Ok(None),
345 Some(s) if s.is_empty() => Ok(None),
346 Some(s) => parse_decimal(s).map(Some),
347 }
348}
349
350pub fn deserialize_empty_string_as_none<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
362where
363 D: Deserializer<'de>,
364{
365 let opt = Option::<String>::deserialize(deserializer)?;
366 Ok(opt.filter(|s| !s.is_empty()))
367}
368
369pub fn deserialize_empty_ustr_as_none<'de, D>(deserializer: D) -> Result<Option<Ustr>, D::Error>
375where
376 D: Deserializer<'de>,
377{
378 let opt = Option::<Ustr>::deserialize(deserializer)?;
379 Ok(opt.filter(|s| !s.is_empty()))
380}
381
382pub fn deserialize_string_to_u8<'de, D>(deserializer: D) -> Result<u8, D::Error>
390where
391 D: Deserializer<'de>,
392{
393 let s: String = Deserialize::deserialize(deserializer)?;
394 if s.is_empty() {
395 return Ok(0);
396 }
397 s.parse::<u8>().map_err(D::Error::custom)
398}
399
400pub fn deserialize_string_to_u64<'de, D>(deserializer: D) -> Result<u64, D::Error>
408where
409 D: Deserializer<'de>,
410{
411 let s = String::deserialize(deserializer)?;
412 if s.is_empty() {
413 Ok(0)
414 } else {
415 s.parse::<u64>().map_err(D::Error::custom)
416 }
417}
418
419pub fn deserialize_optional_string_to_u64<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
427where
428 D: Deserializer<'de>,
429{
430 let s: Option<String> = Option::deserialize(deserializer)?;
431 match s {
432 Some(s) if s.is_empty() => Ok(None),
433 Some(s) => s.parse().map(Some).map_err(D::Error::custom),
434 None => Ok(None),
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use rstest::*;
441 use rust_decimal::Decimal;
442 use rust_decimal_macros::dec;
443 use serde::{Deserialize, Serialize};
444 use ustr::Ustr;
445
446 use super::{
447 Serializable, deserialize_decimal_from_str, deserialize_decimal_or_zero,
448 deserialize_empty_string_as_none, deserialize_empty_ustr_as_none,
449 deserialize_optional_decimal, deserialize_optional_decimal_or_zero,
450 deserialize_optional_string_to_u64, deserialize_string_to_u8, deserialize_string_to_u64,
451 deserialize_vec_decimal_from_str, from_bool_as_u8,
452 msgpack::{FromMsgPack, ToMsgPack},
453 parse_decimal, parse_optional_decimal, serialize_decimal_as_str,
454 serialize_optional_decimal_as_str, serialize_vec_decimal_as_str,
455 };
456
457 #[derive(Deserialize)]
458 pub struct TestStruct {
459 #[serde(deserialize_with = "from_bool_as_u8")]
460 pub value: u8,
461 }
462
463 #[rstest]
464 #[case(r#"{"value": true}"#, 1)]
465 #[case(r#"{"value": false}"#, 0)]
466 fn test_deserialize_bool_as_u8_with_boolean(#[case] json_str: &str, #[case] expected: u8) {
467 let test_struct: TestStruct = serde_json::from_str(json_str).unwrap();
468 assert_eq!(test_struct.value, expected);
469 }
470
471 #[rstest]
472 #[case(r#"{"value": 1}"#, 1)]
473 #[case(r#"{"value": 0}"#, 0)]
474 fn test_deserialize_bool_as_u8_with_u64(#[case] json_str: &str, #[case] expected: u8) {
475 let test_struct: TestStruct = serde_json::from_str(json_str).unwrap();
476 assert_eq!(test_struct.value, expected);
477 }
478
479 #[rstest]
480 fn test_deserialize_bool_as_u8_with_invalid_integer() {
481 let json = r#"{"value": 2}"#;
483 let result: Result<TestStruct, _> = serde_json::from_str(json);
484 assert!(result.is_err());
485 }
486
487 #[derive(Serialize, Deserialize, PartialEq, Debug)]
488 struct SerializableTestStruct {
489 id: u32,
490 name: String,
491 value: f64,
492 }
493
494 impl Serializable for SerializableTestStruct {}
495
496 #[rstest]
497 fn test_serializable_json_roundtrip() {
498 let original = SerializableTestStruct {
499 id: 42,
500 name: "test".to_string(),
501 value: std::f64::consts::PI,
502 };
503
504 let json_bytes = original.to_json_bytes().unwrap();
505 let deserialized = SerializableTestStruct::from_json_bytes(&json_bytes).unwrap();
506
507 assert_eq!(original, deserialized);
508 }
509
510 #[rstest]
511 fn test_serializable_msgpack_roundtrip() {
512 let original = SerializableTestStruct {
513 id: 123,
514 name: "msgpack_test".to_string(),
515 value: std::f64::consts::E,
516 };
517
518 let msgpack_bytes = original.to_msgpack_bytes().unwrap();
519 let deserialized = SerializableTestStruct::from_msgpack_bytes(&msgpack_bytes).unwrap();
520
521 assert_eq!(original, deserialized);
522 }
523
524 #[rstest]
525 fn test_serializable_json_invalid_data() {
526 let invalid_json = b"invalid json data";
527 let result = SerializableTestStruct::from_json_bytes(invalid_json);
528 assert!(result.is_err());
529 }
530
531 #[rstest]
532 fn test_serializable_msgpack_invalid_data() {
533 let invalid_msgpack = b"invalid msgpack data";
534 let result = SerializableTestStruct::from_msgpack_bytes(invalid_msgpack);
535 assert!(result.is_err());
536 }
537
538 #[rstest]
539 fn test_serializable_json_empty_values() {
540 let test_struct = SerializableTestStruct {
541 id: 0,
542 name: String::new(),
543 value: 0.0,
544 };
545
546 let json_bytes = test_struct.to_json_bytes().unwrap();
547 let deserialized = SerializableTestStruct::from_json_bytes(&json_bytes).unwrap();
548
549 assert_eq!(test_struct, deserialized);
550 }
551
552 #[rstest]
553 fn test_serializable_msgpack_empty_values() {
554 let test_struct = SerializableTestStruct {
555 id: 0,
556 name: String::new(),
557 value: 0.0,
558 };
559
560 let msgpack_bytes = test_struct.to_msgpack_bytes().unwrap();
561 let deserialized = SerializableTestStruct::from_msgpack_bytes(&msgpack_bytes).unwrap();
562
563 assert_eq!(test_struct, deserialized);
564 }
565
566 #[derive(Deserialize)]
567 struct TestOptionalDecimal {
568 #[serde(deserialize_with = "deserialize_optional_decimal")]
569 value: Option<Decimal>,
570 }
571
572 #[derive(Deserialize)]
573 struct TestDecimalOrZero {
574 #[serde(deserialize_with = "deserialize_decimal_or_zero")]
575 value: Decimal,
576 }
577
578 #[derive(Deserialize)]
579 struct TestOptionalDecimalOrZero {
580 #[serde(deserialize_with = "deserialize_optional_decimal_or_zero")]
581 value: Decimal,
582 }
583
584 #[derive(Serialize, Deserialize, PartialEq, Debug)]
585 struct TestDecimalRoundtrip {
586 #[serde(
587 serialize_with = "serialize_decimal_as_str",
588 deserialize_with = "deserialize_decimal_from_str"
589 )]
590 value: Decimal,
591 #[serde(
592 serialize_with = "serialize_optional_decimal_as_str",
593 deserialize_with = "super::deserialize_optional_decimal_from_str"
594 )]
595 optional_value: Option<Decimal>,
596 }
597
598 #[rstest]
599 #[case(r#"{"value":"123.45"}"#, Some(dec!(123.45)))]
600 #[case(r#"{"value":"0"}"#, None)]
601 #[case(r#"{"value":""}"#, None)]
602 fn test_deserialize_optional_decimal(#[case] json: &str, #[case] expected: Option<Decimal>) {
603 let result: TestOptionalDecimal = serde_json::from_str(json).unwrap();
604 assert_eq!(result.value, expected);
605 }
606
607 #[rstest]
608 #[case(r#"{"value":"123.45"}"#, dec!(123.45))]
609 #[case(r#"{"value":"0"}"#, Decimal::ZERO)]
610 #[case(r#"{"value":""}"#, Decimal::ZERO)]
611 fn test_deserialize_decimal_or_zero(#[case] json: &str, #[case] expected: Decimal) {
612 let result: TestDecimalOrZero = serde_json::from_str(json).unwrap();
613 assert_eq!(result.value, expected);
614 }
615
616 #[rstest]
617 #[case(r#"{"value":"123.45"}"#, dec!(123.45))]
618 #[case(r#"{"value":"0"}"#, Decimal::ZERO)]
619 #[case(r#"{"value":null}"#, Decimal::ZERO)]
620 fn test_deserialize_optional_decimal_or_zero(#[case] json: &str, #[case] expected: Decimal) {
621 let result: TestOptionalDecimalOrZero = serde_json::from_str(json).unwrap();
622 assert_eq!(result.value, expected);
623 }
624
625 #[rstest]
626 fn test_decimal_serialization_roundtrip() {
627 let original = TestDecimalRoundtrip {
628 value: dec!(123.456789012345678),
629 optional_value: Some(dec!(0.000000001)),
630 };
631
632 let json = serde_json::to_string(&original).unwrap();
633
634 assert!(json.contains("\"123.456789012345678\""));
636 assert!(json.contains("\"0.000000001\""));
637
638 let deserialized: TestDecimalRoundtrip = serde_json::from_str(&json).unwrap();
639 assert_eq!(original.value, deserialized.value);
640 assert_eq!(original.optional_value, deserialized.optional_value);
641 }
642
643 #[rstest]
644 fn test_decimal_optional_none_handling() {
645 let test_struct = TestDecimalRoundtrip {
646 value: dec!(42.0),
647 optional_value: None,
648 };
649
650 let json = serde_json::to_string(&test_struct).unwrap();
651 assert!(json.contains("null"));
652
653 let parsed: TestDecimalRoundtrip = serde_json::from_str(&json).unwrap();
654 assert_eq!(test_struct.value, parsed.value);
655 assert_eq!(None, parsed.optional_value);
656 }
657
658 #[derive(Deserialize)]
659 struct TestEmptyStringAsNone {
660 #[serde(deserialize_with = "deserialize_empty_string_as_none")]
661 value: Option<String>,
662 }
663
664 #[rstest]
665 #[case(r#"{"value":"hello"}"#, Some("hello".to_string()))]
666 #[case(r#"{"value":""}"#, None)]
667 #[case(r#"{"value":null}"#, None)]
668 fn test_deserialize_empty_string_as_none(#[case] json: &str, #[case] expected: Option<String>) {
669 let result: TestEmptyStringAsNone = serde_json::from_str(json).unwrap();
670 assert_eq!(result.value, expected);
671 }
672
673 #[derive(Deserialize)]
674 struct TestEmptyUstrAsNone {
675 #[serde(deserialize_with = "deserialize_empty_ustr_as_none")]
676 value: Option<Ustr>,
677 }
678
679 #[rstest]
680 #[case(r#"{"value":"hello"}"#, Some(Ustr::from("hello")))]
681 #[case(r#"{"value":""}"#, None)]
682 #[case(r#"{"value":null}"#, None)]
683 fn test_deserialize_empty_ustr_as_none(#[case] json: &str, #[case] expected: Option<Ustr>) {
684 let result: TestEmptyUstrAsNone = serde_json::from_str(json).unwrap();
685 assert_eq!(result.value, expected);
686 }
687
688 #[derive(Serialize, Deserialize, PartialEq, Debug)]
689 struct TestVecDecimal {
690 #[serde(
691 serialize_with = "serialize_vec_decimal_as_str",
692 deserialize_with = "deserialize_vec_decimal_from_str"
693 )]
694 values: Vec<Decimal>,
695 }
696
697 #[rstest]
698 fn test_vec_decimal_roundtrip() {
699 let original = TestVecDecimal {
700 values: vec![dec!(1.5), dec!(2.25), dec!(100.001)],
701 };
702
703 let json = serde_json::to_string(&original).unwrap();
704 assert!(json.contains("[\"1.5\",\"2.25\",\"100.001\"]"));
705
706 let parsed: TestVecDecimal = serde_json::from_str(&json).unwrap();
707 assert_eq!(original.values, parsed.values);
708 }
709
710 #[rstest]
711 fn test_vec_decimal_empty() {
712 let original = TestVecDecimal { values: vec![] };
713
714 let json = serde_json::to_string(&original).unwrap();
715 let parsed: TestVecDecimal = serde_json::from_str(&json).unwrap();
716 assert_eq!(original.values, parsed.values);
717 }
718
719 #[derive(Deserialize)]
720 struct TestStringToU8 {
721 #[serde(deserialize_with = "deserialize_string_to_u8")]
722 value: u8,
723 }
724
725 #[rstest]
726 #[case(r#"{"value":"42"}"#, 42)]
727 #[case(r#"{"value":"0"}"#, 0)]
728 #[case(r#"{"value":""}"#, 0)]
729 fn test_deserialize_string_to_u8(#[case] json: &str, #[case] expected: u8) {
730 let result: TestStringToU8 = serde_json::from_str(json).unwrap();
731 assert_eq!(result.value, expected);
732 }
733
734 #[derive(Deserialize)]
735 struct TestStringToU64 {
736 #[serde(deserialize_with = "deserialize_string_to_u64")]
737 value: u64,
738 }
739
740 #[rstest]
741 #[case(r#"{"value":"12345678901234"}"#, 12345678901234)]
742 #[case(r#"{"value":"0"}"#, 0)]
743 #[case(r#"{"value":""}"#, 0)]
744 fn test_deserialize_string_to_u64(#[case] json: &str, #[case] expected: u64) {
745 let result: TestStringToU64 = serde_json::from_str(json).unwrap();
746 assert_eq!(result.value, expected);
747 }
748
749 #[derive(Deserialize)]
750 struct TestOptionalStringToU64 {
751 #[serde(deserialize_with = "deserialize_optional_string_to_u64")]
752 value: Option<u64>,
753 }
754
755 #[rstest]
756 #[case(r#"{"value":"12345678901234"}"#, Some(12345678901234))]
757 #[case(r#"{"value":"0"}"#, Some(0))]
758 #[case(r#"{"value":""}"#, None)]
759 #[case(r#"{"value":null}"#, None)]
760 fn test_deserialize_optional_string_to_u64(#[case] json: &str, #[case] expected: Option<u64>) {
761 let result: TestOptionalStringToU64 = serde_json::from_str(json).unwrap();
762 assert_eq!(result.value, expected);
763 }
764
765 #[rstest]
766 #[case("123.45", dec!(123.45))]
767 #[case("0", Decimal::ZERO)]
768 #[case("0.0", Decimal::ZERO)]
769 fn test_parse_decimal(#[case] input: &str, #[case] expected: Decimal) {
770 let result = parse_decimal(input).unwrap();
771 assert_eq!(result, expected);
772 }
773
774 #[rstest]
775 fn test_parse_decimal_invalid() {
776 assert!(parse_decimal("invalid").is_err());
777 assert!(parse_decimal("").is_err());
778 }
779
780 #[rstest]
781 #[case(&Some("123.45".to_string()), Some(dec!(123.45)))]
782 #[case(&Some("0".to_string()), Some(Decimal::ZERO))]
783 #[case(&Some(String::new()), None)]
784 #[case(&None, None)]
785 fn test_parse_optional_decimal(
786 #[case] input: &Option<String>,
787 #[case] expected: Option<Decimal>,
788 ) {
789 let result = parse_optional_decimal(input).unwrap();
790 assert_eq!(result, expected);
791 }
792}