1use std::{
17 collections::HashMap,
18 fmt::Display,
19 ops::{Deref, DerefMut},
20};
21
22use rust_decimal::{Decimal, prelude::ToPrimitive};
23use serde::{Deserialize, Serialize};
24
25use crate::{
26 accounts::{Account, base::BaseAccount},
27 enums::{AccountType, LiquiditySide, OrderSide},
28 events::{AccountState, OrderFilled},
29 identifiers::AccountId,
30 instruments::InstrumentAny,
31 position::Position,
32 types::{AccountBalance, Currency, Money, Price, Quantity},
33};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[cfg_attr(
37 feature = "python",
38 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
39)]
40pub struct CashAccount {
41 pub base: BaseAccount,
42}
43
44impl CashAccount {
45 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
47 Self {
48 base: BaseAccount::new(event, calculate_account_state),
49 }
50 }
51
52 #[must_use]
53 pub fn is_cash_account(&self) -> bool {
54 self.account_type == AccountType::Cash
55 }
56 #[must_use]
57 pub fn is_margin_account(&self) -> bool {
58 self.account_type == AccountType::Margin
59 }
60
61 #[must_use]
62 pub const fn is_unleveraged(&self) -> bool {
63 false
64 }
65
66 pub fn recalculate_balance(&mut self, currency: Currency) {
67 let current_balance = match self.balances.get(¤cy) {
68 Some(balance) => *balance,
69 None => {
70 return;
71 }
72 };
73
74 let total_locked = self
75 .balances
76 .values()
77 .filter(|balance| balance.currency == currency)
78 .fold(Decimal::ZERO, |acc, balance| {
79 acc + balance.locked.as_decimal()
80 });
81
82 let new_balance = AccountBalance::new(
83 current_balance.total,
84 Money::new(total_locked.to_f64().unwrap(), currency),
85 Money::new(
86 (current_balance.total.as_decimal() - total_locked)
87 .to_f64()
88 .unwrap(),
89 currency,
90 ),
91 );
92
93 self.balances.insert(currency, new_balance);
94 }
95}
96
97impl Account for CashAccount {
98 fn id(&self) -> AccountId {
99 self.id
100 }
101
102 fn account_type(&self) -> AccountType {
103 self.account_type
104 }
105
106 fn base_currency(&self) -> Option<Currency> {
107 self.base_currency
108 }
109
110 fn is_cash_account(&self) -> bool {
111 self.account_type == AccountType::Cash
112 }
113
114 fn is_margin_account(&self) -> bool {
115 self.account_type == AccountType::Margin
116 }
117
118 fn calculated_account_state(&self) -> bool {
119 false }
121
122 fn balance_total(&self, currency: Option<Currency>) -> Option<Money> {
123 self.base_balance_total(currency)
124 }
125
126 fn balances_total(&self) -> HashMap<Currency, Money> {
127 self.base_balances_total()
128 }
129
130 fn balance_free(&self, currency: Option<Currency>) -> Option<Money> {
131 self.base_balance_free(currency)
132 }
133
134 fn balances_free(&self) -> HashMap<Currency, Money> {
135 self.base_balances_free()
136 }
137
138 fn balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
139 self.base_balance_locked(currency)
140 }
141
142 fn balances_locked(&self) -> HashMap<Currency, Money> {
143 self.base_balances_locked()
144 }
145
146 fn balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
147 self.base_balance(currency)
148 }
149
150 fn last_event(&self) -> Option<AccountState> {
151 self.base_last_event()
152 }
153
154 fn events(&self) -> Vec<AccountState> {
155 self.events.clone()
156 }
157
158 fn event_count(&self) -> usize {
159 self.events.len()
160 }
161
162 fn currencies(&self) -> Vec<Currency> {
163 self.balances.keys().copied().collect()
164 }
165
166 fn starting_balances(&self) -> HashMap<Currency, Money> {
167 self.balances_starting.clone()
168 }
169
170 fn balances(&self) -> HashMap<Currency, AccountBalance> {
171 self.balances.clone()
172 }
173
174 fn apply(&mut self, event: AccountState) {
175 self.base_apply(event);
176 }
177
178 fn calculate_balance_locked(
179 &mut self,
180 instrument: InstrumentAny,
181 side: OrderSide,
182 quantity: Quantity,
183 price: Price,
184 use_quote_for_inverse: Option<bool>,
185 ) -> anyhow::Result<Money> {
186 self.base_calculate_balance_locked(instrument, side, quantity, price, use_quote_for_inverse)
187 }
188
189 fn calculate_pnls(
190 &self,
191 instrument: InstrumentAny, fill: OrderFilled, position: Option<Position>,
194 ) -> anyhow::Result<Vec<Money>> {
195 self.base_calculate_pnls(instrument, fill, position)
196 }
197
198 fn calculate_commission(
199 &self,
200 instrument: InstrumentAny,
201 last_qty: Quantity,
202 last_px: Price,
203 liquidity_side: LiquiditySide,
204 use_quote_for_inverse: Option<bool>,
205 ) -> anyhow::Result<Money> {
206 self.base_calculate_commission(
207 instrument,
208 last_qty,
209 last_px,
210 liquidity_side,
211 use_quote_for_inverse,
212 )
213 }
214}
215
216impl Deref for CashAccount {
217 type Target = BaseAccount;
218
219 fn deref(&self) -> &Self::Target {
220 &self.base
221 }
222}
223
224impl DerefMut for CashAccount {
225 fn deref_mut(&mut self) -> &mut Self::Target {
226 &mut self.base
227 }
228}
229
230impl PartialEq for CashAccount {
231 fn eq(&self, other: &Self) -> bool {
232 self.id == other.id
233 }
234}
235
236impl Eq for CashAccount {}
237
238impl Display for CashAccount {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 write!(
241 f,
242 "CashAccount(id={}, type={}, base={})",
243 self.id,
244 self.account_type,
245 self.base_currency.map_or_else(
246 || "None".to_string(),
247 |base_currency| format!("{}", base_currency.code)
248 ),
249 )
250 }
251}
252
253#[cfg(test)]
257mod tests {
258 use std::collections::{HashMap, HashSet};
259
260 use rstest::rstest;
261
262 use crate::{
263 accounts::{Account, CashAccount, stubs::*},
264 enums::{AccountType, LiquiditySide, OrderSide, OrderType},
265 events::{AccountState, account::stubs::*},
266 identifiers::{AccountId, position_id::PositionId},
267 instruments::{CryptoPerpetual, CurrencyPair, Equity, Instrument, InstrumentAny, stubs::*},
268 orders::{builder::OrderTestBuilder, stubs::TestOrderEventStubs},
269 position::Position,
270 types::{Currency, Money, Price, Quantity},
271 };
272
273 #[rstest]
274 fn test_display(cash_account: CashAccount) {
275 assert_eq!(
276 format!("{cash_account}"),
277 "CashAccount(id=SIM-001, type=CASH, base=USD)"
278 );
279 }
280
281 #[rstest]
282 fn test_instantiate_single_asset_cash_account(
283 cash_account: CashAccount,
284 cash_account_state: AccountState,
285 ) {
286 assert_eq!(cash_account.id, AccountId::from("SIM-001"));
287 assert_eq!(cash_account.account_type, AccountType::Cash);
288 assert_eq!(cash_account.base_currency, Some(Currency::from("USD")));
289 assert_eq!(cash_account.last_event(), Some(cash_account_state.clone()));
290 assert_eq!(cash_account.events(), vec![cash_account_state]);
291 assert_eq!(cash_account.event_count(), 1);
292 assert_eq!(
293 cash_account.balance_total(None),
294 Some(Money::from("1525000 USD"))
295 );
296 assert_eq!(
297 cash_account.balance_free(None),
298 Some(Money::from("1500000 USD"))
299 );
300 assert_eq!(
301 cash_account.balance_locked(None),
302 Some(Money::from("25000 USD"))
303 );
304 let mut balances_total_expected = HashMap::new();
305 balances_total_expected.insert(Currency::from("USD"), Money::from("1525000 USD"));
306 assert_eq!(cash_account.balances_total(), balances_total_expected);
307 let mut balances_free_expected = HashMap::new();
308 balances_free_expected.insert(Currency::from("USD"), Money::from("1500000 USD"));
309 assert_eq!(cash_account.balances_free(), balances_free_expected);
310 let mut balances_locked_expected = HashMap::new();
311 balances_locked_expected.insert(Currency::from("USD"), Money::from("25000 USD"));
312 assert_eq!(cash_account.balances_locked(), balances_locked_expected);
313 }
314
315 #[rstest]
316 fn test_instantiate_multi_asset_cash_account(
317 cash_account_multi: CashAccount,
318 cash_account_state_multi: AccountState,
319 ) {
320 assert_eq!(cash_account_multi.id, AccountId::from("SIM-001"));
321 assert_eq!(cash_account_multi.account_type, AccountType::Cash);
322 assert_eq!(
323 cash_account_multi.last_event(),
324 Some(cash_account_state_multi.clone())
325 );
326 assert_eq!(cash_account_state_multi.base_currency, None);
327 assert_eq!(cash_account_multi.events(), vec![cash_account_state_multi]);
328 assert_eq!(cash_account_multi.event_count(), 1);
329 assert_eq!(
330 cash_account_multi.balance_total(Some(Currency::BTC())),
331 Some(Money::from("10 BTC"))
332 );
333 assert_eq!(
334 cash_account_multi.balance_total(Some(Currency::ETH())),
335 Some(Money::from("20 ETH"))
336 );
337 assert_eq!(
338 cash_account_multi.balance_free(Some(Currency::BTC())),
339 Some(Money::from("10 BTC"))
340 );
341 assert_eq!(
342 cash_account_multi.balance_free(Some(Currency::ETH())),
343 Some(Money::from("20 ETH"))
344 );
345 assert_eq!(
346 cash_account_multi.balance_locked(Some(Currency::BTC())),
347 Some(Money::from("0 BTC"))
348 );
349 assert_eq!(
350 cash_account_multi.balance_locked(Some(Currency::ETH())),
351 Some(Money::from("0 ETH"))
352 );
353 let mut balances_total_expected = HashMap::new();
354 balances_total_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
355 balances_total_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
356 assert_eq!(cash_account_multi.balances_total(), balances_total_expected);
357 let mut balances_free_expected = HashMap::new();
358 balances_free_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
359 balances_free_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
360 assert_eq!(cash_account_multi.balances_free(), balances_free_expected);
361 let mut balances_locked_expected = HashMap::new();
362 balances_locked_expected.insert(Currency::from("BTC"), Money::from("0 BTC"));
363 balances_locked_expected.insert(Currency::from("ETH"), Money::from("0 ETH"));
364 assert_eq!(
365 cash_account_multi.balances_locked(),
366 balances_locked_expected
367 );
368 }
369
370 #[rstest]
371 fn test_apply_given_new_state_event_updates_correctly(
372 mut cash_account_multi: CashAccount,
373 cash_account_state_multi: AccountState,
374 cash_account_state_multi_changed_btc: AccountState,
375 ) {
376 cash_account_multi.apply(cash_account_state_multi_changed_btc.clone());
378 assert_eq!(
379 cash_account_multi.last_event(),
380 Some(cash_account_state_multi_changed_btc.clone())
381 );
382 assert_eq!(
383 cash_account_multi.events,
384 vec![
385 cash_account_state_multi,
386 cash_account_state_multi_changed_btc
387 ]
388 );
389 assert_eq!(cash_account_multi.event_count(), 2);
390 assert_eq!(
391 cash_account_multi.balance_total(Some(Currency::BTC())),
392 Some(Money::from("9 BTC"))
393 );
394 assert_eq!(
395 cash_account_multi.balance_free(Some(Currency::BTC())),
396 Some(Money::from("8.5 BTC"))
397 );
398 assert_eq!(
399 cash_account_multi.balance_locked(Some(Currency::BTC())),
400 Some(Money::from("0.5 BTC"))
401 );
402 assert_eq!(
403 cash_account_multi.balance_total(Some(Currency::ETH())),
404 Some(Money::from("20 ETH"))
405 );
406 assert_eq!(
407 cash_account_multi.balance_free(Some(Currency::ETH())),
408 Some(Money::from("20 ETH"))
409 );
410 assert_eq!(
411 cash_account_multi.balance_locked(Some(Currency::ETH())),
412 Some(Money::from("0 ETH"))
413 );
414 }
415
416 #[rstest]
417 fn test_calculate_balance_locked_buy(
418 mut cash_account_million_usd: CashAccount,
419 audusd_sim: CurrencyPair,
420 ) {
421 let balance_locked = cash_account_million_usd
422 .calculate_balance_locked(
423 audusd_sim.into_any(),
424 OrderSide::Buy,
425 Quantity::from("1000000"),
426 Price::from("0.8"),
427 None,
428 )
429 .unwrap();
430 assert_eq!(balance_locked, Money::from("800032 USD"));
431 }
432
433 #[rstest]
434 fn test_calculate_balance_locked_sell(
435 mut cash_account_million_usd: CashAccount,
436 audusd_sim: CurrencyPair,
437 ) {
438 let balance_locked = cash_account_million_usd
439 .calculate_balance_locked(
440 audusd_sim.into_any(),
441 OrderSide::Sell,
442 Quantity::from("1000000"),
443 Price::from("0.8"),
444 None,
445 )
446 .unwrap();
447 assert_eq!(balance_locked, Money::from("1000040 AUD"));
448 }
449
450 #[rstest]
451 fn test_calculate_balance_locked_sell_no_base_currency(
452 mut cash_account_million_usd: CashAccount,
453 equity_aapl: Equity,
454 ) {
455 let balance_locked = cash_account_million_usd
456 .calculate_balance_locked(
457 equity_aapl.into_any(),
458 OrderSide::Sell,
459 Quantity::from("100"),
460 Price::from("1500.0"),
461 None,
462 )
463 .unwrap();
464 assert_eq!(balance_locked, Money::from("100 USD"));
465 }
466
467 #[rstest]
468 fn test_calculate_pnls_for_single_currency_cash_account(
469 cash_account_million_usd: CashAccount,
470 audusd_sim: CurrencyPair,
471 ) {
472 let audusd_sim = InstrumentAny::CurrencyPair(audusd_sim);
473 let order = OrderTestBuilder::new(OrderType::Market)
474 .instrument_id(audusd_sim.id())
475 .side(OrderSide::Buy)
476 .quantity(Quantity::from("1000000"))
477 .build();
478 let fill = TestOrderEventStubs::filled(
479 &order,
480 &audusd_sim,
481 None,
482 Some(PositionId::new("P-123456")),
483 Some(Price::from("0.8")),
484 None,
485 None,
486 None,
487 None,
488 Some(AccountId::from("SIM-001")),
489 );
490 let position = Position::new(&audusd_sim, fill.clone().into());
491 let pnls = cash_account_million_usd
492 .calculate_pnls(audusd_sim, fill.into(), Some(position)) .unwrap();
494 assert_eq!(pnls, vec![Money::from("-800000 USD")]);
495 }
496
497 #[rstest]
498 fn test_calculate_pnls_for_multi_currency_cash_account_btcusdt(
499 cash_account_multi: CashAccount,
500 currency_pair_btcusdt: CurrencyPair,
501 ) {
502 let btcusdt = InstrumentAny::CurrencyPair(currency_pair_btcusdt);
503 let order1 = OrderTestBuilder::new(OrderType::Market)
504 .instrument_id(currency_pair_btcusdt.id)
505 .side(OrderSide::Sell)
506 .quantity(Quantity::from("0.5"))
507 .build();
508 let fill1 = TestOrderEventStubs::filled(
509 &order1,
510 &btcusdt,
511 None,
512 Some(PositionId::new("P-123456")),
513 Some(Price::from("45500.00")),
514 None,
515 None,
516 None,
517 None,
518 Some(AccountId::from("SIM-001")),
519 );
520 let position = Position::new(&btcusdt, fill1.clone().into());
521 let result1 = cash_account_multi
522 .calculate_pnls(
523 currency_pair_btcusdt.into_any(),
524 fill1.into(), Some(position.clone()),
526 )
527 .unwrap();
528 let order2 = OrderTestBuilder::new(OrderType::Market)
529 .instrument_id(currency_pair_btcusdt.id)
530 .side(OrderSide::Buy)
531 .quantity(Quantity::from("0.5"))
532 .build();
533 let fill2 = TestOrderEventStubs::filled(
534 &order2,
535 &btcusdt,
536 None,
537 Some(PositionId::new("P-123456")),
538 Some(Price::from("45500.00")),
539 None,
540 None,
541 None,
542 None,
543 Some(AccountId::from("SIM-001")),
544 );
545 let result2 = cash_account_multi
546 .calculate_pnls(
547 currency_pair_btcusdt.into_any(),
548 fill2.into(),
549 Some(position),
550 )
551 .unwrap();
552 let result1_set: HashSet<Money> = result1.into_iter().collect();
554 let result1_expected: HashSet<Money> =
555 vec![Money::from("22750 USDT"), Money::from("-0.5 BTC")]
556 .into_iter()
557 .collect();
558 let result2_set: HashSet<Money> = result2.into_iter().collect();
559 let result2_expected: HashSet<Money> =
560 vec![Money::from("-22750 USDT"), Money::from("0.5 BTC")]
561 .into_iter()
562 .collect();
563 assert_eq!(result1_set, result1_expected);
564 assert_eq!(result2_set, result2_expected);
565 }
566
567 #[rstest]
568 #[case(false, Money::from("-0.00218331 BTC"))]
569 #[case(true, Money::from("-25.0 USD"))]
570 fn test_calculate_commission_for_inverse_maker_crypto(
571 #[case] use_quote_for_inverse: bool,
572 #[case] expected: Money,
573 cash_account_million_usd: CashAccount,
574 xbtusd_bitmex: CryptoPerpetual,
575 ) {
576 let result = cash_account_million_usd
577 .calculate_commission(
578 xbtusd_bitmex.into_any(),
579 Quantity::from("100000"),
580 Price::from("11450.50"),
581 LiquiditySide::Maker,
582 Some(use_quote_for_inverse),
583 )
584 .unwrap();
585 assert_eq!(result, expected);
586 }
587
588 #[rstest]
589 fn test_calculate_commission_for_taker_fx(
590 cash_account_million_usd: CashAccount,
591 audusd_sim: CurrencyPair,
592 ) {
593 let result = cash_account_million_usd
594 .calculate_commission(
595 audusd_sim.into_any(),
596 Quantity::from("1500000"),
597 Price::from("0.8005"),
598 LiquiditySide::Taker,
599 None,
600 )
601 .unwrap();
602 assert_eq!(result, Money::from("24.02 USD"));
603 }
604
605 #[rstest]
606 fn test_calculate_commission_crypto_taker(
607 cash_account_million_usd: CashAccount,
608 xbtusd_bitmex: CryptoPerpetual,
609 ) {
610 let result = cash_account_million_usd
611 .calculate_commission(
612 xbtusd_bitmex.into_any(),
613 Quantity::from("100000"),
614 Price::from("11450.50"),
615 LiquiditySide::Taker,
616 None,
617 )
618 .unwrap();
619 assert_eq!(result, Money::from("0.00654993 BTC"));
620 }
621
622 #[rstest]
623 fn test_calculate_commission_fx_taker(cash_account_million_usd: CashAccount) {
624 let instrument = usdjpy_idealpro();
625 let result = cash_account_million_usd
626 .calculate_commission(
627 instrument.into_any(),
628 Quantity::from("2200000"),
629 Price::from("120.310"),
630 LiquiditySide::Taker,
631 None,
632 )
633 .unwrap();
634 assert_eq!(result, Money::from("5294 JPY"));
635 }
636}