1use ahash::AHashMap;
22use nautilus_core::{UnixNanos, datetime::secs_to_nanos_unchecked};
23use rust_decimal::{Decimal, prelude::ToPrimitive};
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 enums::{AccountType, LiquiditySide, OrderSide},
28 events::{AccountState, OrderFilled},
29 identifiers::AccountId,
30 instruments::{Instrument, InstrumentAny},
31 position::Position,
32 types::{AccountBalance, Currency, Money, Price, Quantity},
33};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[cfg_attr(
37 feature = "python",
38 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
39)]
40pub struct BaseAccount {
41 pub id: AccountId,
42 pub account_type: AccountType,
43 pub base_currency: Option<Currency>,
44 pub calculate_account_state: bool,
45 pub events: Vec<AccountState>,
46 pub commissions: AHashMap<Currency, f64>,
47 pub balances: AHashMap<Currency, AccountBalance>,
48 pub balances_starting: AHashMap<Currency, Money>,
49}
50
51impl BaseAccount {
52 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
54 let mut balances_starting: AHashMap<Currency, Money> = AHashMap::new();
55 let mut balances: AHashMap<Currency, AccountBalance> = AHashMap::new();
56 event.balances.iter().for_each(|balance| {
57 balances_starting.insert(balance.currency, balance.total);
58 balances.insert(balance.currency, *balance);
59 });
60 Self {
61 id: event.account_id,
62 account_type: event.account_type,
63 base_currency: event.base_currency,
64 calculate_account_state,
65 events: vec![event],
66 commissions: AHashMap::new(),
67 balances,
68 balances_starting,
69 }
70 }
71
72 #[must_use]
78 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
79 let currency = currency
80 .or(self.base_currency)
81 .expect("Currency must be specified");
82 self.balances.get(¤cy)
83 }
84
85 #[must_use]
91 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
92 let currency = currency
93 .or(self.base_currency)
94 .expect("Currency must be specified");
95 let account_balance = self.balances.get(¤cy);
96 account_balance.map(|balance| balance.total)
97 }
98
99 #[must_use]
100 pub fn base_balances_total(&self) -> AHashMap<Currency, Money> {
101 self.balances
102 .iter()
103 .map(|(currency, balance)| (*currency, balance.total))
104 .collect()
105 }
106
107 #[must_use]
113 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
114 let currency = currency
115 .or(self.base_currency)
116 .expect("Currency must be specified");
117 let account_balance = self.balances.get(¤cy);
118 account_balance.map(|balance| balance.free)
119 }
120
121 #[must_use]
122 pub fn base_balances_free(&self) -> AHashMap<Currency, Money> {
123 self.balances
124 .iter()
125 .map(|(currency, balance)| (*currency, balance.free))
126 .collect()
127 }
128
129 #[must_use]
135 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
136 let currency = currency
137 .or(self.base_currency)
138 .expect("Currency must be specified");
139 let account_balance = self.balances.get(¤cy);
140 account_balance.map(|balance| balance.locked)
141 }
142
143 #[must_use]
144 pub fn base_balances_locked(&self) -> AHashMap<Currency, Money> {
145 self.balances
146 .iter()
147 .map(|(currency, balance)| (*currency, balance.locked))
148 .collect()
149 }
150
151 #[must_use]
152 pub fn base_last_event(&self) -> Option<AccountState> {
153 self.events.last().cloned()
154 }
155
156 pub fn update_balances(&mut self, balances: Vec<AccountBalance>) {
162 for balance in balances {
163 if balance.total.raw < 0 {
165 panic!("Cannot update balances with total less than 0.0")
167 } else {
168 self.balances.insert(balance.currency, balance);
170 }
171 }
172 }
173
174 pub fn update_commissions(&mut self, commission: Money) {
175 if commission.as_decimal() == Decimal::ZERO {
176 return;
177 }
178
179 let currency = commission.currency;
180 let total_commissions = self.commissions.get(¤cy).unwrap_or(&0.0);
181
182 self.commissions
183 .insert(currency, total_commissions + commission.as_f64());
184 }
185
186 pub fn base_apply(&mut self, event: AccountState) {
187 self.update_balances(event.balances.clone());
188 self.events.push(event);
189 }
190
191 pub fn base_purge_account_events(&mut self, ts_now: UnixNanos, lookback_secs: u64) {
199 let lookback_ns = UnixNanos::from(secs_to_nanos_unchecked(lookback_secs as f64));
200
201 let mut retained_events = Vec::new();
202
203 for event in &self.events {
204 if event.ts_event + lookback_ns > ts_now {
205 retained_events.push(event.clone());
206 }
207 }
208
209 if retained_events.is_empty() && !self.events.is_empty() {
211 retained_events.push(self.events.last().unwrap().clone());
213 }
214
215 self.events = retained_events;
216 }
217
218 pub fn base_calculate_balance_locked(
228 &mut self,
229 instrument: InstrumentAny,
230 side: OrderSide,
231 quantity: Quantity,
232 price: Price,
233 use_quote_for_inverse: Option<bool>,
234 ) -> anyhow::Result<Money> {
235 let base_currency = instrument
236 .base_currency()
237 .unwrap_or(instrument.quote_currency());
238 let quote_currency = instrument.quote_currency();
239 let notional: f64 = match side {
240 OrderSide::Buy => instrument
241 .calculate_notional_value(quantity, price, use_quote_for_inverse)
242 .as_f64(),
243 OrderSide::Sell => quantity.as_f64(),
244 _ => anyhow::bail!("Invalid `OrderSide` in `base_calculate_balance_locked`: {side}"),
245 };
246
247 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
249 Ok(Money::new(notional, base_currency))
250 } else if side == OrderSide::Buy {
251 Ok(Money::new(notional, quote_currency))
252 } else if side == OrderSide::Sell {
253 Ok(Money::new(notional, base_currency))
254 } else {
255 anyhow::bail!("Invalid `OrderSide` in `base_calculate_balance_locked`: {side}")
256 }
257 }
258
259 pub fn base_calculate_pnls(
269 &self,
270 instrument: InstrumentAny,
271 fill: OrderFilled,
272 position: Option<Position>,
273 ) -> anyhow::Result<Vec<Money>> {
274 let mut pnls: AHashMap<Currency, Money> = AHashMap::new();
275 let base_currency = instrument.base_currency();
276
277 let fill_qty_value = position.map_or(fill.last_qty.as_f64(), |pos| {
278 pos.quantity.as_f64().min(fill.last_qty.as_f64())
279 });
280 let fill_qty = Quantity::new(fill_qty_value, fill.last_qty.precision);
281
282 let notional = instrument.calculate_notional_value(fill_qty, fill.last_px, None);
283
284 if fill.order_side == OrderSide::Buy {
285 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
286 pnls.insert(
287 base_currency_value,
288 Money::new(fill_qty_value, base_currency_value),
289 );
290 }
291 pnls.insert(
292 notional.currency,
293 Money::new(-notional.as_f64(), notional.currency),
294 );
295 } else if fill.order_side == OrderSide::Sell {
296 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
297 pnls.insert(
298 base_currency_value,
299 Money::new(-fill_qty_value, base_currency_value),
300 );
301 }
302 pnls.insert(
303 notional.currency,
304 Money::new(notional.as_f64(), notional.currency),
305 );
306 } else {
307 anyhow::bail!(
308 "Invalid `OrderSide` in base_calculate_pnls: {}",
309 fill.order_side
310 );
311 }
312 Ok(pnls.into_values().collect())
313 }
314
315 #[allow(
321 clippy::missing_errors_doc,
322 reason = "Error conditions documented inline"
323 )]
324 pub fn base_calculate_commission(
325 &self,
326 instrument: InstrumentAny,
327 last_qty: Quantity,
328 last_px: Price,
329 liquidity_side: LiquiditySide,
330 use_quote_for_inverse: Option<bool>,
331 ) -> anyhow::Result<Money> {
332 anyhow::ensure!(
333 liquidity_side != LiquiditySide::NoLiquiditySide,
334 "Invalid `LiquiditySide`: {liquidity_side}"
335 );
336 let notional = instrument
337 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
338 .as_f64();
339 let commission = if liquidity_side == LiquiditySide::Maker {
340 notional * instrument.maker_fee().to_f64().unwrap()
341 } else if liquidity_side == LiquiditySide::Taker {
342 notional * instrument.taker_fee().to_f64().unwrap()
343 } else {
344 anyhow::bail!("Invalid `LiquiditySide`: {liquidity_side}");
345 };
346 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
347 Ok(Money::new(commission, instrument.base_currency().unwrap()))
348 } else {
349 Ok(Money::new(commission, instrument.quote_currency()))
350 }
351 }
352}
353
354#[cfg(all(test, feature = "stubs"))]
355mod tests {
356 use rstest::rstest;
357
358 use super::*;
359
360 #[rstest]
361 fn test_base_purge_account_events_retains_latest_when_all_purged() {
362 use crate::{
363 enums::AccountType,
364 events::account::stubs::cash_account_state,
365 identifiers::stubs::{account_id, uuid4},
366 types::{Currency, stubs::stub_account_balance},
367 };
368
369 let mut account = BaseAccount::new(cash_account_state(), true);
370
371 let event1 = AccountState::new(
373 account_id(),
374 AccountType::Cash,
375 vec![stub_account_balance()],
376 vec![],
377 true,
378 uuid4(),
379 UnixNanos::from(100_000_000),
380 UnixNanos::from(100_000_000),
381 Some(Currency::USD()),
382 );
383 let event2 = AccountState::new(
384 account_id(),
385 AccountType::Cash,
386 vec![stub_account_balance()],
387 vec![],
388 true,
389 uuid4(),
390 UnixNanos::from(200_000_000),
391 UnixNanos::from(200_000_000),
392 Some(Currency::USD()),
393 );
394 let event3 = AccountState::new(
395 account_id(),
396 AccountType::Cash,
397 vec![stub_account_balance()],
398 vec![],
399 true,
400 uuid4(),
401 UnixNanos::from(300_000_000),
402 UnixNanos::from(300_000_000),
403 Some(Currency::USD()),
404 );
405
406 account.base_apply(event1);
407 account.base_apply(event2);
408 account.base_apply(event3.clone());
409
410 assert_eq!(account.events.len(), 4);
411
412 account.base_purge_account_events(UnixNanos::from(1_000_000_000), 0);
413
414 assert_eq!(account.events.len(), 1);
415 assert_eq!(account.events[0].ts_event, event3.ts_event);
416 assert_eq!(account.base_last_event().unwrap().ts_event, event3.ts_event);
417 }
418}