1use std::{
19 collections::HashMap,
20 fmt::Display,
21 ops::{Deref, DerefMut},
22};
23
24use rust_decimal::{Decimal, prelude::ToPrimitive};
25use serde::{Deserialize, Serialize};
26
27use crate::{
28 accounts::{Account, base::BaseAccount},
29 enums::{AccountType, LiquiditySide, OrderSide},
30 events::{AccountState, OrderFilled},
31 identifiers::AccountId,
32 instruments::InstrumentAny,
33 position::Position,
34 types::{AccountBalance, Currency, Money, Price, Quantity},
35};
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38#[cfg_attr(
39 feature = "python",
40 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
41)]
42pub struct CashAccount {
43 pub base: BaseAccount,
44 pub allow_borrowing: bool,
45}
46
47impl CashAccount {
48 pub fn new(event: AccountState, calculate_account_state: bool, allow_borrowing: bool) -> Self {
50 Self {
51 base: BaseAccount::new(event, calculate_account_state),
52 allow_borrowing,
53 }
54 }
55
56 #[must_use]
57 pub fn is_cash_account(&self) -> bool {
58 self.account_type == AccountType::Cash
59 }
60 #[must_use]
61 pub fn is_margin_account(&self) -> bool {
62 self.account_type == AccountType::Margin
63 }
64
65 #[must_use]
66 pub const fn is_unleveraged(&self) -> bool {
67 false
68 }
69
70 pub fn recalculate_balance(&mut self, currency: Currency) {
76 let current_balance = match self.balances.get(¤cy) {
77 Some(balance) => *balance,
78 None => {
79 return;
80 }
81 };
82
83 let total_locked = self
84 .balances
85 .values()
86 .filter(|balance| balance.currency == currency)
87 .fold(Decimal::ZERO, |acc, balance| {
88 acc + balance.locked.as_decimal()
89 });
90
91 let new_balance = AccountBalance::new(
92 current_balance.total,
93 Money::new(total_locked.to_f64().unwrap(), currency),
94 Money::new(
95 (current_balance.total.as_decimal() - total_locked)
96 .to_f64()
97 .unwrap(),
98 currency,
99 ),
100 );
101
102 self.balances.insert(currency, new_balance);
103 }
104}
105
106impl Account for CashAccount {
107 fn id(&self) -> AccountId {
108 self.id
109 }
110
111 fn account_type(&self) -> AccountType {
112 self.account_type
113 }
114
115 fn base_currency(&self) -> Option<Currency> {
116 self.base_currency
117 }
118
119 fn is_cash_account(&self) -> bool {
120 self.account_type == AccountType::Cash
121 }
122
123 fn is_margin_account(&self) -> bool {
124 self.account_type == AccountType::Margin
125 }
126
127 fn calculated_account_state(&self) -> bool {
128 false }
130
131 fn balance_total(&self, currency: Option<Currency>) -> Option<Money> {
132 self.base_balance_total(currency)
133 }
134
135 fn balances_total(&self) -> HashMap<Currency, Money> {
136 self.base_balances_total()
137 }
138
139 fn balance_free(&self, currency: Option<Currency>) -> Option<Money> {
140 self.base_balance_free(currency)
141 }
142
143 fn balances_free(&self) -> HashMap<Currency, Money> {
144 self.base_balances_free()
145 }
146
147 fn balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
148 self.base_balance_locked(currency)
149 }
150
151 fn balances_locked(&self) -> HashMap<Currency, Money> {
152 self.base_balances_locked()
153 }
154
155 fn balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
156 self.base_balance(currency)
157 }
158
159 fn last_event(&self) -> Option<AccountState> {
160 self.base_last_event()
161 }
162
163 fn events(&self) -> Vec<AccountState> {
164 self.events.clone()
165 }
166
167 fn event_count(&self) -> usize {
168 self.events.len()
169 }
170
171 fn currencies(&self) -> Vec<Currency> {
172 self.balances.keys().copied().collect()
173 }
174
175 fn starting_balances(&self) -> HashMap<Currency, Money> {
176 self.balances_starting.clone()
177 }
178
179 fn balances(&self) -> HashMap<Currency, AccountBalance> {
180 self.balances.clone()
181 }
182
183 fn apply(&mut self, event: AccountState) {
184 if !self.allow_borrowing {
186 for balance in &event.balances {
187 if balance.total.as_decimal() < rust_decimal::Decimal::ZERO {
188 panic!(
189 "Account balance negative: {} {}",
190 balance.total.as_decimal(),
191 balance.currency.code
192 );
193 }
194 }
195 }
196 self.base_apply(event);
197 }
198
199 fn purge_account_events(&mut self, ts_now: nautilus_core::UnixNanos, lookback_secs: u64) {
200 self.base.base_purge_account_events(ts_now, lookback_secs);
201 }
202
203 fn calculate_balance_locked(
204 &mut self,
205 instrument: InstrumentAny,
206 side: OrderSide,
207 quantity: Quantity,
208 price: Price,
209 use_quote_for_inverse: Option<bool>,
210 ) -> anyhow::Result<Money> {
211 self.base_calculate_balance_locked(instrument, side, quantity, price, use_quote_for_inverse)
212 }
213
214 fn calculate_pnls(
215 &self,
216 instrument: InstrumentAny, fill: OrderFilled, position: Option<Position>,
219 ) -> anyhow::Result<Vec<Money>> {
220 self.base_calculate_pnls(instrument, fill, position)
221 }
222
223 fn calculate_commission(
224 &self,
225 instrument: InstrumentAny,
226 last_qty: Quantity,
227 last_px: Price,
228 liquidity_side: LiquiditySide,
229 use_quote_for_inverse: Option<bool>,
230 ) -> anyhow::Result<Money> {
231 self.base_calculate_commission(
232 instrument,
233 last_qty,
234 last_px,
235 liquidity_side,
236 use_quote_for_inverse,
237 )
238 }
239}
240
241impl Deref for CashAccount {
242 type Target = BaseAccount;
243
244 fn deref(&self) -> &Self::Target {
245 &self.base
246 }
247}
248
249impl DerefMut for CashAccount {
250 fn deref_mut(&mut self) -> &mut Self::Target {
251 &mut self.base
252 }
253}
254
255impl PartialEq for CashAccount {
256 fn eq(&self, other: &Self) -> bool {
257 self.id == other.id
258 }
259}
260
261impl Eq for CashAccount {}
262
263impl Display for CashAccount {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 write!(
266 f,
267 "CashAccount(id={}, type={}, base={})",
268 self.id,
269 self.account_type,
270 self.base_currency.map_or_else(
271 || "None".to_string(),
272 |base_currency| format!("{}", base_currency.code)
273 ),
274 )
275 }
276}
277
278#[cfg(test)]
282mod tests {
283 use std::collections::{HashMap, HashSet};
284
285 use rstest::rstest;
286
287 use crate::{
288 accounts::{Account, CashAccount, stubs::*},
289 enums::{AccountType, LiquiditySide, OrderSide, OrderType},
290 events::{AccountState, account::stubs::*},
291 identifiers::{AccountId, position_id::PositionId},
292 instruments::{CryptoPerpetual, CurrencyPair, Equity, Instrument, InstrumentAny, stubs::*},
293 orders::{builder::OrderTestBuilder, stubs::TestOrderEventStubs},
294 position::Position,
295 types::{Currency, Money, Price, Quantity},
296 };
297
298 #[rstest]
299 fn test_display(cash_account: CashAccount) {
300 assert_eq!(
301 format!("{cash_account}"),
302 "CashAccount(id=SIM-001, type=CASH, base=USD)"
303 );
304 }
305
306 #[rstest]
307 fn test_instantiate_single_asset_cash_account(
308 cash_account: CashAccount,
309 cash_account_state: AccountState,
310 ) {
311 assert_eq!(cash_account.id, AccountId::from("SIM-001"));
312 assert_eq!(cash_account.account_type, AccountType::Cash);
313 assert_eq!(cash_account.base_currency, Some(Currency::from("USD")));
314 assert_eq!(cash_account.last_event(), Some(cash_account_state.clone()));
315 assert_eq!(cash_account.events(), vec![cash_account_state]);
316 assert_eq!(cash_account.event_count(), 1);
317 assert_eq!(
318 cash_account.balance_total(None),
319 Some(Money::from("1525000 USD"))
320 );
321 assert_eq!(
322 cash_account.balance_free(None),
323 Some(Money::from("1500000 USD"))
324 );
325 assert_eq!(
326 cash_account.balance_locked(None),
327 Some(Money::from("25000 USD"))
328 );
329 let mut balances_total_expected = HashMap::new();
330 balances_total_expected.insert(Currency::from("USD"), Money::from("1525000 USD"));
331 assert_eq!(cash_account.balances_total(), balances_total_expected);
332 let mut balances_free_expected = HashMap::new();
333 balances_free_expected.insert(Currency::from("USD"), Money::from("1500000 USD"));
334 assert_eq!(cash_account.balances_free(), balances_free_expected);
335 let mut balances_locked_expected = HashMap::new();
336 balances_locked_expected.insert(Currency::from("USD"), Money::from("25000 USD"));
337 assert_eq!(cash_account.balances_locked(), balances_locked_expected);
338 }
339
340 #[rstest]
341 fn test_instantiate_multi_asset_cash_account(
342 cash_account_multi: CashAccount,
343 cash_account_state_multi: AccountState,
344 ) {
345 assert_eq!(cash_account_multi.id, AccountId::from("SIM-001"));
346 assert_eq!(cash_account_multi.account_type, AccountType::Cash);
347 assert_eq!(
348 cash_account_multi.last_event(),
349 Some(cash_account_state_multi.clone())
350 );
351 assert_eq!(cash_account_state_multi.base_currency, None);
352 assert_eq!(cash_account_multi.events(), vec![cash_account_state_multi]);
353 assert_eq!(cash_account_multi.event_count(), 1);
354 assert_eq!(
355 cash_account_multi.balance_total(Some(Currency::BTC())),
356 Some(Money::from("10 BTC"))
357 );
358 assert_eq!(
359 cash_account_multi.balance_total(Some(Currency::ETH())),
360 Some(Money::from("20 ETH"))
361 );
362 assert_eq!(
363 cash_account_multi.balance_free(Some(Currency::BTC())),
364 Some(Money::from("10 BTC"))
365 );
366 assert_eq!(
367 cash_account_multi.balance_free(Some(Currency::ETH())),
368 Some(Money::from("20 ETH"))
369 );
370 assert_eq!(
371 cash_account_multi.balance_locked(Some(Currency::BTC())),
372 Some(Money::from("0 BTC"))
373 );
374 assert_eq!(
375 cash_account_multi.balance_locked(Some(Currency::ETH())),
376 Some(Money::from("0 ETH"))
377 );
378 let mut balances_total_expected = HashMap::new();
379 balances_total_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
380 balances_total_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
381 assert_eq!(cash_account_multi.balances_total(), balances_total_expected);
382 let mut balances_free_expected = HashMap::new();
383 balances_free_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
384 balances_free_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
385 assert_eq!(cash_account_multi.balances_free(), balances_free_expected);
386 let mut balances_locked_expected = HashMap::new();
387 balances_locked_expected.insert(Currency::from("BTC"), Money::from("0 BTC"));
388 balances_locked_expected.insert(Currency::from("ETH"), Money::from("0 ETH"));
389 assert_eq!(
390 cash_account_multi.balances_locked(),
391 balances_locked_expected
392 );
393 }
394
395 #[rstest]
396 fn test_apply_given_new_state_event_updates_correctly(
397 mut cash_account_multi: CashAccount,
398 cash_account_state_multi: AccountState,
399 cash_account_state_multi_changed_btc: AccountState,
400 ) {
401 cash_account_multi.apply(cash_account_state_multi_changed_btc.clone());
403 assert_eq!(
404 cash_account_multi.last_event(),
405 Some(cash_account_state_multi_changed_btc.clone())
406 );
407 assert_eq!(
408 cash_account_multi.events,
409 vec![
410 cash_account_state_multi,
411 cash_account_state_multi_changed_btc
412 ]
413 );
414 assert_eq!(cash_account_multi.event_count(), 2);
415 assert_eq!(
416 cash_account_multi.balance_total(Some(Currency::BTC())),
417 Some(Money::from("9 BTC"))
418 );
419 assert_eq!(
420 cash_account_multi.balance_free(Some(Currency::BTC())),
421 Some(Money::from("8.5 BTC"))
422 );
423 assert_eq!(
424 cash_account_multi.balance_locked(Some(Currency::BTC())),
425 Some(Money::from("0.5 BTC"))
426 );
427 assert_eq!(
428 cash_account_multi.balance_total(Some(Currency::ETH())),
429 Some(Money::from("20 ETH"))
430 );
431 assert_eq!(
432 cash_account_multi.balance_free(Some(Currency::ETH())),
433 Some(Money::from("20 ETH"))
434 );
435 assert_eq!(
436 cash_account_multi.balance_locked(Some(Currency::ETH())),
437 Some(Money::from("0 ETH"))
438 );
439 }
440
441 #[rstest]
442 fn test_calculate_balance_locked_buy(
443 mut cash_account_million_usd: CashAccount,
444 audusd_sim: CurrencyPair,
445 ) {
446 let balance_locked = cash_account_million_usd
447 .calculate_balance_locked(
448 audusd_sim.into_any(),
449 OrderSide::Buy,
450 Quantity::from("1000000"),
451 Price::from("0.8"),
452 None,
453 )
454 .unwrap();
455 assert_eq!(balance_locked, Money::from("800000 USD"));
456 }
457
458 #[rstest]
459 fn test_calculate_balance_locked_sell(
460 mut cash_account_million_usd: CashAccount,
461 audusd_sim: CurrencyPair,
462 ) {
463 let balance_locked = cash_account_million_usd
464 .calculate_balance_locked(
465 audusd_sim.into_any(),
466 OrderSide::Sell,
467 Quantity::from("1000000"),
468 Price::from("0.8"),
469 None,
470 )
471 .unwrap();
472 assert_eq!(balance_locked, Money::from("1000000 AUD"));
473 }
474
475 #[rstest]
476 fn test_calculate_balance_locked_sell_no_base_currency(
477 mut cash_account_million_usd: CashAccount,
478 equity_aapl: Equity,
479 ) {
480 let balance_locked = cash_account_million_usd
481 .calculate_balance_locked(
482 equity_aapl.into_any(),
483 OrderSide::Sell,
484 Quantity::from("100"),
485 Price::from("1500.0"),
486 None,
487 )
488 .unwrap();
489 assert_eq!(balance_locked, Money::from("100 USD"));
490 }
491
492 #[rstest]
493 fn test_calculate_pnls_for_single_currency_cash_account(
494 cash_account_million_usd: CashAccount,
495 audusd_sim: CurrencyPair,
496 ) {
497 let audusd_sim = InstrumentAny::CurrencyPair(audusd_sim);
498 let order = OrderTestBuilder::new(OrderType::Market)
499 .instrument_id(audusd_sim.id())
500 .side(OrderSide::Buy)
501 .quantity(Quantity::from("1000000"))
502 .build();
503 let fill = TestOrderEventStubs::filled(
504 &order,
505 &audusd_sim,
506 None,
507 Some(PositionId::new("P-123456")),
508 Some(Price::from("0.8")),
509 None,
510 None,
511 None,
512 None,
513 Some(AccountId::from("SIM-001")),
514 );
515 let position = Position::new(&audusd_sim, fill.clone().into());
516 let pnls = cash_account_million_usd
517 .calculate_pnls(audusd_sim, fill.into(), Some(position)) .unwrap();
519 assert_eq!(pnls, vec![Money::from("-800000 USD")]);
520 }
521
522 #[rstest]
523 fn test_calculate_pnls_for_multi_currency_cash_account_btcusdt(
524 cash_account_multi: CashAccount,
525 currency_pair_btcusdt: CurrencyPair,
526 ) {
527 let btcusdt = InstrumentAny::CurrencyPair(currency_pair_btcusdt);
528 let order1 = OrderTestBuilder::new(OrderType::Market)
529 .instrument_id(currency_pair_btcusdt.id)
530 .side(OrderSide::Sell)
531 .quantity(Quantity::from("0.5"))
532 .build();
533 let fill1 = TestOrderEventStubs::filled(
534 &order1,
535 &btcusdt,
536 None,
537 Some(PositionId::new("P-123456")),
538 Some(Price::from("45500.00")),
539 None,
540 None,
541 None,
542 None,
543 Some(AccountId::from("SIM-001")),
544 );
545 let position = Position::new(&btcusdt, fill1.clone().into());
546 let result1 = cash_account_multi
547 .calculate_pnls(
548 currency_pair_btcusdt.into_any(),
549 fill1.into(), Some(position.clone()),
551 )
552 .unwrap();
553 let order2 = OrderTestBuilder::new(OrderType::Market)
554 .instrument_id(currency_pair_btcusdt.id)
555 .side(OrderSide::Buy)
556 .quantity(Quantity::from("0.5"))
557 .build();
558 let fill2 = TestOrderEventStubs::filled(
559 &order2,
560 &btcusdt,
561 None,
562 Some(PositionId::new("P-123456")),
563 Some(Price::from("45500.00")),
564 None,
565 None,
566 None,
567 None,
568 Some(AccountId::from("SIM-001")),
569 );
570 let result2 = cash_account_multi
571 .calculate_pnls(
572 currency_pair_btcusdt.into_any(),
573 fill2.into(),
574 Some(position),
575 )
576 .unwrap();
577 let result1_set: HashSet<Money> = result1.into_iter().collect();
579 let result1_expected: HashSet<Money> =
580 vec![Money::from("22750 USDT"), Money::from("-0.5 BTC")]
581 .into_iter()
582 .collect();
583 let result2_set: HashSet<Money> = result2.into_iter().collect();
584 let result2_expected: HashSet<Money> =
585 vec![Money::from("-22750 USDT"), Money::from("0.5 BTC")]
586 .into_iter()
587 .collect();
588 assert_eq!(result1_set, result1_expected);
589 assert_eq!(result2_set, result2_expected);
590 }
591
592 #[rstest]
593 #[case(false, Money::from("-0.00218331 BTC"))]
594 #[case(true, Money::from("-25.0 USD"))]
595 fn test_calculate_commission_for_inverse_maker_crypto(
596 #[case] use_quote_for_inverse: bool,
597 #[case] expected: Money,
598 cash_account_million_usd: CashAccount,
599 xbtusd_bitmex: CryptoPerpetual,
600 ) {
601 let result = cash_account_million_usd
602 .calculate_commission(
603 xbtusd_bitmex.into_any(),
604 Quantity::from("100000"),
605 Price::from("11450.50"),
606 LiquiditySide::Maker,
607 Some(use_quote_for_inverse),
608 )
609 .unwrap();
610 assert_eq!(result, expected);
611 }
612
613 #[rstest]
614 fn test_calculate_commission_for_taker_fx(
615 cash_account_million_usd: CashAccount,
616 audusd_sim: CurrencyPair,
617 ) {
618 let result = cash_account_million_usd
619 .calculate_commission(
620 audusd_sim.into_any(),
621 Quantity::from("1500000"),
622 Price::from("0.8005"),
623 LiquiditySide::Taker,
624 None,
625 )
626 .unwrap();
627 assert_eq!(result, Money::from("24.02 USD"));
628 }
629
630 #[rstest]
631 fn test_calculate_commission_crypto_taker(
632 cash_account_million_usd: CashAccount,
633 xbtusd_bitmex: CryptoPerpetual,
634 ) {
635 let result = cash_account_million_usd
636 .calculate_commission(
637 xbtusd_bitmex.into_any(),
638 Quantity::from("100000"),
639 Price::from("11450.50"),
640 LiquiditySide::Taker,
641 None,
642 )
643 .unwrap();
644 assert_eq!(result, Money::from("0.00654993 BTC"));
645 }
646
647 #[rstest]
648 fn test_calculate_commission_fx_taker(cash_account_million_usd: CashAccount) {
649 let instrument = usdjpy_idealpro();
650 let result = cash_account_million_usd
651 .calculate_commission(
652 instrument.into_any(),
653 Quantity::from("2200000"),
654 Price::from("120.310"),
655 LiquiditySide::Taker,
656 None,
657 )
658 .unwrap();
659 assert_eq!(result, Money::from("5294 JPY"));
660 }
661}