1use std::{collections::HashMap, fmt::Display};
17
18use nautilus_core::{UUID4, UnixNanos};
19use serde::{Deserialize, Serialize};
20
21use crate::{
22 enums::AccountType,
23 identifiers::{AccountId, InstrumentId},
24 types::{AccountBalance, Currency, MarginBalance},
25};
26
27#[repr(C)]
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[cfg_attr(
31 feature = "python",
32 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
33)]
34pub struct AccountState {
35 pub account_id: AccountId,
37 pub account_type: AccountType,
39 pub base_currency: Option<Currency>,
41 pub balances: Vec<AccountBalance>,
43 pub margins: Vec<MarginBalance>,
45 pub is_reported: bool,
48 pub event_id: UUID4,
50 pub ts_event: UnixNanos,
52 pub ts_init: UnixNanos,
54}
55
56impl AccountState {
57 #[allow(clippy::too_many_arguments)]
59 pub fn new(
60 account_id: AccountId,
61 account_type: AccountType,
62 balances: Vec<AccountBalance>,
63 margins: Vec<MarginBalance>,
64 is_reported: bool,
65 event_id: UUID4,
66 ts_event: UnixNanos,
67 ts_init: UnixNanos,
68 base_currency: Option<Currency>,
69 ) -> Self {
70 Self {
71 account_id,
72 account_type,
73 base_currency,
74 balances,
75 margins,
76 is_reported,
77 event_id,
78 ts_event,
79 ts_init,
80 }
81 }
82
83 pub fn has_same_balances_and_margins(&self, other: &Self) -> bool {
94 if self.balances.len() != other.balances.len() || self.margins.len() != other.margins.len()
96 {
97 return false;
98 }
99
100 let self_balances: HashMap<Currency, &AccountBalance> = self
102 .balances
103 .iter()
104 .map(|balance| (balance.currency, balance))
105 .collect();
106
107 let other_balances: HashMap<Currency, &AccountBalance> = other
108 .balances
109 .iter()
110 .map(|balance| (balance.currency, balance))
111 .collect();
112
113 for (currency, self_balance) in &self_balances {
115 match other_balances.get(currency) {
116 Some(other_balance) => {
117 if self_balance != other_balance {
118 return false;
119 }
120 }
121 None => return false, }
123 }
124
125 let self_margins: HashMap<InstrumentId, &MarginBalance> = self
127 .margins
128 .iter()
129 .map(|margin| (margin.instrument_id, margin))
130 .collect();
131
132 let other_margins: HashMap<InstrumentId, &MarginBalance> = other
133 .margins
134 .iter()
135 .map(|margin| (margin.instrument_id, margin))
136 .collect();
137
138 for (instrument_id, self_margin) in &self_margins {
140 match other_margins.get(instrument_id) {
141 Some(other_margin) => {
142 if self_margin != other_margin {
143 return false;
144 }
145 }
146 None => return false, }
148 }
149
150 true
151 }
152}
153
154impl Display for AccountState {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 write!(
157 f,
158 "{}(account_id={}, account_type={}, base_currency={}, is_reported={}, balances=[{}], margins=[{}], event_id={})",
159 stringify!(AccountState),
160 self.account_id,
161 self.account_type,
162 self.base_currency.map_or_else(
163 || "None".to_string(),
164 |base_currency| format!("{}", base_currency.code)
165 ),
166 self.is_reported,
167 self.balances
168 .iter()
169 .map(|b| format!("{b}"))
170 .collect::<Vec<String>>()
171 .join(", "),
172 self.margins
173 .iter()
174 .map(|m| format!("{m}"))
175 .collect::<Vec<String>>()
176 .join(", "),
177 self.event_id
178 )
179 }
180}
181
182impl PartialEq for AccountState {
183 fn eq(&self, other: &Self) -> bool {
184 self.account_id == other.account_id
185 && self.account_type == other.account_type
186 && self.event_id == other.event_id
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use nautilus_core::{UUID4, UnixNanos};
193 use rstest::rstest;
194
195 use crate::{
196 enums::AccountType,
197 events::{
198 AccountState,
199 account::stubs::{cash_account_state, margin_account_state},
200 },
201 identifiers::{AccountId, InstrumentId},
202 types::{AccountBalance, Currency, MarginBalance, Money},
203 };
204
205 #[rstest]
206 fn test_equality() {
207 let cash_account_state_1 = cash_account_state();
208 let cash_account_state_2 = cash_account_state();
209 assert_eq!(cash_account_state_1, cash_account_state_2);
210 }
211
212 #[rstest]
213 fn test_display_cash_account_state(cash_account_state: AccountState) {
214 let display = format!("{cash_account_state}");
215 assert_eq!(
216 display,
217 "AccountState(account_id=SIM-001, account_type=CASH, base_currency=USD, is_reported=true, \
218 balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
219 margins=[], event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
220 );
221 }
222
223 #[rstest]
224 fn test_display_margin_account_state(margin_account_state: AccountState) {
225 let display = format!("{margin_account_state}");
226 assert_eq!(
227 display,
228 "AccountState(account_id=SIM-001, account_type=MARGIN, base_currency=USD, is_reported=true, \
229 balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
230 margins=[MarginBalance(initial=5000.00 USD, maintenance=20000.00 USD, instrument_id=BTCUSDT.COINBASE)], \
231 event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
232 );
233 }
234
235 #[rstest]
236 fn test_has_same_balances_and_margins_when_identical() {
237 let state1 = cash_account_state();
238 let state2 = cash_account_state();
239 assert!(state1.has_same_balances_and_margins(&state2));
240 }
241
242 #[rstest]
243 fn test_has_same_balances_and_margins_when_different_balance_amounts() {
244 let state1 = cash_account_state();
245 let mut state2 = cash_account_state();
246 let usd = Currency::USD();
248 let different_balance = AccountBalance::new(
249 Money::new(2000000.0, usd),
250 Money::new(50000.0, usd),
251 Money::new(1950000.0, usd),
252 );
253 state2.balances = vec![different_balance];
254 assert!(!state1.has_same_balances_and_margins(&state2));
255 }
256
257 #[rstest]
258 fn test_has_same_balances_and_margins_when_different_balance_currencies() {
259 let state1 = cash_account_state();
260 let mut state2 = cash_account_state();
261 let eur = Currency::EUR();
263 let different_balance = AccountBalance::new(
264 Money::new(1525000.0, eur),
265 Money::new(25000.0, eur),
266 Money::new(1500000.0, eur),
267 );
268 state2.balances = vec![different_balance];
269 assert!(!state1.has_same_balances_and_margins(&state2));
270 }
271
272 #[rstest]
273 fn test_has_same_balances_and_margins_when_missing_balance() {
274 let state1 = cash_account_state();
275 let mut state2 = cash_account_state();
276 let eur = Currency::EUR();
278 let additional_balance = AccountBalance::new(
279 Money::new(1000000.0, eur),
280 Money::new(0.0, eur),
281 Money::new(1000000.0, eur),
282 );
283 state2.balances.push(additional_balance);
284 assert!(!state1.has_same_balances_and_margins(&state2));
285 }
286
287 #[rstest]
288 fn test_has_same_balances_and_margins_when_different_margin_amounts() {
289 let state1 = margin_account_state();
290 let mut state2 = margin_account_state();
291 let usd = Currency::USD();
293 let instrument_id = InstrumentId::from("BTCUSDT.COINBASE");
294 let different_margin = MarginBalance::new(
295 Money::new(10000.0, usd),
296 Money::new(40000.0, usd),
297 instrument_id,
298 );
299 state2.margins = vec![different_margin];
300 assert!(!state1.has_same_balances_and_margins(&state2));
301 }
302
303 #[rstest]
304 fn test_has_same_balances_and_margins_when_different_margin_instruments() {
305 let state1 = margin_account_state();
306 let mut state2 = margin_account_state();
307 let usd = Currency::USD();
309 let different_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
310 let different_margin = MarginBalance::new(
311 Money::new(5000.0, usd),
312 Money::new(20000.0, usd),
313 different_instrument_id,
314 );
315 state2.margins = vec![different_margin];
316 assert!(!state1.has_same_balances_and_margins(&state2));
317 }
318
319 #[rstest]
320 fn test_has_same_balances_and_margins_when_missing_margin() {
321 let state1 = margin_account_state();
322 let mut state2 = margin_account_state();
323 let usd = Currency::USD();
325 let additional_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
326 let additional_margin = MarginBalance::new(
327 Money::new(3000.0, usd),
328 Money::new(15000.0, usd),
329 additional_instrument_id,
330 );
331 state2.margins.push(additional_margin);
332 assert!(!state1.has_same_balances_and_margins(&state2));
333 }
334
335 #[rstest]
336 fn test_has_same_balances_and_margins_with_empty_collections() {
337 let account_id = AccountId::new("TEST-001");
338 let event_id = UUID4::new();
339 let ts_event = UnixNanos::from(1);
340 let ts_init = UnixNanos::from(2);
341
342 let state1 = AccountState::new(
343 account_id,
344 AccountType::Cash,
345 vec![], vec![], true,
348 event_id,
349 ts_event,
350 ts_init,
351 Some(Currency::USD()),
352 );
353
354 let state2 = AccountState::new(
355 account_id,
356 AccountType::Cash,
357 vec![], vec![], true,
360 UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
363 Some(Currency::USD()),
364 );
365
366 assert!(state1.has_same_balances_and_margins(&state2));
367 }
368
369 #[rstest]
370 fn test_has_same_balances_and_margins_with_multiple_balances_and_margins() {
371 let account_id = AccountId::new("TEST-001");
372 let event_id = UUID4::new();
373 let ts_event = UnixNanos::from(1);
374 let ts_init = UnixNanos::from(2);
375
376 let usd = Currency::USD();
377 let eur = Currency::EUR();
378 let btc_instrument = InstrumentId::from("BTCUSDT.COINBASE");
379 let eth_instrument = InstrumentId::from("ETHUSDT.BINANCE");
380
381 let balances = vec![
382 AccountBalance::new(
383 Money::new(1000000.0, usd),
384 Money::new(0.0, usd),
385 Money::new(1000000.0, usd),
386 ),
387 AccountBalance::new(
388 Money::new(500000.0, eur),
389 Money::new(10000.0, eur),
390 Money::new(490000.0, eur),
391 ),
392 ];
393
394 let margins = vec![
395 MarginBalance::new(
396 Money::new(5000.0, usd),
397 Money::new(20000.0, usd),
398 btc_instrument,
399 ),
400 MarginBalance::new(
401 Money::new(3000.0, usd),
402 Money::new(15000.0, usd),
403 eth_instrument,
404 ),
405 ];
406
407 let state1 = AccountState::new(
408 account_id,
409 AccountType::Margin,
410 balances.clone(),
411 margins.clone(),
412 true,
413 event_id,
414 ts_event,
415 ts_init,
416 Some(usd),
417 );
418
419 let state2 = AccountState::new(
420 account_id,
421 AccountType::Margin,
422 balances,
423 margins,
424 true,
425 UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
428 Some(usd),
429 );
430
431 assert!(state1.has_same_balances_and_margins(&state2));
432 }
433}