nautilus_model/events/account/
state.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, fmt::Display};
17
18use nautilus_core::{UUID4, UnixNanos};
19use serde::{Deserialize, Serialize};
20
21use crate::{
22    enums::AccountType,
23    identifiers::{AccountId, InstrumentId},
24    types::{AccountBalance, Currency, MarginBalance},
25};
26
27/// Represents an event which includes information on the state of the account.
28#[repr(C)]
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[cfg_attr(
31    feature = "python",
32    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
33)]
34pub struct AccountState {
35    /// The account ID associated with the event.
36    pub account_id: AccountId,
37    /// The type of the account (e.g., margin, spot, etc.).
38    pub account_type: AccountType,
39    /// The base currency for the account, if applicable.
40    pub base_currency: Option<Currency>,
41    /// The balances in the account.
42    pub balances: Vec<AccountBalance>,
43    /// The margin balances in the account.
44    pub margins: Vec<MarginBalance>,
45    /// Indicates if the account state is reported by the exchange
46    /// (as opposed to system-calculated).
47    pub is_reported: bool,
48    /// The unique identifier for the event.
49    pub event_id: UUID4,
50    /// UNIX timestamp (nanoseconds) when the event occurred.
51    pub ts_event: UnixNanos,
52    /// UNIX timestamp (nanoseconds) when the event was initialized.
53    pub ts_init: UnixNanos,
54}
55
56impl AccountState {
57    /// Creates a new [`AccountState`] instance.
58    #[allow(clippy::too_many_arguments)]
59    pub fn new(
60        account_id: AccountId,
61        account_type: AccountType,
62        balances: Vec<AccountBalance>,
63        margins: Vec<MarginBalance>,
64        is_reported: bool,
65        event_id: UUID4,
66        ts_event: UnixNanos,
67        ts_init: UnixNanos,
68        base_currency: Option<Currency>,
69    ) -> Self {
70        Self {
71            account_id,
72            account_type,
73            base_currency,
74            balances,
75            margins,
76            is_reported,
77            event_id,
78            ts_event,
79            ts_init,
80        }
81    }
82
83    /// Returns `true` if this account state has the same balances and margins as another.
84    ///
85    /// This compares all balances and margins for equality, returning `true` only if
86    /// all balances and margins are equal. If any balance or margin is different or
87    /// missing, returns `false`.
88    ///
89    /// # Note
90    ///
91    /// This method does not compare event IDs, timestamps, or other metadata - only
92    /// the actual balance and margin values.
93    pub fn has_same_balances_and_margins(&self, other: &Self) -> bool {
94        // Quick check - if lengths differ, they can't be equal
95        if self.balances.len() != other.balances.len() || self.margins.len() != other.margins.len()
96        {
97            return false;
98        }
99
100        // Compare balances by currency
101        let self_balances: HashMap<Currency, &AccountBalance> = self
102            .balances
103            .iter()
104            .map(|balance| (balance.currency, balance))
105            .collect();
106
107        let other_balances: HashMap<Currency, &AccountBalance> = other
108            .balances
109            .iter()
110            .map(|balance| (balance.currency, balance))
111            .collect();
112
113        // Check if all balances are equal
114        for (currency, self_balance) in &self_balances {
115            match other_balances.get(currency) {
116                Some(other_balance) => {
117                    if self_balance != other_balance {
118                        return false;
119                    }
120                }
121                None => return false, // Currency missing in other
122            }
123        }
124
125        // Compare margins by instrument_id
126        let self_margins: HashMap<InstrumentId, &MarginBalance> = self
127            .margins
128            .iter()
129            .map(|margin| (margin.instrument_id, margin))
130            .collect();
131
132        let other_margins: HashMap<InstrumentId, &MarginBalance> = other
133            .margins
134            .iter()
135            .map(|margin| (margin.instrument_id, margin))
136            .collect();
137
138        // Check if all margins are equal
139        for (instrument_id, self_margin) in &self_margins {
140            match other_margins.get(instrument_id) {
141                Some(other_margin) => {
142                    if self_margin != other_margin {
143                        return false;
144                    }
145                }
146                None => return false, // Instrument missing in other
147            }
148        }
149
150        true
151    }
152}
153
154impl Display for AccountState {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "{}(account_id={}, account_type={}, base_currency={}, is_reported={}, balances=[{}], margins=[{}], event_id={})",
159            stringify!(AccountState),
160            self.account_id,
161            self.account_type,
162            self.base_currency.map_or_else(
163                || "None".to_string(),
164                |base_currency| format!("{}", base_currency.code)
165            ),
166            self.is_reported,
167            self.balances
168                .iter()
169                .map(|b| format!("{b}"))
170                .collect::<Vec<String>>()
171                .join(", "),
172            self.margins
173                .iter()
174                .map(|m| format!("{m}"))
175                .collect::<Vec<String>>()
176                .join(", "),
177            self.event_id
178        )
179    }
180}
181
182impl PartialEq for AccountState {
183    fn eq(&self, other: &Self) -> bool {
184        self.account_id == other.account_id
185            && self.account_type == other.account_type
186            && self.event_id == other.event_id
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use nautilus_core::{UUID4, UnixNanos};
193    use rstest::rstest;
194
195    use crate::{
196        enums::AccountType,
197        events::{
198            AccountState,
199            account::stubs::{cash_account_state, margin_account_state},
200        },
201        identifiers::{AccountId, InstrumentId},
202        types::{AccountBalance, Currency, MarginBalance, Money},
203    };
204
205    #[rstest]
206    fn test_equality() {
207        let cash_account_state_1 = cash_account_state();
208        let cash_account_state_2 = cash_account_state();
209        assert_eq!(cash_account_state_1, cash_account_state_2);
210    }
211
212    #[rstest]
213    fn test_display_cash_account_state(cash_account_state: AccountState) {
214        let display = format!("{cash_account_state}");
215        assert_eq!(
216            display,
217            "AccountState(account_id=SIM-001, account_type=CASH, base_currency=USD, is_reported=true, \
218            balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
219            margins=[], event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
220        );
221    }
222
223    #[rstest]
224    fn test_display_margin_account_state(margin_account_state: AccountState) {
225        let display = format!("{margin_account_state}");
226        assert_eq!(
227            display,
228            "AccountState(account_id=SIM-001, account_type=MARGIN, base_currency=USD, is_reported=true, \
229            balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
230            margins=[MarginBalance(initial=5000.00 USD, maintenance=20000.00 USD, instrument_id=BTCUSDT.COINBASE)], \
231            event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
232        );
233    }
234
235    #[rstest]
236    fn test_has_same_balances_and_margins_when_identical() {
237        let state1 = cash_account_state();
238        let state2 = cash_account_state();
239        assert!(state1.has_same_balances_and_margins(&state2));
240    }
241
242    #[rstest]
243    fn test_has_same_balances_and_margins_when_different_balance_amounts() {
244        let state1 = cash_account_state();
245        let mut state2 = cash_account_state();
246        // Create a different balance with same currency
247        let usd = Currency::USD();
248        let different_balance = AccountBalance::new(
249            Money::new(2000000.0, usd),
250            Money::new(50000.0, usd),
251            Money::new(1950000.0, usd),
252        );
253        state2.balances = vec![different_balance];
254        assert!(!state1.has_same_balances_and_margins(&state2));
255    }
256
257    #[rstest]
258    fn test_has_same_balances_and_margins_when_different_balance_currencies() {
259        let state1 = cash_account_state();
260        let mut state2 = cash_account_state();
261        // Create a balance with different currency
262        let eur = Currency::EUR();
263        let different_balance = AccountBalance::new(
264            Money::new(1525000.0, eur),
265            Money::new(25000.0, eur),
266            Money::new(1500000.0, eur),
267        );
268        state2.balances = vec![different_balance];
269        assert!(!state1.has_same_balances_and_margins(&state2));
270    }
271
272    #[rstest]
273    fn test_has_same_balances_and_margins_when_missing_balance() {
274        let state1 = cash_account_state();
275        let mut state2 = cash_account_state();
276        // Add an additional balance to state2
277        let eur = Currency::EUR();
278        let additional_balance = AccountBalance::new(
279            Money::new(1000000.0, eur),
280            Money::new(0.0, eur),
281            Money::new(1000000.0, eur),
282        );
283        state2.balances.push(additional_balance);
284        assert!(!state1.has_same_balances_and_margins(&state2));
285    }
286
287    #[rstest]
288    fn test_has_same_balances_and_margins_when_different_margin_amounts() {
289        let state1 = margin_account_state();
290        let mut state2 = margin_account_state();
291        // Create a different margin with same instrument_id
292        let usd = Currency::USD();
293        let instrument_id = InstrumentId::from("BTCUSDT.COINBASE");
294        let different_margin = MarginBalance::new(
295            Money::new(10000.0, usd),
296            Money::new(40000.0, usd),
297            instrument_id,
298        );
299        state2.margins = vec![different_margin];
300        assert!(!state1.has_same_balances_and_margins(&state2));
301    }
302
303    #[rstest]
304    fn test_has_same_balances_and_margins_when_different_margin_instruments() {
305        let state1 = margin_account_state();
306        let mut state2 = margin_account_state();
307        // Create a margin with different instrument_id
308        let usd = Currency::USD();
309        let different_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
310        let different_margin = MarginBalance::new(
311            Money::new(5000.0, usd),
312            Money::new(20000.0, usd),
313            different_instrument_id,
314        );
315        state2.margins = vec![different_margin];
316        assert!(!state1.has_same_balances_and_margins(&state2));
317    }
318
319    #[rstest]
320    fn test_has_same_balances_and_margins_when_missing_margin() {
321        let state1 = margin_account_state();
322        let mut state2 = margin_account_state();
323        // Add an additional margin to state2
324        let usd = Currency::USD();
325        let additional_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
326        let additional_margin = MarginBalance::new(
327            Money::new(3000.0, usd),
328            Money::new(15000.0, usd),
329            additional_instrument_id,
330        );
331        state2.margins.push(additional_margin);
332        assert!(!state1.has_same_balances_and_margins(&state2));
333    }
334
335    #[rstest]
336    fn test_has_same_balances_and_margins_with_empty_collections() {
337        let account_id = AccountId::new("TEST-001");
338        let event_id = UUID4::new();
339        let ts_event = UnixNanos::from(1);
340        let ts_init = UnixNanos::from(2);
341
342        let state1 = AccountState::new(
343            account_id,
344            AccountType::Cash,
345            vec![], // Empty balances
346            vec![], // Empty margins
347            true,
348            event_id,
349            ts_event,
350            ts_init,
351            Some(Currency::USD()),
352        );
353
354        let state2 = AccountState::new(
355            account_id,
356            AccountType::Cash,
357            vec![], // Empty balances
358            vec![], // Empty margins
359            true,
360            UUID4::new(),       // Different event_id
361            UnixNanos::from(3), // Different timestamps
362            UnixNanos::from(4),
363            Some(Currency::USD()),
364        );
365
366        assert!(state1.has_same_balances_and_margins(&state2));
367    }
368
369    #[rstest]
370    fn test_has_same_balances_and_margins_with_multiple_balances_and_margins() {
371        let account_id = AccountId::new("TEST-001");
372        let event_id = UUID4::new();
373        let ts_event = UnixNanos::from(1);
374        let ts_init = UnixNanos::from(2);
375
376        let usd = Currency::USD();
377        let eur = Currency::EUR();
378        let btc_instrument = InstrumentId::from("BTCUSDT.COINBASE");
379        let eth_instrument = InstrumentId::from("ETHUSDT.BINANCE");
380
381        let balances = vec![
382            AccountBalance::new(
383                Money::new(1000000.0, usd),
384                Money::new(0.0, usd),
385                Money::new(1000000.0, usd),
386            ),
387            AccountBalance::new(
388                Money::new(500000.0, eur),
389                Money::new(10000.0, eur),
390                Money::new(490000.0, eur),
391            ),
392        ];
393
394        let margins = vec![
395            MarginBalance::new(
396                Money::new(5000.0, usd),
397                Money::new(20000.0, usd),
398                btc_instrument,
399            ),
400            MarginBalance::new(
401                Money::new(3000.0, usd),
402                Money::new(15000.0, usd),
403                eth_instrument,
404            ),
405        ];
406
407        let state1 = AccountState::new(
408            account_id,
409            AccountType::Margin,
410            balances.clone(),
411            margins.clone(),
412            true,
413            event_id,
414            ts_event,
415            ts_init,
416            Some(usd),
417        );
418
419        let state2 = AccountState::new(
420            account_id,
421            AccountType::Margin,
422            balances,
423            margins,
424            true,
425            UUID4::new(),       // Different event_id
426            UnixNanos::from(3), // Different timestamps
427            UnixNanos::from(4),
428            Some(usd),
429        );
430
431        assert!(state1.has_same_balances_and_margins(&state2));
432    }
433}