1use ahash::AHashMap;
22use nautilus_core::{UnixNanos, datetime::secs_to_nanos_unchecked};
23use rust_decimal::{Decimal, prelude::ToPrimitive};
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 enums::{AccountType, LiquiditySide, OrderSide},
28 events::{AccountState, OrderFilled},
29 identifiers::AccountId,
30 instruments::{Instrument, InstrumentAny},
31 position::Position,
32 types::{AccountBalance, Currency, Money, Price, Quantity},
33};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[cfg_attr(
37 feature = "python",
38 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
39)]
40pub struct BaseAccount {
41 pub id: AccountId,
42 pub account_type: AccountType,
43 pub base_currency: Option<Currency>,
44 pub calculate_account_state: bool,
45 pub events: Vec<AccountState>,
46 pub commissions: AHashMap<Currency, f64>,
47 pub balances: AHashMap<Currency, AccountBalance>,
48 pub balances_starting: AHashMap<Currency, Money>,
49}
50
51impl BaseAccount {
52 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
54 let mut balances_starting: AHashMap<Currency, Money> = AHashMap::new();
55 let mut balances: AHashMap<Currency, AccountBalance> = AHashMap::new();
56 event.balances.iter().for_each(|balance| {
57 balances_starting.insert(balance.currency, balance.total);
58 balances.insert(balance.currency, *balance);
59 });
60 Self {
61 id: event.account_id,
62 account_type: event.account_type,
63 base_currency: event.base_currency,
64 calculate_account_state,
65 events: vec![event],
66 commissions: AHashMap::new(),
67 balances,
68 balances_starting,
69 }
70 }
71
72 #[must_use]
78 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
79 let currency = currency
80 .or(self.base_currency)
81 .expect("Currency must be specified");
82 self.balances.get(¤cy)
83 }
84
85 #[must_use]
91 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
92 let currency = currency
93 .or(self.base_currency)
94 .expect("Currency must be specified");
95 let account_balance = self.balances.get(¤cy);
96 account_balance.map(|balance| balance.total)
97 }
98
99 #[must_use]
100 pub fn base_balances_total(&self) -> AHashMap<Currency, Money> {
101 self.balances
102 .iter()
103 .map(|(currency, balance)| (*currency, balance.total))
104 .collect()
105 }
106
107 #[must_use]
113 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
114 let currency = currency
115 .or(self.base_currency)
116 .expect("Currency must be specified");
117 let account_balance = self.balances.get(¤cy);
118 account_balance.map(|balance| balance.free)
119 }
120
121 #[must_use]
122 pub fn base_balances_free(&self) -> AHashMap<Currency, Money> {
123 self.balances
124 .iter()
125 .map(|(currency, balance)| (*currency, balance.free))
126 .collect()
127 }
128
129 #[must_use]
135 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
136 let currency = currency
137 .or(self.base_currency)
138 .expect("Currency must be specified");
139 let account_balance = self.balances.get(¤cy);
140 account_balance.map(|balance| balance.locked)
141 }
142
143 #[must_use]
144 pub fn base_balances_locked(&self) -> AHashMap<Currency, Money> {
145 self.balances
146 .iter()
147 .map(|(currency, balance)| (*currency, balance.locked))
148 .collect()
149 }
150
151 #[must_use]
152 pub fn base_last_event(&self) -> Option<AccountState> {
153 self.events.last().cloned()
154 }
155
156 pub fn update_balances(&mut self, balances: Vec<AccountBalance>) {
162 for balance in balances {
163 if balance.total.raw < 0 {
165 panic!("Cannot update balances with total less than 0.0")
167 } else {
168 self.balances.insert(balance.currency, balance);
170 }
171 }
172 }
173
174 pub fn update_commissions(&mut self, commission: Money) {
175 if commission.as_decimal() == Decimal::ZERO {
176 return;
177 }
178
179 let currency = commission.currency;
180 let total_commissions = self.commissions.get(¤cy).unwrap_or(&0.0);
181
182 self.commissions
183 .insert(currency, total_commissions + commission.as_f64());
184 }
185
186 #[must_use]
188 pub fn commission(&self, currency: &Currency) -> Option<Money> {
189 self.commissions
190 .get(currency)
191 .map(|&amount| Money::new(amount, *currency))
192 }
193
194 #[must_use]
196 pub fn commissions(&self) -> AHashMap<Currency, Money> {
197 self.commissions
198 .iter()
199 .map(|(currency, &amount)| (*currency, Money::new(amount, *currency)))
200 .collect()
201 }
202
203 pub fn base_apply(&mut self, event: AccountState) {
204 self.update_balances(event.balances.clone());
205 self.events.push(event);
206 }
207
208 pub fn base_purge_account_events(&mut self, ts_now: UnixNanos, lookback_secs: u64) {
216 let lookback_ns = UnixNanos::from(secs_to_nanos_unchecked(lookback_secs as f64));
217
218 let mut retained_events = Vec::new();
219
220 for event in &self.events {
221 if event.ts_event + lookback_ns > ts_now {
222 retained_events.push(event.clone());
223 }
224 }
225
226 if retained_events.is_empty() && !self.events.is_empty() {
228 retained_events.push(self.events.last().unwrap().clone());
230 }
231
232 self.events = retained_events;
233 }
234
235 pub fn base_calculate_balance_locked(
245 &mut self,
246 instrument: InstrumentAny,
247 side: OrderSide,
248 quantity: Quantity,
249 price: Price,
250 use_quote_for_inverse: Option<bool>,
251 ) -> anyhow::Result<Money> {
252 let base_currency = instrument
253 .base_currency()
254 .unwrap_or(instrument.quote_currency());
255 let quote_currency = instrument.quote_currency();
256 let notional: f64 = match side {
257 OrderSide::Buy => instrument
258 .calculate_notional_value(quantity, price, use_quote_for_inverse)
259 .as_f64(),
260 OrderSide::Sell => quantity.as_f64(),
261 _ => anyhow::bail!("Invalid `OrderSide` in `base_calculate_balance_locked`: {side}"),
262 };
263
264 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
266 Ok(Money::new(notional, base_currency))
267 } else if side == OrderSide::Buy {
268 Ok(Money::new(notional, quote_currency))
269 } else if side == OrderSide::Sell {
270 Ok(Money::new(notional, base_currency))
271 } else {
272 anyhow::bail!("Invalid `OrderSide` in `base_calculate_balance_locked`: {side}")
273 }
274 }
275
276 pub fn base_calculate_pnls(
286 &self,
287 instrument: InstrumentAny,
288 fill: OrderFilled,
289 position: Option<Position>,
290 ) -> anyhow::Result<Vec<Money>> {
291 let mut pnls: AHashMap<Currency, Money> = AHashMap::new();
292 let base_currency = instrument.base_currency();
293
294 let fill_qty_value = position.map_or(fill.last_qty.as_f64(), |pos| {
295 pos.quantity.as_f64().min(fill.last_qty.as_f64())
296 });
297 let fill_qty = Quantity::new(fill_qty_value, fill.last_qty.precision);
298
299 let notional = instrument.calculate_notional_value(fill_qty, fill.last_px, None);
300
301 if fill.order_side == OrderSide::Buy {
302 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
303 pnls.insert(
304 base_currency_value,
305 Money::new(fill_qty_value, base_currency_value),
306 );
307 }
308 pnls.insert(
309 notional.currency,
310 Money::new(-notional.as_f64(), notional.currency),
311 );
312 } else if fill.order_side == OrderSide::Sell {
313 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
314 pnls.insert(
315 base_currency_value,
316 Money::new(-fill_qty_value, base_currency_value),
317 );
318 }
319 pnls.insert(
320 notional.currency,
321 Money::new(notional.as_f64(), notional.currency),
322 );
323 } else {
324 anyhow::bail!(
325 "Invalid `OrderSide` in base_calculate_pnls: {}",
326 fill.order_side
327 );
328 }
329 Ok(pnls.into_values().collect())
330 }
331
332 #[allow(
338 clippy::missing_errors_doc,
339 reason = "Error conditions documented inline"
340 )]
341 pub fn base_calculate_commission(
342 &self,
343 instrument: InstrumentAny,
344 last_qty: Quantity,
345 last_px: Price,
346 liquidity_side: LiquiditySide,
347 use_quote_for_inverse: Option<bool>,
348 ) -> anyhow::Result<Money> {
349 anyhow::ensure!(
350 liquidity_side != LiquiditySide::NoLiquiditySide,
351 "Invalid `LiquiditySide`: {liquidity_side}"
352 );
353 let notional = instrument
354 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
355 .as_f64();
356 let commission = if liquidity_side == LiquiditySide::Maker {
357 notional * instrument.maker_fee().to_f64().unwrap()
358 } else if liquidity_side == LiquiditySide::Taker {
359 notional * instrument.taker_fee().to_f64().unwrap()
360 } else {
361 anyhow::bail!("Invalid `LiquiditySide`: {liquidity_side}");
362 };
363 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
364 Ok(Money::new(commission, instrument.base_currency().unwrap()))
365 } else {
366 Ok(Money::new(commission, instrument.quote_currency()))
367 }
368 }
369}
370
371#[cfg(all(test, feature = "stubs"))]
372mod tests {
373 use rstest::rstest;
374
375 use super::*;
376
377 #[rstest]
378 fn test_base_purge_account_events_retains_latest_when_all_purged() {
379 use crate::{
380 enums::AccountType,
381 events::account::stubs::cash_account_state,
382 identifiers::stubs::{account_id, uuid4},
383 types::{Currency, stubs::stub_account_balance},
384 };
385
386 let mut account = BaseAccount::new(cash_account_state(), true);
387
388 let event1 = AccountState::new(
390 account_id(),
391 AccountType::Cash,
392 vec![stub_account_balance()],
393 vec![],
394 true,
395 uuid4(),
396 UnixNanos::from(100_000_000),
397 UnixNanos::from(100_000_000),
398 Some(Currency::USD()),
399 );
400 let event2 = AccountState::new(
401 account_id(),
402 AccountType::Cash,
403 vec![stub_account_balance()],
404 vec![],
405 true,
406 uuid4(),
407 UnixNanos::from(200_000_000),
408 UnixNanos::from(200_000_000),
409 Some(Currency::USD()),
410 );
411 let event3 = AccountState::new(
412 account_id(),
413 AccountType::Cash,
414 vec![stub_account_balance()],
415 vec![],
416 true,
417 uuid4(),
418 UnixNanos::from(300_000_000),
419 UnixNanos::from(300_000_000),
420 Some(Currency::USD()),
421 );
422
423 account.base_apply(event1);
424 account.base_apply(event2);
425 account.base_apply(event3.clone());
426
427 assert_eq!(account.events.len(), 4);
428
429 account.base_purge_account_events(UnixNanos::from(1_000_000_000), 0);
430
431 assert_eq!(account.events.len(), 1);
432 assert_eq!(account.events[0].ts_event, event3.ts_event);
433 assert_eq!(account.base_last_event().unwrap().ts_event, event3.ts_event);
434 }
435}