1#![allow(dead_code)]
17
18use std::{
19 collections::HashMap,
20 fmt::Display,
21 hash::{Hash, Hasher},
22 ops::{Deref, DerefMut},
23};
24
25use rust_decimal::prelude::ToPrimitive;
26use serde::{Deserialize, Serialize};
27
28use crate::{
29 accounts::base::{Account, BaseAccount},
30 enums::{AccountType, LiquiditySide, OrderSide},
31 events::{AccountState, OrderFilled},
32 identifiers::{AccountId, InstrumentId},
33 instruments::{Instrument, InstrumentAny},
34 position::Position,
35 types::{AccountBalance, Currency, MarginBalance, Money, Price, Quantity},
36};
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[cfg_attr(
40 feature = "python",
41 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
42)]
43pub struct MarginAccount {
44 pub base: BaseAccount,
45 pub leverages: HashMap<InstrumentId, f64>,
46 pub margins: HashMap<InstrumentId, MarginBalance>,
47 pub default_leverage: f64,
48}
49
50impl MarginAccount {
51 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
53 Self {
54 base: BaseAccount::new(event, calculate_account_state),
55 leverages: HashMap::new(),
56 margins: HashMap::new(),
57 default_leverage: 1.0,
58 }
59 }
60
61 pub fn set_default_leverage(&mut self, leverage: f64) {
62 self.default_leverage = leverage;
63 }
64
65 pub fn set_leverage(&mut self, instrument_id: InstrumentId, leverage: f64) {
66 self.leverages.insert(instrument_id, leverage);
67 }
68
69 #[must_use]
70 pub fn get_leverage(&self, instrument_id: &InstrumentId) -> f64 {
71 *self
72 .leverages
73 .get(instrument_id)
74 .unwrap_or(&self.default_leverage)
75 }
76
77 #[must_use]
78 pub fn is_unleveraged(&self, instrument_id: InstrumentId) -> bool {
79 self.get_leverage(&instrument_id) == 1.0
80 }
81
82 #[must_use]
83 pub fn is_cash_account(&self) -> bool {
84 self.account_type == AccountType::Cash
85 }
86 #[must_use]
87 pub fn is_margin_account(&self) -> bool {
88 self.account_type == AccountType::Margin
89 }
90
91 #[must_use]
92 pub fn initial_margins(&self) -> HashMap<InstrumentId, Money> {
93 let mut initial_margins: HashMap<InstrumentId, Money> = HashMap::new();
94 self.margins.values().for_each(|margin_balance| {
95 initial_margins.insert(margin_balance.instrument_id, margin_balance.initial);
96 });
97 initial_margins
98 }
99
100 #[must_use]
101 pub fn maintenance_margins(&self) -> HashMap<InstrumentId, Money> {
102 let mut maintenance_margins: HashMap<InstrumentId, Money> = HashMap::new();
103 self.margins.values().for_each(|margin_balance| {
104 maintenance_margins.insert(margin_balance.instrument_id, margin_balance.maintenance);
105 });
106 maintenance_margins
107 }
108
109 pub fn update_initial_margin(&mut self, instrument_id: InstrumentId, margin_init: Money) {
110 let margin_balance = self.margins.get(&instrument_id);
111 if margin_balance.is_none() {
112 self.margins.insert(
113 instrument_id,
114 MarginBalance::new(
115 margin_init,
116 Money::new(0.0, margin_init.currency),
117 instrument_id,
118 ),
119 );
120 } else {
121 let mut new_margin_balance = *margin_balance.unwrap();
123 new_margin_balance.initial = margin_init;
124 self.margins.insert(instrument_id, new_margin_balance);
125 }
126 self.recalculate_balance(margin_init.currency);
127 }
128
129 #[must_use]
130 pub fn initial_margin(&self, instrument_id: InstrumentId) -> Money {
131 let margin_balance = self.margins.get(&instrument_id);
132 assert!(
133 margin_balance.is_some(),
134 "Cannot get margin_init when no margin_balance"
135 );
136 margin_balance.unwrap().initial
137 }
138
139 pub fn update_maintenance_margin(
140 &mut self,
141 instrument_id: InstrumentId,
142 margin_maintenance: Money,
143 ) {
144 let margin_balance = self.margins.get(&instrument_id);
145 if margin_balance.is_none() {
146 self.margins.insert(
147 instrument_id,
148 MarginBalance::new(
149 Money::new(0.0, margin_maintenance.currency),
150 margin_maintenance,
151 instrument_id,
152 ),
153 );
154 } else {
155 let mut new_margin_balance = *margin_balance.unwrap();
157 new_margin_balance.maintenance = margin_maintenance;
158 self.margins.insert(instrument_id, new_margin_balance);
159 }
160 self.recalculate_balance(margin_maintenance.currency);
161 }
162
163 #[must_use]
164 pub fn maintenance_margin(&self, instrument_id: InstrumentId) -> Money {
165 let margin_balance = self.margins.get(&instrument_id);
166 assert!(
167 margin_balance.is_some(),
168 "Cannot get maintenance_margin when no margin_balance"
169 );
170 margin_balance.unwrap().maintenance
171 }
172
173 pub fn calculate_initial_margin<T: Instrument>(
174 &mut self,
175 instrument: T,
176 quantity: Quantity,
177 price: Price,
178 use_quote_for_inverse: Option<bool>,
179 ) -> Money {
180 let notional = instrument.calculate_notional_value(quantity, price, use_quote_for_inverse);
181 let leverage = self.get_leverage(&instrument.id());
182 if leverage == 0.0 {
183 self.leverages
184 .insert(instrument.id(), self.default_leverage);
185 }
186 let adjusted_notional = notional / leverage;
187 let initial_margin_f64 = instrument.margin_init().to_f64().unwrap();
188 let mut margin = adjusted_notional * initial_margin_f64;
189 margin += adjusted_notional * instrument.taker_fee().to_f64().unwrap() * 2.0;
191 let use_quote_for_inverse = use_quote_for_inverse.unwrap_or(false);
192 if instrument.is_inverse() && !use_quote_for_inverse {
193 Money::new(margin, instrument.base_currency().unwrap())
194 } else {
195 Money::new(margin, instrument.quote_currency())
196 }
197 }
198
199 pub fn calculate_maintenance_margin<T: Instrument>(
200 &mut self,
201 instrument: T,
202 quantity: Quantity,
203 price: Price,
204 use_quote_for_inverse: Option<bool>,
205 ) -> Money {
206 let notional = instrument.calculate_notional_value(quantity, price, use_quote_for_inverse);
207 let leverage = self.get_leverage(&instrument.id());
208 if leverage == 0.0 {
209 self.leverages
210 .insert(instrument.id(), self.default_leverage);
211 }
212 let adjusted_notional = notional / leverage;
213 let margin_maint_f64 = instrument.margin_maint().to_f64().unwrap();
214 let mut margin = adjusted_notional * margin_maint_f64;
215 margin += adjusted_notional * instrument.taker_fee().to_f64().unwrap();
217 let use_quote_for_inverse = use_quote_for_inverse.unwrap_or(false);
218 if instrument.is_inverse() && !use_quote_for_inverse {
219 Money::new(margin, instrument.base_currency().unwrap())
220 } else {
221 Money::new(margin, instrument.quote_currency())
222 }
223 }
224
225 pub fn recalculate_balance(&mut self, currency: Currency) {
226 let current_balance = match self.balances.get(¤cy) {
227 Some(balance) => balance,
228 None => panic!("Cannot recalculate balance when no starting balance"),
229 };
230
231 let mut total_margin = 0;
232 self.margins.values().for_each(|margin| {
234 if margin.currency == currency {
235 total_margin += margin.initial.raw;
236 total_margin += margin.maintenance.raw;
237 }
238 });
239 let total_free = current_balance.total.raw - total_margin;
240 assert!(
242 total_free >= 0,
243 "Cannot recalculate balance when total_free is less than 0.0"
244 );
245 let new_balance = AccountBalance::new(
246 current_balance.total,
247 Money::from_raw(total_margin, currency),
248 Money::from_raw(total_free, currency),
249 );
250 self.balances.insert(currency, new_balance);
251 }
252}
253
254impl Deref for MarginAccount {
255 type Target = BaseAccount;
256
257 fn deref(&self) -> &Self::Target {
258 &self.base
259 }
260}
261
262impl DerefMut for MarginAccount {
263 fn deref_mut(&mut self) -> &mut Self::Target {
264 &mut self.base
265 }
266}
267
268impl Account for MarginAccount {
269 fn id(&self) -> AccountId {
270 self.id
271 }
272
273 fn account_type(&self) -> AccountType {
274 self.account_type
275 }
276
277 fn base_currency(&self) -> Option<Currency> {
278 self.base_currency
279 }
280
281 fn is_cash_account(&self) -> bool {
282 self.account_type == AccountType::Cash
283 }
284
285 fn is_margin_account(&self) -> bool {
286 self.account_type == AccountType::Margin
287 }
288
289 fn calculated_account_state(&self) -> bool {
290 false }
292
293 fn balance_total(&self, currency: Option<Currency>) -> Option<Money> {
294 self.base_balance_total(currency)
295 }
296 fn balances_total(&self) -> HashMap<Currency, Money> {
297 self.base_balances_total()
298 }
299
300 fn balance_free(&self, currency: Option<Currency>) -> Option<Money> {
301 self.base_balance_free(currency)
302 }
303 fn balances_free(&self) -> HashMap<Currency, Money> {
304 self.base_balances_free()
305 }
306
307 fn balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
308 self.base_balance_locked(currency)
309 }
310 fn balances_locked(&self) -> HashMap<Currency, Money> {
311 self.base_balances_locked()
312 }
313
314 fn balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
315 self.base_balance(currency)
316 }
317
318 fn last_event(&self) -> Option<AccountState> {
319 self.base_last_event()
320 }
321 fn events(&self) -> Vec<AccountState> {
322 self.events.clone()
323 }
324 fn event_count(&self) -> usize {
325 self.events.len()
326 }
327 fn currencies(&self) -> Vec<Currency> {
328 self.balances.keys().copied().collect()
329 }
330 fn starting_balances(&self) -> HashMap<Currency, Money> {
331 self.balances_starting.clone()
332 }
333 fn balances(&self) -> HashMap<Currency, AccountBalance> {
334 self.balances.clone()
335 }
336 fn apply(&mut self, event: AccountState) {
337 self.base_apply(event);
338 }
339 fn calculate_balance_locked(
340 &mut self,
341 instrument: InstrumentAny,
342 side: OrderSide,
343 quantity: Quantity,
344 price: Price,
345 use_quote_for_inverse: Option<bool>,
346 ) -> anyhow::Result<Money> {
347 self.base_calculate_balance_locked(instrument, side, quantity, price, use_quote_for_inverse)
348 }
349 fn calculate_pnls(
350 &self,
351 instrument: InstrumentAny,
352 fill: OrderFilled,
353 position: Option<Position>,
354 ) -> anyhow::Result<Vec<Money>> {
355 self.base_calculate_pnls(instrument, fill, position)
356 }
357 fn calculate_commission(
358 &self,
359 instrument: InstrumentAny,
360 last_qty: Quantity,
361 last_px: Price,
362 liquidity_side: LiquiditySide,
363 use_quote_for_inverse: Option<bool>,
364 ) -> anyhow::Result<Money> {
365 self.base_calculate_commission(
366 instrument,
367 last_qty,
368 last_px,
369 liquidity_side,
370 use_quote_for_inverse,
371 )
372 }
373}
374
375impl PartialEq for MarginAccount {
376 fn eq(&self, other: &Self) -> bool {
377 self.id == other.id
378 }
379}
380
381impl Eq for MarginAccount {}
382
383impl Display for MarginAccount {
384 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385 write!(
386 f,
387 "MarginAccount(id={}, type={}, base={})",
388 self.id,
389 self.account_type,
390 self.base_currency.map_or_else(
391 || "None".to_string(),
392 |base_currency| format!("{}", base_currency.code)
393 ),
394 )
395 }
396}
397
398impl Hash for MarginAccount {
399 fn hash<H: Hasher>(&self, state: &mut H) {
400 self.id.hash(state);
401 }
402}
403
404#[cfg(test)]
408mod tests {
409 use std::collections::HashMap;
410
411 use rstest::rstest;
412
413 use crate::{
414 accounts::{stubs::*, Account, MarginAccount},
415 events::{account::stubs::*, AccountState},
416 identifiers::{stubs::*, InstrumentId},
417 instruments::{stubs::*, CryptoPerpetual, CurrencyPair},
418 types::{Currency, Money, Price, Quantity},
419 };
420
421 #[rstest]
422 fn test_display(margin_account: MarginAccount) {
423 assert_eq!(
424 margin_account.to_string(),
425 "MarginAccount(id=SIM-001, type=MARGIN, base=USD)"
426 );
427 }
428
429 #[rstest]
430 fn test_base_account_properties(
431 margin_account: MarginAccount,
432 margin_account_state: AccountState,
433 ) {
434 assert_eq!(margin_account.base_currency, Some(Currency::from("USD")));
435 assert_eq!(
436 margin_account.last_event(),
437 Some(margin_account_state.clone())
438 );
439 assert_eq!(margin_account.events(), vec![margin_account_state]);
440 assert_eq!(margin_account.event_count(), 1);
441 assert_eq!(
442 margin_account.balance_total(None),
443 Some(Money::from("1525000 USD"))
444 );
445 assert_eq!(
446 margin_account.balance_free(None),
447 Some(Money::from("1500000 USD"))
448 );
449 assert_eq!(
450 margin_account.balance_locked(None),
451 Some(Money::from("25000 USD"))
452 );
453 let mut balances_total_expected = HashMap::new();
454 balances_total_expected.insert(Currency::from("USD"), Money::from("1525000 USD"));
455 assert_eq!(margin_account.balances_total(), balances_total_expected);
456 let mut balances_free_expected = HashMap::new();
457 balances_free_expected.insert(Currency::from("USD"), Money::from("1500000 USD"));
458 assert_eq!(margin_account.balances_free(), balances_free_expected);
459 let mut balances_locked_expected = HashMap::new();
460 balances_locked_expected.insert(Currency::from("USD"), Money::from("25000 USD"));
461 assert_eq!(margin_account.balances_locked(), balances_locked_expected);
462 }
463
464 #[rstest]
465 fn test_set_default_leverage(mut margin_account: MarginAccount) {
466 assert_eq!(margin_account.default_leverage, 1.0);
467 margin_account.set_default_leverage(10.0);
468 assert_eq!(margin_account.default_leverage, 10.0);
469 }
470
471 #[rstest]
472 fn test_get_leverage_default_leverage(
473 margin_account: MarginAccount,
474 instrument_id_aud_usd_sim: InstrumentId,
475 ) {
476 assert_eq!(margin_account.get_leverage(&instrument_id_aud_usd_sim), 1.0);
477 }
478
479 #[rstest]
480 fn test_set_leverage(
481 mut margin_account: MarginAccount,
482 instrument_id_aud_usd_sim: InstrumentId,
483 ) {
484 assert_eq!(margin_account.leverages.len(), 0);
485 margin_account.set_leverage(instrument_id_aud_usd_sim, 10.0);
486 assert_eq!(margin_account.leverages.len(), 1);
487 assert_eq!(
488 margin_account.get_leverage(&instrument_id_aud_usd_sim),
489 10.0
490 );
491 }
492
493 #[rstest]
494 fn test_is_unleveraged_with_leverage_returns_false(
495 mut margin_account: MarginAccount,
496 instrument_id_aud_usd_sim: InstrumentId,
497 ) {
498 margin_account.set_leverage(instrument_id_aud_usd_sim, 10.0);
499 assert!(!margin_account.is_unleveraged(instrument_id_aud_usd_sim));
500 }
501
502 #[rstest]
503 fn test_is_unleveraged_with_no_leverage_returns_true(
504 mut margin_account: MarginAccount,
505 instrument_id_aud_usd_sim: InstrumentId,
506 ) {
507 margin_account.set_leverage(instrument_id_aud_usd_sim, 1.0);
508 assert!(margin_account.is_unleveraged(instrument_id_aud_usd_sim));
509 }
510
511 #[rstest]
512 fn test_is_unleveraged_with_default_leverage_of_1_returns_true(
513 margin_account: MarginAccount,
514 instrument_id_aud_usd_sim: InstrumentId,
515 ) {
516 assert!(margin_account.is_unleveraged(instrument_id_aud_usd_sim));
517 }
518
519 #[rstest]
520 fn test_update_margin_init(
521 mut margin_account: MarginAccount,
522 instrument_id_aud_usd_sim: InstrumentId,
523 ) {
524 assert_eq!(margin_account.margins.len(), 0);
525 let margin = Money::from("10000 USD");
526 margin_account.update_initial_margin(instrument_id_aud_usd_sim, margin);
527 assert_eq!(
528 margin_account.initial_margin(instrument_id_aud_usd_sim),
529 margin
530 );
531 let margins: Vec<Money> = margin_account
532 .margins
533 .values()
534 .map(|margin_balance| margin_balance.initial)
535 .collect();
536 assert_eq!(margins, vec![margin]);
537 }
538
539 #[rstest]
540 fn test_update_margin_maintenance(
541 mut margin_account: MarginAccount,
542 instrument_id_aud_usd_sim: InstrumentId,
543 ) {
544 let margin = Money::from("10000 USD");
545 margin_account.update_maintenance_margin(instrument_id_aud_usd_sim, margin);
546 assert_eq!(
547 margin_account.maintenance_margin(instrument_id_aud_usd_sim),
548 margin
549 );
550 let margins: Vec<Money> = margin_account
551 .margins
552 .values()
553 .map(|margin_balance| margin_balance.maintenance)
554 .collect();
555 assert_eq!(margins, vec![margin]);
556 }
557
558 #[rstest]
559 fn test_calculate_margin_init_with_leverage(
560 mut margin_account: MarginAccount,
561 audusd_sim: CurrencyPair,
562 ) {
563 margin_account.set_leverage(audusd_sim.id, 50.0);
564 let result = margin_account.calculate_initial_margin(
565 audusd_sim,
566 Quantity::from(100_000),
567 Price::from("0.8000"),
568 None,
569 );
570 assert_eq!(result, Money::from("48.06 USD"));
571 }
572
573 #[rstest]
574 fn test_calculate_margin_init_with_default_leverage(
575 mut margin_account: MarginAccount,
576 audusd_sim: CurrencyPair,
577 ) {
578 margin_account.set_default_leverage(10.0);
579 let result = margin_account.calculate_initial_margin(
580 audusd_sim,
581 Quantity::from(100_000),
582 Price::from("0.8"),
583 None,
584 );
585 assert_eq!(result, Money::from("240.32 USD"));
586 }
587
588 #[rstest]
589 fn test_calculate_margin_init_with_no_leverage_for_inverse(
590 mut margin_account: MarginAccount,
591 xbtusd_bitmex: CryptoPerpetual,
592 ) {
593 let result_use_quote_inverse_true = margin_account.calculate_initial_margin(
594 xbtusd_bitmex,
595 Quantity::from(100_000),
596 Price::from("11493.60"),
597 Some(false),
598 );
599 assert_eq!(result_use_quote_inverse_true, Money::from("0.10005568 BTC"));
600 let result_use_quote_inverse_false = margin_account.calculate_initial_margin(
601 xbtusd_bitmex,
602 Quantity::from(100_000),
603 Price::from("11493.60"),
604 Some(true),
605 );
606 assert_eq!(result_use_quote_inverse_false, Money::from("1150 USD"));
607 }
608
609 #[rstest]
610 fn test_calculate_margin_maintenance_with_no_leverage(
611 mut margin_account: MarginAccount,
612 xbtusd_bitmex: CryptoPerpetual,
613 ) {
614 let result = margin_account.calculate_maintenance_margin(
615 xbtusd_bitmex,
616 Quantity::from(100_000),
617 Price::from("11493.60"),
618 None,
619 );
620 assert_eq!(result, Money::from("0.03697710 BTC"));
621 }
622
623 #[rstest]
624 fn test_calculate_margin_maintenance_with_leverage_fx_instrument(
625 mut margin_account: MarginAccount,
626 audusd_sim: CurrencyPair,
627 ) {
628 margin_account.set_default_leverage(50.0);
629 let result = margin_account.calculate_maintenance_margin(
630 audusd_sim,
631 Quantity::from(1_000_000),
632 Price::from("1"),
633 None,
634 );
635 assert_eq!(result, Money::from("600.40 USD"));
636 }
637
638 #[rstest]
639 fn test_calculate_margin_maintenance_with_leverage_inverse_instrument(
640 mut margin_account: MarginAccount,
641 xbtusd_bitmex: CryptoPerpetual,
642 ) {
643 margin_account.set_default_leverage(10.0);
644 let result = margin_account.calculate_maintenance_margin(
645 xbtusd_bitmex,
646 Quantity::from(100_000),
647 Price::from("100000.00"),
648 None,
649 );
650 assert_eq!(result, Money::from("0.00042500 BTC"));
651 }
652}