1use std::{
17 collections::HashMap,
18 fmt::{Display, Formatter},
19};
20
21use nautilus_core::{UUID4, UnixNanos};
22use serde::{Deserialize, Serialize};
23
24use crate::{
25 enums::AccountType,
26 identifiers::{AccountId, InstrumentId},
27 types::{AccountBalance, Currency, MarginBalance},
28};
29
30#[repr(C)]
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
36)]
37pub struct AccountState {
38 pub account_id: AccountId,
40 pub account_type: AccountType,
42 pub base_currency: Option<Currency>,
44 pub balances: Vec<AccountBalance>,
46 pub margins: Vec<MarginBalance>,
48 pub is_reported: bool,
51 pub event_id: UUID4,
53 pub ts_event: UnixNanos,
55 pub ts_init: UnixNanos,
57}
58
59impl AccountState {
60 #[allow(clippy::too_many_arguments)]
62 pub fn new(
63 account_id: AccountId,
64 account_type: AccountType,
65 balances: Vec<AccountBalance>,
66 margins: Vec<MarginBalance>,
67 is_reported: bool,
68 event_id: UUID4,
69 ts_event: UnixNanos,
70 ts_init: UnixNanos,
71 base_currency: Option<Currency>,
72 ) -> Self {
73 Self {
74 account_id,
75 account_type,
76 base_currency,
77 balances,
78 margins,
79 is_reported,
80 event_id,
81 ts_event,
82 ts_init,
83 }
84 }
85
86 pub fn has_same_balances_and_margins(&self, other: &Self) -> bool {
97 if self.balances.len() != other.balances.len() || self.margins.len() != other.margins.len()
99 {
100 return false;
101 }
102
103 let self_balances: HashMap<Currency, &AccountBalance> = self
105 .balances
106 .iter()
107 .map(|balance| (balance.currency, balance))
108 .collect();
109
110 let other_balances: HashMap<Currency, &AccountBalance> = other
111 .balances
112 .iter()
113 .map(|balance| (balance.currency, balance))
114 .collect();
115
116 for (currency, self_balance) in &self_balances {
118 match other_balances.get(currency) {
119 Some(other_balance) => {
120 if self_balance != other_balance {
121 return false;
122 }
123 }
124 None => return false, }
126 }
127
128 let self_margins: HashMap<InstrumentId, &MarginBalance> = self
130 .margins
131 .iter()
132 .map(|margin| (margin.instrument_id, margin))
133 .collect();
134
135 let other_margins: HashMap<InstrumentId, &MarginBalance> = other
136 .margins
137 .iter()
138 .map(|margin| (margin.instrument_id, margin))
139 .collect();
140
141 for (instrument_id, self_margin) in &self_margins {
143 match other_margins.get(instrument_id) {
144 Some(other_margin) => {
145 if self_margin != other_margin {
146 return false;
147 }
148 }
149 None => return false, }
151 }
152
153 true
154 }
155}
156
157impl Display for AccountState {
158 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159 write!(
160 f,
161 "{}(account_id={}, account_type={}, base_currency={}, is_reported={}, balances=[{}], margins=[{}], event_id={})",
162 stringify!(AccountState),
163 self.account_id,
164 self.account_type,
165 self.base_currency.map_or_else(
166 || "None".to_string(),
167 |base_currency| format!("{}", base_currency.code)
168 ),
169 self.is_reported,
170 self.balances
171 .iter()
172 .map(|b| format!("{b}"))
173 .collect::<Vec<String>>()
174 .join(","),
175 self.margins
176 .iter()
177 .map(|m| format!("{m}"))
178 .collect::<Vec<String>>()
179 .join(","),
180 self.event_id
181 )
182 }
183}
184
185impl PartialEq for AccountState {
186 fn eq(&self, other: &Self) -> bool {
187 self.account_id == other.account_id
188 && self.account_type == other.account_type
189 && self.event_id == other.event_id
190 }
191}
192
193#[cfg(test)]
197mod tests {
198 use nautilus_core::{UUID4, UnixNanos};
199 use rstest::rstest;
200
201 use crate::{
202 enums::AccountType,
203 events::{
204 AccountState,
205 account::stubs::{cash_account_state, margin_account_state},
206 },
207 identifiers::{AccountId, InstrumentId},
208 types::{AccountBalance, Currency, MarginBalance, Money},
209 };
210
211 #[rstest]
212 fn test_equality() {
213 let cash_account_state_1 = cash_account_state();
214 let cash_account_state_2 = cash_account_state();
215 assert_eq!(cash_account_state_1, cash_account_state_2);
216 }
217
218 #[rstest]
219 fn test_display_cash_account_state(cash_account_state: AccountState) {
220 let display = format!("{cash_account_state}");
221 assert_eq!(
222 display,
223 "AccountState(account_id=SIM-001, account_type=CASH, base_currency=USD, is_reported=true, \
224 balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
225 margins=[], event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
226 );
227 }
228
229 #[rstest]
230 fn test_display_margin_account_state(margin_account_state: AccountState) {
231 let display = format!("{margin_account_state}");
232 assert_eq!(
233 display,
234 "AccountState(account_id=SIM-001, account_type=MARGIN, base_currency=USD, is_reported=true, \
235 balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
236 margins=[MarginBalance(initial=5000.00 USD, maintenance=20000.00 USD, instrument_id=BTCUSDT.COINBASE)], \
237 event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
238 );
239 }
240
241 #[rstest]
242 fn test_has_same_balances_and_margins_when_identical() {
243 let state1 = cash_account_state();
244 let state2 = cash_account_state();
245 assert!(state1.has_same_balances_and_margins(&state2));
246 }
247
248 #[rstest]
249 fn test_has_same_balances_and_margins_when_different_balance_amounts() {
250 let state1 = cash_account_state();
251 let mut state2 = cash_account_state();
252 let usd = Currency::USD();
254 let different_balance = AccountBalance::new(
255 Money::new(2000000.0, usd),
256 Money::new(50000.0, usd),
257 Money::new(1950000.0, usd),
258 );
259 state2.balances = vec![different_balance];
260 assert!(!state1.has_same_balances_and_margins(&state2));
261 }
262
263 #[rstest]
264 fn test_has_same_balances_and_margins_when_different_balance_currencies() {
265 let state1 = cash_account_state();
266 let mut state2 = cash_account_state();
267 let eur = Currency::EUR();
269 let different_balance = AccountBalance::new(
270 Money::new(1525000.0, eur),
271 Money::new(25000.0, eur),
272 Money::new(1500000.0, eur),
273 );
274 state2.balances = vec![different_balance];
275 assert!(!state1.has_same_balances_and_margins(&state2));
276 }
277
278 #[rstest]
279 fn test_has_same_balances_and_margins_when_missing_balance() {
280 let state1 = cash_account_state();
281 let mut state2 = cash_account_state();
282 let eur = Currency::EUR();
284 let additional_balance = AccountBalance::new(
285 Money::new(1000000.0, eur),
286 Money::new(0.0, eur),
287 Money::new(1000000.0, eur),
288 );
289 state2.balances.push(additional_balance);
290 assert!(!state1.has_same_balances_and_margins(&state2));
291 }
292
293 #[rstest]
294 fn test_has_same_balances_and_margins_when_different_margin_amounts() {
295 let state1 = margin_account_state();
296 let mut state2 = margin_account_state();
297 let usd = Currency::USD();
299 let instrument_id = InstrumentId::from("BTCUSDT.COINBASE");
300 let different_margin = MarginBalance::new(
301 Money::new(10000.0, usd),
302 Money::new(40000.0, usd),
303 instrument_id,
304 );
305 state2.margins = vec![different_margin];
306 assert!(!state1.has_same_balances_and_margins(&state2));
307 }
308
309 #[rstest]
310 fn test_has_same_balances_and_margins_when_different_margin_instruments() {
311 let state1 = margin_account_state();
312 let mut state2 = margin_account_state();
313 let usd = Currency::USD();
315 let different_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
316 let different_margin = MarginBalance::new(
317 Money::new(5000.0, usd),
318 Money::new(20000.0, usd),
319 different_instrument_id,
320 );
321 state2.margins = vec![different_margin];
322 assert!(!state1.has_same_balances_and_margins(&state2));
323 }
324
325 #[rstest]
326 fn test_has_same_balances_and_margins_when_missing_margin() {
327 let state1 = margin_account_state();
328 let mut state2 = margin_account_state();
329 let usd = Currency::USD();
331 let additional_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
332 let additional_margin = MarginBalance::new(
333 Money::new(3000.0, usd),
334 Money::new(15000.0, usd),
335 additional_instrument_id,
336 );
337 state2.margins.push(additional_margin);
338 assert!(!state1.has_same_balances_and_margins(&state2));
339 }
340
341 #[rstest]
342 fn test_has_same_balances_and_margins_with_empty_collections() {
343 let account_id = AccountId::new("TEST-001");
344 let event_id = UUID4::new();
345 let ts_event = UnixNanos::from(1);
346 let ts_init = UnixNanos::from(2);
347
348 let state1 = AccountState::new(
349 account_id,
350 AccountType::Cash,
351 vec![], vec![], true,
354 event_id,
355 ts_event,
356 ts_init,
357 Some(Currency::USD()),
358 );
359
360 let state2 = AccountState::new(
361 account_id,
362 AccountType::Cash,
363 vec![], vec![], true,
366 UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
369 Some(Currency::USD()),
370 );
371
372 assert!(state1.has_same_balances_and_margins(&state2));
373 }
374
375 #[rstest]
376 fn test_has_same_balances_and_margins_with_multiple_balances_and_margins() {
377 let account_id = AccountId::new("TEST-001");
378 let event_id = UUID4::new();
379 let ts_event = UnixNanos::from(1);
380 let ts_init = UnixNanos::from(2);
381
382 let usd = Currency::USD();
383 let eur = Currency::EUR();
384 let btc_instrument = InstrumentId::from("BTCUSDT.COINBASE");
385 let eth_instrument = InstrumentId::from("ETHUSDT.BINANCE");
386
387 let balances = vec![
388 AccountBalance::new(
389 Money::new(1000000.0, usd),
390 Money::new(0.0, usd),
391 Money::new(1000000.0, usd),
392 ),
393 AccountBalance::new(
394 Money::new(500000.0, eur),
395 Money::new(10000.0, eur),
396 Money::new(490000.0, eur),
397 ),
398 ];
399
400 let margins = vec![
401 MarginBalance::new(
402 Money::new(5000.0, usd),
403 Money::new(20000.0, usd),
404 btc_instrument,
405 ),
406 MarginBalance::new(
407 Money::new(3000.0, usd),
408 Money::new(15000.0, usd),
409 eth_instrument,
410 ),
411 ];
412
413 let state1 = AccountState::new(
414 account_id,
415 AccountType::Margin,
416 balances.clone(),
417 margins.clone(),
418 true,
419 event_id,
420 ts_event,
421 ts_init,
422 Some(usd),
423 );
424
425 let state2 = AccountState::new(
426 account_id,
427 AccountType::Margin,
428 balances,
429 margins,
430 true,
431 UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
434 Some(usd),
435 );
436
437 assert!(state1.has_same_balances_and_margins(&state2));
438 }
439}