1use ahash::AHashMap;
22use nautilus_core::{
23 UnixNanos,
24 correctness::{FAILED, check_equal},
25 datetime::secs_to_nanos_unchecked,
26};
27use rust_decimal::prelude::ToPrimitive;
28use serde::{Deserialize, Serialize};
29
30use crate::{
31 enums::{AccountType, LiquiditySide, OrderSide},
32 events::{AccountState, OrderFilled},
33 identifiers::AccountId,
34 instruments::{Instrument, InstrumentAny},
35 position::Position,
36 types::{AccountBalance, Currency, Money, Price, Quantity},
37};
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[cfg_attr(
41 feature = "python",
42 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
43)]
44pub struct BaseAccount {
45 pub id: AccountId,
46 pub account_type: AccountType,
47 pub base_currency: Option<Currency>,
48 pub calculate_account_state: bool,
49 pub events: Vec<AccountState>,
50 pub commissions: AHashMap<Currency, Money>,
51 pub balances: AHashMap<Currency, AccountBalance>,
52 pub balances_starting: AHashMap<Currency, Money>,
53}
54
55impl BaseAccount {
56 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
58 let mut balances_starting: AHashMap<Currency, Money> = AHashMap::new();
59 let mut balances: AHashMap<Currency, AccountBalance> = AHashMap::new();
60 event.balances.iter().for_each(|balance| {
61 balances_starting.insert(balance.currency, balance.total);
62 balances.insert(balance.currency, *balance);
63 });
64 Self {
65 id: event.account_id,
66 account_type: event.account_type,
67 base_currency: event.base_currency,
68 calculate_account_state,
69 events: vec![event],
70 commissions: AHashMap::new(),
71 balances,
72 balances_starting,
73 }
74 }
75
76 #[must_use]
82 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
83 let currency = currency
84 .or(self.base_currency)
85 .expect("Currency must be specified");
86 self.balances.get(¤cy)
87 }
88
89 #[must_use]
95 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
96 let currency = currency
97 .or(self.base_currency)
98 .expect("Currency must be specified");
99 let account_balance = self.balances.get(¤cy);
100 account_balance.map(|balance| balance.total)
101 }
102
103 #[must_use]
104 pub fn base_balances_total(&self) -> AHashMap<Currency, Money> {
105 self.balances
106 .iter()
107 .map(|(currency, balance)| (*currency, balance.total))
108 .collect()
109 }
110
111 #[must_use]
117 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
118 let currency = currency
119 .or(self.base_currency)
120 .expect("Currency must be specified");
121 let account_balance = self.balances.get(¤cy);
122 account_balance.map(|balance| balance.free)
123 }
124
125 #[must_use]
126 pub fn base_balances_free(&self) -> AHashMap<Currency, Money> {
127 self.balances
128 .iter()
129 .map(|(currency, balance)| (*currency, balance.free))
130 .collect()
131 }
132
133 #[must_use]
139 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
140 let currency = currency
141 .or(self.base_currency)
142 .expect("Currency must be specified");
143 let account_balance = self.balances.get(¤cy);
144 account_balance.map(|balance| balance.locked)
145 }
146
147 #[must_use]
148 pub fn base_balances_locked(&self) -> AHashMap<Currency, Money> {
149 self.balances
150 .iter()
151 .map(|(currency, balance)| (*currency, balance.locked))
152 .collect()
153 }
154
155 #[must_use]
156 pub fn base_last_event(&self) -> Option<AccountState> {
157 self.events.last().cloned()
158 }
159
160 pub fn update_balances(&mut self, balances: &[AccountBalance]) {
167 for balance in balances {
168 self.balances.insert(balance.currency, *balance);
169 }
170 }
171
172 pub fn update_commissions(&mut self, commission: Money) {
173 let commission = commission.normalized();
175 if commission.is_zero() {
176 return;
177 }
178 let currency = commission.currency;
179 self.commissions
180 .entry(currency)
181 .and_modify(|total| *total = *total + commission)
182 .or_insert(commission);
183 }
184
185 #[must_use]
187 pub fn commission(&self, currency: &Currency) -> Option<Money> {
188 self.commissions.get(currency).copied()
189 }
190
191 #[must_use]
193 pub fn commissions(&self) -> AHashMap<Currency, Money> {
194 self.commissions.clone()
195 }
196
197 pub fn base_apply(&mut self, event: AccountState) {
203 check_equal(&event.account_id, &self.id, "event.account_id", "self.id").expect(FAILED);
204 self.update_balances(&event.balances);
205 self.events.push(event);
206 }
207
208 pub fn base_purge_account_events(&mut self, ts_now: UnixNanos, lookback_secs: u64) {
216 let lookback_ns = UnixNanos::from(secs_to_nanos_unchecked(lookback_secs as f64));
217
218 let mut retained_events = Vec::new();
219
220 for event in &self.events {
221 if event.ts_event + lookback_ns > ts_now {
222 retained_events.push(event.clone());
223 }
224 }
225
226 if retained_events.is_empty() && !self.events.is_empty() {
228 retained_events.push(self.events.last().unwrap().clone());
230 }
231
232 self.events = retained_events;
233 }
234
235 pub fn base_calculate_balance_locked(
245 &mut self,
246 instrument: InstrumentAny,
247 side: OrderSide,
248 quantity: Quantity,
249 price: Price,
250 use_quote_for_inverse: Option<bool>,
251 ) -> anyhow::Result<Money> {
252 let base_currency = instrument
253 .base_currency()
254 .unwrap_or(instrument.quote_currency());
255 let quote_currency = instrument.quote_currency();
256 let notional: f64 = match side {
257 OrderSide::Buy => instrument
258 .calculate_notional_value(quantity, price, use_quote_for_inverse)
259 .as_f64(),
260 OrderSide::Sell => quantity.as_f64(),
261 _ => anyhow::bail!("Invalid `OrderSide` in `base_calculate_balance_locked`: {side}"),
262 };
263
264 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
266 Ok(Money::new(notional, base_currency))
267 } else if side == OrderSide::Buy {
268 Ok(Money::new(notional, quote_currency))
269 } else if side == OrderSide::Sell {
270 Ok(Money::new(notional, base_currency))
271 } else {
272 anyhow::bail!("Invalid `OrderSide` in `base_calculate_balance_locked`: {side}")
273 }
274 }
275
276 pub fn base_calculate_pnls(
293 &self,
294 instrument: InstrumentAny,
295 fill: OrderFilled,
296 _position: Option<Position>,
297 ) -> anyhow::Result<Vec<Money>> {
298 let mut pnls: AHashMap<Currency, Money> = AHashMap::new();
299 let base_currency = instrument.base_currency();
300
301 let fill_qty = fill.last_qty;
303 let fill_qty_value = fill_qty.as_f64();
304
305 let notional = instrument.calculate_notional_value(fill_qty, fill.last_px, None);
306
307 if fill.order_side == OrderSide::Buy {
308 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
309 pnls.insert(
310 base_currency_value,
311 Money::new(fill_qty_value, base_currency_value),
312 );
313 }
314 pnls.insert(
315 notional.currency,
316 Money::new(-notional.as_f64(), notional.currency),
317 );
318 } else if fill.order_side == OrderSide::Sell {
319 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
320 pnls.insert(
321 base_currency_value,
322 Money::new(-fill_qty_value, base_currency_value),
323 );
324 }
325 pnls.insert(
326 notional.currency,
327 Money::new(notional.as_f64(), notional.currency),
328 );
329 } else {
330 anyhow::bail!(
331 "Invalid `OrderSide` in base_calculate_pnls: {}",
332 fill.order_side
333 );
334 }
335 Ok(pnls.into_values().collect())
336 }
337
338 #[allow(
344 clippy::missing_errors_doc,
345 reason = "Error conditions documented inline"
346 )]
347 pub fn base_calculate_commission(
348 &self,
349 instrument: InstrumentAny,
350 last_qty: Quantity,
351 last_px: Price,
352 liquidity_side: LiquiditySide,
353 use_quote_for_inverse: Option<bool>,
354 ) -> anyhow::Result<Money> {
355 anyhow::ensure!(
356 liquidity_side != LiquiditySide::NoLiquiditySide,
357 "Invalid `LiquiditySide`: {liquidity_side}"
358 );
359 let notional = instrument
360 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
361 .as_f64();
362 let commission = if liquidity_side == LiquiditySide::Maker {
363 notional * instrument.maker_fee().to_f64().unwrap()
364 } else if liquidity_side == LiquiditySide::Taker {
365 notional * instrument.taker_fee().to_f64().unwrap()
366 } else {
367 anyhow::bail!("Invalid `LiquiditySide`: {liquidity_side}");
368 };
369 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
370 Ok(Money::new(commission, instrument.base_currency().unwrap()))
371 } else {
372 Ok(Money::new(commission, instrument.quote_currency()))
373 }
374 }
375}
376
377#[cfg(all(test, feature = "stubs"))]
378mod tests {
379 use rstest::rstest;
380
381 use super::*;
382
383 #[rstest]
384 fn test_base_purge_account_events_retains_latest_when_all_purged() {
385 use crate::{
386 enums::AccountType,
387 events::account::stubs::cash_account_state,
388 identifiers::stubs::{account_id, uuid4},
389 types::{Currency, stubs::stub_account_balance},
390 };
391
392 let mut account = BaseAccount::new(cash_account_state(), true);
393
394 let event1 = AccountState::new(
396 account_id(),
397 AccountType::Cash,
398 vec![stub_account_balance()],
399 vec![],
400 true,
401 uuid4(),
402 UnixNanos::from(100_000_000),
403 UnixNanos::from(100_000_000),
404 Some(Currency::USD()),
405 );
406 let event2 = AccountState::new(
407 account_id(),
408 AccountType::Cash,
409 vec![stub_account_balance()],
410 vec![],
411 true,
412 uuid4(),
413 UnixNanos::from(200_000_000),
414 UnixNanos::from(200_000_000),
415 Some(Currency::USD()),
416 );
417 let event3 = AccountState::new(
418 account_id(),
419 AccountType::Cash,
420 vec![stub_account_balance()],
421 vec![],
422 true,
423 uuid4(),
424 UnixNanos::from(300_000_000),
425 UnixNanos::from(300_000_000),
426 Some(Currency::USD()),
427 );
428
429 account.base_apply(event1);
430 account.base_apply(event2);
431 account.base_apply(event3.clone());
432
433 assert_eq!(account.events.len(), 4);
434
435 account.base_purge_account_events(UnixNanos::from(1_000_000_000), 0);
436
437 assert_eq!(account.events.len(), 1);
438 assert_eq!(account.events[0].ts_event, event3.ts_event);
439 assert_eq!(account.base_last_event().unwrap().ts_event, event3.ts_event);
440 }
441
442 #[rstest]
443 fn test_update_commissions_sub_canonical_raw_skipped() {
444 use crate::{
445 events::account::stubs::cash_account_state,
446 types::{Currency, Money},
447 };
448
449 let mut account = BaseAccount::new(cash_account_state(), true);
450 let usd = Currency::USD();
451
452 account.update_commissions(Money::from_raw(1, usd));
454
455 assert!(account.commission(&usd).is_none());
456 }
457}