1use std::{
17 collections::HashMap,
18 fmt::{Display, Formatter},
19};
20
21use nautilus_core::{UUID4, UnixNanos};
22use serde::{Deserialize, Serialize};
23
24use crate::{
25 enums::AccountType,
26 identifiers::{AccountId, InstrumentId},
27 types::{AccountBalance, Currency, MarginBalance},
28};
29
30#[repr(C)]
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
36)]
37pub struct AccountState {
38 pub account_id: AccountId,
40 pub account_type: AccountType,
42 pub base_currency: Option<Currency>,
44 pub balances: Vec<AccountBalance>,
46 pub margins: Vec<MarginBalance>,
48 pub is_reported: bool,
51 pub event_id: UUID4,
53 pub ts_event: UnixNanos,
55 pub ts_init: UnixNanos,
57}
58
59impl AccountState {
60 #[allow(clippy::too_many_arguments)]
62 pub fn new(
63 account_id: AccountId,
64 account_type: AccountType,
65 balances: Vec<AccountBalance>,
66 margins: Vec<MarginBalance>,
67 is_reported: bool,
68 event_id: UUID4,
69 ts_event: UnixNanos,
70 ts_init: UnixNanos,
71 base_currency: Option<Currency>,
72 ) -> Self {
73 Self {
74 account_id,
75 account_type,
76 base_currency,
77 balances,
78 margins,
79 is_reported,
80 event_id,
81 ts_event,
82 ts_init,
83 }
84 }
85
86 pub fn has_same_balances_and_margins(&self, other: &Self) -> bool {
97 if self.balances.len() != other.balances.len() || self.margins.len() != other.margins.len()
99 {
100 return false;
101 }
102
103 let self_balances: HashMap<Currency, &AccountBalance> = self
105 .balances
106 .iter()
107 .map(|balance| (balance.currency, balance))
108 .collect();
109
110 let other_balances: HashMap<Currency, &AccountBalance> = other
111 .balances
112 .iter()
113 .map(|balance| (balance.currency, balance))
114 .collect();
115
116 for (currency, self_balance) in &self_balances {
118 match other_balances.get(currency) {
119 Some(other_balance) => {
120 if self_balance != other_balance {
121 return false;
122 }
123 }
124 None => return false, }
126 }
127
128 let self_margins: HashMap<InstrumentId, &MarginBalance> = self
130 .margins
131 .iter()
132 .map(|margin| (margin.instrument_id, margin))
133 .collect();
134
135 let other_margins: HashMap<InstrumentId, &MarginBalance> = other
136 .margins
137 .iter()
138 .map(|margin| (margin.instrument_id, margin))
139 .collect();
140
141 for (instrument_id, self_margin) in &self_margins {
143 match other_margins.get(instrument_id) {
144 Some(other_margin) => {
145 if self_margin != other_margin {
146 return false;
147 }
148 }
149 None => return false, }
151 }
152
153 true
154 }
155}
156
157impl Display for AccountState {
158 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159 write!(
160 f,
161 "{}(account_id={}, account_type={}, base_currency={}, is_reported={}, balances=[{}], margins=[{}], event_id={})",
162 stringify!(AccountState),
163 self.account_id,
164 self.account_type,
165 self.base_currency.map_or_else(
166 || "None".to_string(),
167 |base_currency| format!("{}", base_currency.code)
168 ),
169 self.is_reported,
170 self.balances
171 .iter()
172 .map(|b| format!("{b}"))
173 .collect::<Vec<String>>()
174 .join(","),
175 self.margins
176 .iter()
177 .map(|m| format!("{m}"))
178 .collect::<Vec<String>>()
179 .join(","),
180 self.event_id
181 )
182 }
183}
184
185impl PartialEq for AccountState {
186 fn eq(&self, other: &Self) -> bool {
187 self.account_id == other.account_id
188 && self.account_type == other.account_type
189 && self.event_id == other.event_id
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use nautilus_core::{UUID4, UnixNanos};
196 use rstest::rstest;
197
198 use crate::{
199 enums::AccountType,
200 events::{
201 AccountState,
202 account::stubs::{cash_account_state, margin_account_state},
203 },
204 identifiers::{AccountId, InstrumentId},
205 types::{AccountBalance, Currency, MarginBalance, Money},
206 };
207
208 #[rstest]
209 fn test_equality() {
210 let cash_account_state_1 = cash_account_state();
211 let cash_account_state_2 = cash_account_state();
212 assert_eq!(cash_account_state_1, cash_account_state_2);
213 }
214
215 #[rstest]
216 fn test_display_cash_account_state(cash_account_state: AccountState) {
217 let display = format!("{cash_account_state}");
218 assert_eq!(
219 display,
220 "AccountState(account_id=SIM-001, account_type=CASH, base_currency=USD, is_reported=true, \
221 balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
222 margins=[], event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
223 );
224 }
225
226 #[rstest]
227 fn test_display_margin_account_state(margin_account_state: AccountState) {
228 let display = format!("{margin_account_state}");
229 assert_eq!(
230 display,
231 "AccountState(account_id=SIM-001, account_type=MARGIN, base_currency=USD, is_reported=true, \
232 balances=[AccountBalance(total=1525000.00 USD, locked=25000.00 USD, free=1500000.00 USD)], \
233 margins=[MarginBalance(initial=5000.00 USD, maintenance=20000.00 USD, instrument_id=BTCUSDT.COINBASE)], \
234 event_id=16578139-a945-4b65-b46c-bc131a15d8e7)"
235 );
236 }
237
238 #[rstest]
239 fn test_has_same_balances_and_margins_when_identical() {
240 let state1 = cash_account_state();
241 let state2 = cash_account_state();
242 assert!(state1.has_same_balances_and_margins(&state2));
243 }
244
245 #[rstest]
246 fn test_has_same_balances_and_margins_when_different_balance_amounts() {
247 let state1 = cash_account_state();
248 let mut state2 = cash_account_state();
249 let usd = Currency::USD();
251 let different_balance = AccountBalance::new(
252 Money::new(2000000.0, usd),
253 Money::new(50000.0, usd),
254 Money::new(1950000.0, usd),
255 );
256 state2.balances = vec![different_balance];
257 assert!(!state1.has_same_balances_and_margins(&state2));
258 }
259
260 #[rstest]
261 fn test_has_same_balances_and_margins_when_different_balance_currencies() {
262 let state1 = cash_account_state();
263 let mut state2 = cash_account_state();
264 let eur = Currency::EUR();
266 let different_balance = AccountBalance::new(
267 Money::new(1525000.0, eur),
268 Money::new(25000.0, eur),
269 Money::new(1500000.0, eur),
270 );
271 state2.balances = vec![different_balance];
272 assert!(!state1.has_same_balances_and_margins(&state2));
273 }
274
275 #[rstest]
276 fn test_has_same_balances_and_margins_when_missing_balance() {
277 let state1 = cash_account_state();
278 let mut state2 = cash_account_state();
279 let eur = Currency::EUR();
281 let additional_balance = AccountBalance::new(
282 Money::new(1000000.0, eur),
283 Money::new(0.0, eur),
284 Money::new(1000000.0, eur),
285 );
286 state2.balances.push(additional_balance);
287 assert!(!state1.has_same_balances_and_margins(&state2));
288 }
289
290 #[rstest]
291 fn test_has_same_balances_and_margins_when_different_margin_amounts() {
292 let state1 = margin_account_state();
293 let mut state2 = margin_account_state();
294 let usd = Currency::USD();
296 let instrument_id = InstrumentId::from("BTCUSDT.COINBASE");
297 let different_margin = MarginBalance::new(
298 Money::new(10000.0, usd),
299 Money::new(40000.0, usd),
300 instrument_id,
301 );
302 state2.margins = vec![different_margin];
303 assert!(!state1.has_same_balances_and_margins(&state2));
304 }
305
306 #[rstest]
307 fn test_has_same_balances_and_margins_when_different_margin_instruments() {
308 let state1 = margin_account_state();
309 let mut state2 = margin_account_state();
310 let usd = Currency::USD();
312 let different_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
313 let different_margin = MarginBalance::new(
314 Money::new(5000.0, usd),
315 Money::new(20000.0, usd),
316 different_instrument_id,
317 );
318 state2.margins = vec![different_margin];
319 assert!(!state1.has_same_balances_and_margins(&state2));
320 }
321
322 #[rstest]
323 fn test_has_same_balances_and_margins_when_missing_margin() {
324 let state1 = margin_account_state();
325 let mut state2 = margin_account_state();
326 let usd = Currency::USD();
328 let additional_instrument_id = InstrumentId::from("ETHUSDT.BINANCE");
329 let additional_margin = MarginBalance::new(
330 Money::new(3000.0, usd),
331 Money::new(15000.0, usd),
332 additional_instrument_id,
333 );
334 state2.margins.push(additional_margin);
335 assert!(!state1.has_same_balances_and_margins(&state2));
336 }
337
338 #[rstest]
339 fn test_has_same_balances_and_margins_with_empty_collections() {
340 let account_id = AccountId::new("TEST-001");
341 let event_id = UUID4::new();
342 let ts_event = UnixNanos::from(1);
343 let ts_init = UnixNanos::from(2);
344
345 let state1 = AccountState::new(
346 account_id,
347 AccountType::Cash,
348 vec![], vec![], true,
351 event_id,
352 ts_event,
353 ts_init,
354 Some(Currency::USD()),
355 );
356
357 let state2 = AccountState::new(
358 account_id,
359 AccountType::Cash,
360 vec![], vec![], true,
363 UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
366 Some(Currency::USD()),
367 );
368
369 assert!(state1.has_same_balances_and_margins(&state2));
370 }
371
372 #[rstest]
373 fn test_has_same_balances_and_margins_with_multiple_balances_and_margins() {
374 let account_id = AccountId::new("TEST-001");
375 let event_id = UUID4::new();
376 let ts_event = UnixNanos::from(1);
377 let ts_init = UnixNanos::from(2);
378
379 let usd = Currency::USD();
380 let eur = Currency::EUR();
381 let btc_instrument = InstrumentId::from("BTCUSDT.COINBASE");
382 let eth_instrument = InstrumentId::from("ETHUSDT.BINANCE");
383
384 let balances = vec![
385 AccountBalance::new(
386 Money::new(1000000.0, usd),
387 Money::new(0.0, usd),
388 Money::new(1000000.0, usd),
389 ),
390 AccountBalance::new(
391 Money::new(500000.0, eur),
392 Money::new(10000.0, eur),
393 Money::new(490000.0, eur),
394 ),
395 ];
396
397 let margins = vec![
398 MarginBalance::new(
399 Money::new(5000.0, usd),
400 Money::new(20000.0, usd),
401 btc_instrument,
402 ),
403 MarginBalance::new(
404 Money::new(3000.0, usd),
405 Money::new(15000.0, usd),
406 eth_instrument,
407 ),
408 ];
409
410 let state1 = AccountState::new(
411 account_id,
412 AccountType::Margin,
413 balances.clone(),
414 margins.clone(),
415 true,
416 event_id,
417 ts_event,
418 ts_init,
419 Some(usd),
420 );
421
422 let state2 = AccountState::new(
423 account_id,
424 AccountType::Margin,
425 balances,
426 margins,
427 true,
428 UUID4::new(), UnixNanos::from(3), UnixNanos::from(4),
431 Some(usd),
432 );
433
434 assert!(state1.has_same_balances_and_margins(&state2));
435 }
436}