nautilus_model/events/account/
state.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::HashMap,
18    fmt::{Display, Formatter},
19};
20
21use nautilus_core::{UUID4, UnixNanos};
22use serde::{Deserialize, Serialize};
23
24use crate::{
25    enums::AccountType,
26    identifiers::{AccountId, InstrumentId},
27    types::{AccountBalance, Currency, MarginBalance},
28};
29
30/// Represents an event which includes information on the state of the account.
31#[repr(C)]
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[cfg_attr(
34    feature = "python",
35    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
36)]
37pub struct AccountState {
38    /// The account ID associated with the event.
39    pub account_id: AccountId,
40    /// The type of the account (e.g., margin, spot, etc.).
41    pub account_type: AccountType,
42    /// The base currency for the account, if applicable.
43    pub base_currency: Option<Currency>,
44    /// The balances in the account.
45    pub balances: Vec<AccountBalance>,
46    /// The margin balances in the account.
47    pub margins: Vec<MarginBalance>,
48    /// Indicates if the account state is reported by the exchange
49    /// (as opposed to system-calculated).
50    pub is_reported: bool,
51    /// The unique identifier for the event.
52    pub event_id: UUID4,
53    /// UNIX timestamp (nanoseconds) when the event occurred.
54    pub ts_event: UnixNanos,
55    /// UNIX timestamp (nanoseconds) when the event was initialized.
56    pub ts_init: UnixNanos,
57}
58
59impl AccountState {
60    /// Creates a new [`AccountState`] instance.
61    #[allow(clippy::too_many_arguments)]
62    pub fn new(
63        account_id: AccountId,
64        account_type: AccountType,
65        balances: Vec<AccountBalance>,
66        margins: Vec<MarginBalance>,
67        is_reported: bool,
68        event_id: UUID4,
69        ts_event: UnixNanos,
70        ts_init: UnixNanos,
71        base_currency: Option<Currency>,
72    ) -> Self {
73        Self {
74            account_id,
75            account_type,
76            base_currency,
77            balances,
78            margins,
79            is_reported,
80            event_id,
81            ts_event,
82            ts_init,
83        }
84    }
85
86    /// Returns `true` if this account state has the same balances and margins as another.
87    ///
88    /// This compares all balances and margins for equality, returning `true` only if
89    /// all balances and margins are equal. If any balance or margin is different or
90    /// missing, returns `false`.
91    ///
92    /// # Note
93    ///
94    /// This method does not compare event IDs, timestamps, or other metadata - only
95    /// the actual balance and margin values.
96    pub fn has_same_balances_and_margins(&self, other: &Self) -> bool {
97        // Quick check - if lengths differ, they can't be equal
98        if self.balances.len() != other.balances.len() || self.margins.len() != other.margins.len()
99        {
100            return false;
101        }
102
103        // Compare balances by currency
104        let self_balances: HashMap<Currency, &AccountBalance> = self
105            .balances
106            .iter()
107            .map(|balance| (balance.currency, balance))
108            .collect();
109
110        let other_balances: HashMap<Currency, &AccountBalance> = other
111            .balances
112            .iter()
113            .map(|balance| (balance.currency, balance))
114            .collect();
115
116        // Check if all balances are equal
117        for (currency, self_balance) in &self_balances {
118            match other_balances.get(currency) {
119                Some(other_balance) => {
120                    if self_balance != other_balance {
121                        return false;
122                    }
123                }
124                None => return false, // Currency missing in other
125            }
126        }
127
128        // Compare margins by instrument_id
129        let self_margins: HashMap<InstrumentId, &MarginBalance> = self
130            .margins
131            .iter()
132            .map(|margin| (margin.instrument_id, margin))
133            .collect();
134
135        let other_margins: HashMap<InstrumentId, &MarginBalance> = other
136            .margins
137            .iter()
138            .map(|margin| (margin.instrument_id, margin))
139            .collect();
140
141        // Check if all margins are equal
142        for (instrument_id, self_margin) in &self_margins {
143            match other_margins.get(instrument_id) {
144                Some(other_margin) => {
145                    if self_margin != other_margin {
146                        return false;
147                    }
148                }
149                None => return false, // Instrument missing in other
150            }
151        }
152
153        true
154    }
155}
156
157impl Display for AccountState {
158    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159        write!(
160            f,
161            "{}(account_id={}, account_type={}, base_currency={}, is_reported={}, balances=[{}], margins=[{}], event_id={})",
162            stringify!(AccountState),
163            self.account_id,
164            self.account_type,
165            self.base_currency.map_or_else(
166                || "None".to_string(),
167                |base_currency| format!("{}", base_currency.code)
168            ),
169            self.is_reported,
170            self.balances
171                .iter()
172                .map(|b| format!("{b}"))
173                .collect::<Vec<String>>()
174                .join(","),
175            self.margins
176                .iter()
177                .map(|m| format!("{m}"))
178                .collect::<Vec<String>>()
179                .join(","),
180            self.event_id
181        )
182    }
183}
184
185impl PartialEq for AccountState {
186    fn eq(&self, other: &Self) -> bool {
187        self.account_id == other.account_id
188            && self.account_type == other.account_type
189            && self.event_id == other.event_id
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use nautilus_core::{UUID4, UnixNanos};
196    use rstest::rstest;
197
198    use crate::{
199        enums::AccountType,
200        events::{
201            AccountState,
202            account::stubs::{cash_account_state, margin_account_state},
203        },
204        identifiers::{AccountId, InstrumentId},
205        types::{AccountBalance, Currency, MarginBalance, Money},
206    };
207
208    #[rstest]
209    fn test_equality() {
210        let cash_account_state_1 = cash_account_state();
211        let cash_account_state_2 = cash_account_state();
212        assert_eq!(cash_account_state_1, cash_account_state_2);
213    }
214
215    #[rstest]
216    fn test_display_cash_account_state(cash_account_state: AccountState) {
217        let display = format!("{cash_account_state}");
218        assert_eq!(
219            display,
220            "AccountState(account_id=SIM-001, account_type=CASH, base_currency=USD, is_reported=true, \
221            balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
222            margins=[], event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
223        );
224    }
225
226    #[rstest]
227    fn test_display_margin_account_state(margin_account_state: AccountState) {
228        let display = format!("{margin_account_state}");
229        assert_eq!(
230            display,
231            "AccountState(account_id=SIM-001, account_type=MARGIN, base_currency=USD, is_reported=true, \
232            balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
233            margins=[MarginBalance(initial=5000.00 USD, maintenance=20000.00 USD, instrument_id=BTCUSDT.COINBASE)], \
234            event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
235        );
236    }
237
238    #[rstest]
239    fn test_has_same_balances_and_margins_when_identical() {
240        let state1 = cash_account_state();
241        let state2 = cash_account_state();
242        assert!(state1.has_same_balances_and_margins(&state2));
243    }
244
245    #[rstest]
246    fn test_has_same_balances_and_margins_when_different_balance_amounts() {
247        let state1 = cash_account_state();
248        let mut state2 = cash_account_state();
249        // Create a different balance with same currency
250        let usd = Currency::USD();
251        let different_balance = AccountBalance::new(
252            Money::new(2000000.0, usd),
253            Money::new(50000.0, usd),
254            Money::new(1950000.0, usd),
255        );
256        state2.balances = vec![different_balance];
257        assert!(!state1.has_same_balances_and_margins(&state2));
258    }
259
260    #[rstest]
261    fn test_has_same_balances_and_margins_when_different_balance_currencies() {
262        let state1 = cash_account_state();
263        let mut state2 = cash_account_state();
264        // Create a balance with different currency
265        let eur = Currency::EUR();
266        let different_balance = AccountBalance::new(
267            Money::new(1525000.0, eur),
268            Money::new(25000.0, eur),
269            Money::new(1500000.0, eur),
270        );
271        state2.balances = vec![different_balance];
272        assert!(!state1.has_same_balances_and_margins(&state2));
273    }
274
275    #[rstest]
276    fn test_has_same_balances_and_margins_when_missing_balance() {
277        let state1 = cash_account_state();
278        let mut state2 = cash_account_state();
279        // Add an additional balance to state2
280        let eur = Currency::EUR();
281        let additional_balance = AccountBalance::new(
282            Money::new(1000000.0, eur),
283            Money::new(0.0, eur),
284            Money::new(1000000.0, eur),
285        );
286        state2.balances.push(additional_balance);
287        assert!(!state1.has_same_balances_and_margins(&state2));
288    }
289
290    #[rstest]
291    fn test_has_same_balances_and_margins_when_different_margin_amounts() {
292        let state1 = margin_account_state();
293        let mut state2 = margin_account_state();
294        // Create a different margin with same instrument_id
295        let usd = Currency::USD();
296        let instrument_id = InstrumentId::from("BTCUSDT.COINBASE");
297        let different_margin = MarginBalance::new(
298            Money::new(10000.0, usd),
299            Money::new(40000.0, usd),
300            instrument_id,
301        );
302        state2.margins = vec![different_margin];
303        assert!(!state1.has_same_balances_and_margins(&state2));
304    }
305
306    #[rstest]
307    fn test_has_same_balances_and_margins_when_different_margin_instruments() {
308        let state1 = margin_account_state();
309        let mut state2 = margin_account_state();
310        // Create a margin with different instrument_id
311        let usd = Currency::USD();
312        let different_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
313        let different_margin = MarginBalance::new(
314            Money::new(5000.0, usd),
315            Money::new(20000.0, usd),
316            different_instrument_id,
317        );
318        state2.margins = vec![different_margin];
319        assert!(!state1.has_same_balances_and_margins(&state2));
320    }
321
322    #[rstest]
323    fn test_has_same_balances_and_margins_when_missing_margin() {
324        let state1 = margin_account_state();
325        let mut state2 = margin_account_state();
326        // Add an additional margin to state2
327        let usd = Currency::USD();
328        let additional_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
329        let additional_margin = MarginBalance::new(
330            Money::new(3000.0, usd),
331            Money::new(15000.0, usd),
332            additional_instrument_id,
333        );
334        state2.margins.push(additional_margin);
335        assert!(!state1.has_same_balances_and_margins(&state2));
336    }
337
338    #[rstest]
339    fn test_has_same_balances_and_margins_with_empty_collections() {
340        let account_id = AccountId::new("TEST-001");
341        let event_id = UUID4::new();
342        let ts_event = UnixNanos::from(1);
343        let ts_init = UnixNanos::from(2);
344
345        let state1 = AccountState::new(
346            account_id,
347            AccountType::Cash,
348            vec![], // Empty balances
349            vec![], // Empty margins
350            true,
351            event_id,
352            ts_event,
353            ts_init,
354            Some(Currency::USD()),
355        );
356
357        let state2 = AccountState::new(
358            account_id,
359            AccountType::Cash,
360            vec![], // Empty balances
361            vec![], // Empty margins
362            true,
363            UUID4::new(),       // Different event_id
364            UnixNanos::from(3), // Different timestamps
365            UnixNanos::from(4),
366            Some(Currency::USD()),
367        );
368
369        assert!(state1.has_same_balances_and_margins(&state2));
370    }
371
372    #[rstest]
373    fn test_has_same_balances_and_margins_with_multiple_balances_and_margins() {
374        let account_id = AccountId::new("TEST-001");
375        let event_id = UUID4::new();
376        let ts_event = UnixNanos::from(1);
377        let ts_init = UnixNanos::from(2);
378
379        let usd = Currency::USD();
380        let eur = Currency::EUR();
381        let btc_instrument = InstrumentId::from("BTCUSDT.COINBASE");
382        let eth_instrument = InstrumentId::from("ETHUSDT.BINANCE");
383
384        let balances = vec![
385            AccountBalance::new(
386                Money::new(1000000.0, usd),
387                Money::new(0.0, usd),
388                Money::new(1000000.0, usd),
389            ),
390            AccountBalance::new(
391                Money::new(500000.0, eur),
392                Money::new(10000.0, eur),
393                Money::new(490000.0, eur),
394            ),
395        ];
396
397        let margins = vec![
398            MarginBalance::new(
399                Money::new(5000.0, usd),
400                Money::new(20000.0, usd),
401                btc_instrument,
402            ),
403            MarginBalance::new(
404                Money::new(3000.0, usd),
405                Money::new(15000.0, usd),
406                eth_instrument,
407            ),
408        ];
409
410        let state1 = AccountState::new(
411            account_id,
412            AccountType::Margin,
413            balances.clone(),
414            margins.clone(),
415            true,
416            event_id,
417            ts_event,
418            ts_init,
419            Some(usd),
420        );
421
422        let state2 = AccountState::new(
423            account_id,
424            AccountType::Margin,
425            balances,
426            margins,
427            true,
428            UUID4::new(),       // Different event_id
429            UnixNanos::from(3), // Different timestamps
430            UnixNanos::from(4),
431            Some(usd),
432        );
433
434        assert!(state1.has_same_balances_and_margins(&state2));
435    }
436}