nautilus_model/defi/
pool_identifier.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    fmt::{Debug, Display, Formatter},
18    hash::{Hash, Hasher},
19    str::FromStr,
20};
21
22use alloy_primitives::Address;
23use nautilus_core::correctness::FAILED;
24use serde::{Deserialize, Deserializer, Serialize, Serializer};
25use ustr::Ustr;
26
27/// Protocol-aware pool identifier for DeFi liquidity pools.
28///
29/// This enum distinguishes between two types of pool identifiers:
30/// - **Address**: Used by V2/V3 protocols where pool identifier equals pool contract address (42 chars: "0x" + 40 hex)
31/// - **PoolId**: Used by V4 protocols where pool identifier is a bytes32 hash (66 chars: "0x" + 64 hex)
32///
33/// The type implements case-insensitive equality and hashing for address comparison,
34/// while preserving the original case for display purposes.
35#[derive(Clone, Copy, PartialOrd, Ord)]
36pub enum PoolIdentifier {
37    /// V2/V3 pool identifier (checksummed Ethereum address)
38    Address(Ustr),
39    /// V4 pool identifier (32-byte pool ID as hex string)
40    PoolId(Ustr),
41}
42
43impl PoolIdentifier {
44    /// Creates a new [`PoolIdentifier`] instance with correctness checking.
45    ///
46    /// Automatically detects variant based on string length:
47    /// - 42 characters (0x + 40 hex): Address variant
48    /// - 66 characters (0x + 64 hex): PoolId variant
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if:
53    /// - String doesn't start with "0x"
54    /// - Length is neither 42 nor 66 characters
55    /// - Contains invalid hex characters
56    /// - Address checksum validation fails (for Address variant)
57    pub fn new_checked<T: AsRef<str>>(value: T) -> anyhow::Result<Self> {
58        let value = value.as_ref();
59
60        if !value.starts_with("0x") {
61            anyhow::bail!("Pool identifier must start with '0x', got: {value}");
62        }
63
64        match value.len() {
65            42 => {
66                validate_hex_string(value)?;
67
68                // Parse without strict checksum validation, then normalize to checksummed format
69                let addr = value
70                    .parse::<Address>()
71                    .map_err(|e| anyhow::anyhow!("Invalid address: {e}"))?;
72
73                // Store the checksummed version
74                Ok(Self::Address(Ustr::from(addr.to_checksum(None).as_str())))
75            }
76            66 => {
77                // PoolId variant (32 bytes)
78                validate_hex_string(value)?;
79
80                // Store lowercase version for consistency
81                Ok(Self::PoolId(Ustr::from(&value.to_lowercase())))
82            }
83            len => {
84                anyhow::bail!(
85                    "Pool identifier must be 42 chars (address) or 66 chars (pool ID), got {len} chars: {value}"
86                )
87            }
88        }
89    }
90
91    /// Creates a new [`PoolIdentifier`] instance.
92    ///
93    /// # Panics
94    ///
95    /// Panics if validation fails.
96    #[must_use]
97    pub fn new<T: AsRef<str>>(value: T) -> Self {
98        Self::new_checked(value).expect(FAILED)
99    }
100
101    /// Creates an Address variant from an alloy Address.
102    ///
103    /// Returns the checksummed representation.
104    #[must_use]
105    pub fn from_address(address: Address) -> Self {
106        Self::Address(Ustr::from(address.to_checksum(None).as_str()))
107    }
108
109    /// Creates a PoolId variant from raw bytes (32 bytes).
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if bytes length is not 32.
114    pub fn from_pool_id_bytes(bytes: &[u8]) -> anyhow::Result<Self> {
115        anyhow::ensure!(
116            bytes.len() == 32,
117            "Pool ID must be 32 bytes, got {}",
118            bytes.len()
119        );
120
121        let hex_string = format!("0x{}", hex::encode(bytes));
122        Ok(Self::PoolId(Ustr::from(&hex_string)))
123    }
124
125    /// Creates a PoolId variant from a hex string (with or without 0x prefix).
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if the string is not valid 64-character hex.
130    pub fn from_pool_id_hex<T: AsRef<str>>(hex: T) -> anyhow::Result<Self> {
131        let hex = hex.as_ref();
132        let hex_str = hex.strip_prefix("0x").unwrap_or(hex);
133
134        anyhow::ensure!(
135            hex_str.len() == 64,
136            "Pool ID hex must be 64 characters (32 bytes), got {}",
137            hex_str.len()
138        );
139
140        validate_hex_string(&format!("0x{hex_str}"))?;
141
142        Ok(Self::PoolId(Ustr::from(&format!(
143            "0x{}",
144            hex_str.to_lowercase()
145        ))))
146    }
147
148    /// Returns the inner identifier value as a Ustr.
149    #[must_use]
150    pub fn inner(&self) -> Ustr {
151        match self {
152            Self::Address(s) | Self::PoolId(s) => *s,
153        }
154    }
155
156    /// Returns the inner identifier value as a string slice.
157    #[must_use]
158    pub fn as_str(&self) -> &str {
159        match self {
160            Self::Address(s) | Self::PoolId(s) => s.as_str(),
161        }
162    }
163
164    /// Returns true if this is an Address variant (V2/V3 pools).
165    #[must_use]
166    pub fn is_address(&self) -> bool {
167        matches!(self, Self::Address(_))
168    }
169
170    /// Returns true if this is a PoolId variant (V4 pools).
171    #[must_use]
172    pub fn is_pool_id(&self) -> bool {
173        matches!(self, Self::PoolId(_))
174    }
175
176    /// Converts to native Address type (V2/V3 pools only).
177    ///
178    /// Returns the underlying Address for use with alloy/ethers operations.
179    ///
180    /// # Errors
181    ///
182    /// Returns error if this is a PoolId variant or if parsing fails.
183    pub fn to_address(&self) -> anyhow::Result<Address> {
184        match self {
185            Self::Address(s) => Address::parse_checksummed(s.as_str(), None)
186                .map_err(|e| anyhow::anyhow!("Failed to parse address: {e}")),
187            Self::PoolId(_) => anyhow::bail!("Cannot convert PoolId variant to Address"),
188        }
189    }
190
191    /// Converts to native bytes array (V4 pools only).
192    ///
193    /// Returns the 32-byte pool ID for use in V4-specific operations.
194    ///
195    /// # Errors
196    ///
197    /// Returns error if this is an Address variant or if hex decoding fails.
198    pub fn to_pool_id_bytes(&self) -> anyhow::Result<[u8; 32]> {
199        match self {
200            Self::PoolId(s) => {
201                let hex = s.as_str().strip_prefix("0x").unwrap_or(s.as_str());
202                let bytes = hex::decode(hex)
203                    .map_err(|e| anyhow::anyhow!("Failed to decode pool ID hex: {e}",))?;
204
205                bytes
206                    .try_into()
207                    .map_err(|_| anyhow::anyhow!("Pool ID must be exactly 32 bytes"))
208            }
209            Self::Address(_) => anyhow::bail!("Cannot convert Address variant to PoolId bytes"),
210        }
211    }
212}
213
214/// Validates that a string contains only valid hexadecimal characters after "0x" prefix.
215fn validate_hex_string(s: &str) -> anyhow::Result<()> {
216    let hex_part = &s[2..];
217    if !hex_part.chars().all(|c| c.is_ascii_hexdigit()) {
218        anyhow::bail!("Invalid hex characters in: {s}");
219    }
220    Ok(())
221}
222
223impl PartialEq for PoolIdentifier {
224    fn eq(&self, other: &Self) -> bool {
225        match (self, other) {
226            (Self::Address(a), Self::Address(b)) | (Self::PoolId(a), Self::PoolId(b)) => {
227                // Case-insensitive comparison
228                a.as_str().eq_ignore_ascii_case(b.as_str())
229            }
230            // Different variants are never equal
231            _ => false,
232        }
233    }
234}
235
236impl Eq for PoolIdentifier {}
237
238impl Hash for PoolIdentifier {
239    fn hash<H: Hasher>(&self, state: &mut H) {
240        // Hash the variant discriminant first
241        std::mem::discriminant(self).hash(state);
242
243        // Then hash the lowercase version of the string
244        match self {
245            Self::Address(s) | Self::PoolId(s) => {
246                for byte in s.as_str().bytes() {
247                    state.write_u8(byte.to_ascii_lowercase());
248                }
249            }
250        }
251    }
252}
253
254impl Display for PoolIdentifier {
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        match self {
257            Self::Address(s) | Self::PoolId(s) => write!(f, "{s}"),
258        }
259    }
260}
261
262impl Debug for PoolIdentifier {
263    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
264        match self {
265            Self::Address(s) => write!(f, "Address({s:?})"),
266            Self::PoolId(s) => write!(f, "PoolId({s:?})"),
267        }
268    }
269}
270
271impl Serialize for PoolIdentifier {
272    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
273    where
274        S: Serializer,
275    {
276        // Serialize as plain string (same as current String behavior)
277        match self {
278            Self::Address(s) | Self::PoolId(s) => s.serialize(serializer),
279        }
280    }
281}
282
283impl<'de> Deserialize<'de> for PoolIdentifier {
284    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
285    where
286        D: Deserializer<'de>,
287    {
288        let value_str: &str = Deserialize::deserialize(deserializer)?;
289        Self::new_checked(value_str).map_err(serde::de::Error::custom)
290    }
291}
292
293impl FromStr for PoolIdentifier {
294    type Err = anyhow::Error;
295
296    fn from_str(s: &str) -> Result<Self, Self::Err> {
297        Self::new_checked(s)
298    }
299}
300
301impl From<&str> for PoolIdentifier {
302    fn from(value: &str) -> Self {
303        Self::new(value)
304    }
305}
306
307impl From<String> for PoolIdentifier {
308    fn from(value: String) -> Self {
309        Self::new(value)
310    }
311}
312
313impl AsRef<str> for PoolIdentifier {
314    fn as_ref(&self) -> &str {
315        self.as_str()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use rstest::rstest;
322
323    use super::*;
324
325    #[rstest]
326    #[case("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", true)] // Valid checksummed address
327    #[case("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2", true)] // Lowercase address
328    #[case(
329        "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
330        true
331    )] // V4 Pool ID
332    fn test_valid_pool_identifiers(#[case] input: &str, #[case] expected_valid: bool) {
333        let result = PoolIdentifier::new_checked(input);
334        assert_eq!(result.is_ok(), expected_valid, "Input: {input}");
335    }
336
337    #[rstest]
338    #[case("C02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")] // Missing 0x
339    #[case("0xC02aaA39")] // Too short
340    #[case("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2EXTRA")] // Too long
341    #[case("0xGGGGGGGGb223FE8D0A0e5C4F27eAD9083C756Cc2")] // Invalid hex
342    fn test_invalid_pool_identifiers(#[case] input: &str) {
343        let result = PoolIdentifier::new_checked(input);
344        assert!(result.is_err(), "Input should fail: {input}");
345    }
346
347    #[rstest]
348    fn test_case_insensitive_equality() {
349        let addr1 = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
350        let addr2 = PoolIdentifier::new("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2");
351        let addr3 = PoolIdentifier::new("0xC02AAA39B223FE8D0A0E5C4F27EAD9083C756CC2");
352
353        assert_eq!(addr1, addr2);
354        assert_eq!(addr2, addr3);
355        assert_eq!(addr1, addr3);
356    }
357
358    #[rstest]
359    fn test_case_insensitive_hashing() {
360        use std::collections::HashMap;
361
362        let mut map = HashMap::new();
363        let addr1 = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
364        let addr2 = PoolIdentifier::new("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2");
365
366        map.insert(addr1, "value1");
367
368        // Should be able to retrieve using different case
369        assert_eq!(map.get(&addr2), Some(&"value1"));
370    }
371
372    #[rstest]
373    fn test_display_preserves_case() {
374        let checksummed = "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2";
375        let addr = PoolIdentifier::new_checked(checksummed).unwrap();
376
377        // Display should show checksummed version
378        assert_eq!(addr.to_string(), checksummed);
379    }
380
381    #[rstest]
382    fn test_variant_detection() {
383        let address = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
384        let pool_id = PoolIdentifier::new(
385            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
386        );
387
388        assert!(address.is_address());
389        assert!(!address.is_pool_id());
390
391        assert!(pool_id.is_pool_id());
392        assert!(!pool_id.is_address());
393    }
394
395    #[rstest]
396    fn test_different_variants_not_equal() {
397        let address = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
398        let pool_id = PoolIdentifier::new(
399            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
400        );
401
402        assert_ne!(address, pool_id);
403    }
404
405    #[rstest]
406    fn test_serialization_roundtrip() {
407        let original = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
408
409        let json = serde_json::to_string(&original).unwrap();
410        let deserialized: PoolIdentifier = serde_json::from_str(&json).unwrap();
411
412        assert_eq!(original, deserialized);
413    }
414
415    #[rstest]
416    fn test_from_address() {
417        let addr = Address::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap();
418        let pool_id = PoolIdentifier::from_address(addr);
419
420        assert!(pool_id.is_address());
421        assert_eq!(
422            pool_id.to_string(),
423            "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"
424        );
425    }
426
427    #[rstest]
428    fn test_from_pool_id_bytes() {
429        let bytes: [u8; 32] = [
430            0xc9, 0xbc, 0x80, 0x43, 0x29, 0x41, 0x46, 0x42, 0x4a, 0x4e, 0x46, 0x07, 0xd8, 0xad,
431            0x83, 0x7d, 0x6a, 0x65, 0x91, 0x42, 0x82, 0x2b, 0xba, 0xaa, 0xbc, 0x83, 0xbb, 0x57,
432            0xe7, 0x44, 0x74, 0x61,
433        ];
434
435        let pool_id = PoolIdentifier::from_pool_id_bytes(&bytes).unwrap();
436
437        assert!(pool_id.is_pool_id());
438        assert_eq!(
439            pool_id.to_string(),
440            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461"
441        );
442    }
443
444    #[rstest]
445    fn test_to_address() {
446        let id = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
447        let address = id.to_address().unwrap();
448
449        assert_eq!(
450            address.to_string(),
451            "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"
452        );
453    }
454
455    #[rstest]
456    fn test_to_address_fails_for_pool_id() {
457        let pool_id = PoolIdentifier::new(
458            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
459        );
460        let result = pool_id.to_address();
461
462        assert!(result.is_err());
463    }
464
465    #[rstest]
466    fn test_to_pool_id_bytes() {
467        let pool_id = PoolIdentifier::new(
468            "0xc9bc8043294146424a4e4607d8ad837d6a659142822bbaaabc83bb57e7447461",
469        );
470        let bytes = pool_id.to_pool_id_bytes().unwrap();
471
472        assert_eq!(bytes.len(), 32);
473        assert_eq!(bytes[0], 0xc9);
474        assert_eq!(bytes[31], 0x61);
475    }
476
477    #[rstest]
478    fn test_to_pool_id_bytes_fails_for_address() {
479        let address = PoolIdentifier::new("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
480        let result = address.to_pool_id_bytes();
481
482        assert!(result.is_err());
483    }
484
485    #[rstest]
486    fn test_conversion_roundtrip_address() {
487        let original_addr =
488            Address::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2").unwrap();
489        let pool_id = PoolIdentifier::from_address(original_addr);
490        let converted_addr = pool_id.to_address().unwrap();
491
492        assert_eq!(original_addr, converted_addr);
493    }
494
495    #[rstest]
496    fn test_conversion_roundtrip_pool_id() {
497        let original_bytes: [u8; 32] = [
498            0xc9, 0xbc, 0x80, 0x43, 0x29, 0x41, 0x46, 0x42, 0x4a, 0x4e, 0x46, 0x07, 0xd8, 0xad,
499            0x83, 0x7d, 0x6a, 0x65, 0x91, 0x42, 0x82, 0x2b, 0xba, 0xaa, 0xbc, 0x83, 0xbb, 0x57,
500            0xe7, 0x44, 0x74, 0x61,
501        ];
502
503        let pool_id = PoolIdentifier::from_pool_id_bytes(&original_bytes).unwrap();
504        let converted_bytes = pool_id.to_pool_id_bytes().unwrap();
505
506        assert_eq!(original_bytes, converted_bytes);
507    }
508}