1use std::collections::HashMap;
22
23use nautilus_core::{UnixNanos, datetime::secs_to_nanos};
24use rust_decimal::{Decimal, prelude::ToPrimitive};
25use serde::{Deserialize, Serialize};
26
27use crate::{
28 enums::{AccountType, LiquiditySide, OrderSide},
29 events::{AccountState, OrderFilled},
30 identifiers::AccountId,
31 instruments::{Instrument, InstrumentAny},
32 position::Position,
33 types::{AccountBalance, Currency, Money, Price, Quantity},
34};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[cfg_attr(
38 feature = "python",
39 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
40)]
41pub struct BaseAccount {
42 pub id: AccountId,
43 pub account_type: AccountType,
44 pub base_currency: Option<Currency>,
45 pub calculate_account_state: bool,
46 pub events: Vec<AccountState>,
47 pub commissions: HashMap<Currency, f64>,
48 pub balances: HashMap<Currency, AccountBalance>,
49 pub balances_starting: HashMap<Currency, Money>,
50}
51
52impl BaseAccount {
53 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
55 let mut balances_starting: HashMap<Currency, Money> = HashMap::new();
56 let mut balances: HashMap<Currency, AccountBalance> = HashMap::new();
57 event.balances.iter().for_each(|balance| {
58 balances_starting.insert(balance.currency, balance.total);
59 balances.insert(balance.currency, *balance);
60 });
61 Self {
62 id: event.account_id,
63 account_type: event.account_type,
64 base_currency: event.base_currency,
65 calculate_account_state,
66 events: vec![event],
67 commissions: HashMap::new(),
68 balances,
69 balances_starting,
70 }
71 }
72
73 #[must_use]
79 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
80 let currency = currency
81 .or(self.base_currency)
82 .expect("Currency must be specified");
83 self.balances.get(¤cy)
84 }
85
86 #[must_use]
92 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
93 let currency = currency
94 .or(self.base_currency)
95 .expect("Currency must be specified");
96 let account_balance = self.balances.get(¤cy);
97 account_balance.map(|balance| balance.total)
98 }
99
100 #[must_use]
101 pub fn base_balances_total(&self) -> HashMap<Currency, Money> {
102 self.balances
103 .iter()
104 .map(|(currency, balance)| (*currency, balance.total))
105 .collect()
106 }
107
108 #[must_use]
114 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
115 let currency = currency
116 .or(self.base_currency)
117 .expect("Currency must be specified");
118 let account_balance = self.balances.get(¤cy);
119 account_balance.map(|balance| balance.free)
120 }
121
122 #[must_use]
123 pub fn base_balances_free(&self) -> HashMap<Currency, Money> {
124 self.balances
125 .iter()
126 .map(|(currency, balance)| (*currency, balance.free))
127 .collect()
128 }
129
130 #[must_use]
136 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
137 let currency = currency
138 .or(self.base_currency)
139 .expect("Currency must be specified");
140 let account_balance = self.balances.get(¤cy);
141 account_balance.map(|balance| balance.locked)
142 }
143
144 #[must_use]
145 pub fn base_balances_locked(&self) -> HashMap<Currency, Money> {
146 self.balances
147 .iter()
148 .map(|(currency, balance)| (*currency, balance.locked))
149 .collect()
150 }
151
152 #[must_use]
153 pub fn base_last_event(&self) -> Option<AccountState> {
154 self.events.last().cloned()
155 }
156
157 pub fn update_balances(&mut self, balances: Vec<AccountBalance>) {
163 for balance in balances {
164 if balance.total.raw < 0 {
166 panic!("Cannot update balances with total less than 0.0")
168 } else {
169 self.balances.insert(balance.currency, balance);
171 }
172 }
173 }
174
175 pub fn update_commissions(&mut self, commission: Money) {
176 if commission.as_decimal() == Decimal::ZERO {
177 return;
178 }
179
180 let currency = commission.currency;
181 let total_commissions = self.commissions.get(¤cy).unwrap_or(&0.0);
182
183 self.commissions
184 .insert(currency, total_commissions + commission.as_f64());
185 }
186
187 pub fn base_apply(&mut self, event: AccountState) {
188 self.update_balances(event.balances.clone());
189 self.events.push(event);
190 }
191
192 pub fn base_purge_account_events(&mut self, ts_now: UnixNanos, lookback_secs: u64) {
200 let lookback_ns = UnixNanos::from(secs_to_nanos(lookback_secs as f64));
201
202 let mut retained_events = Vec::new();
203
204 for event in &self.events {
205 if event.ts_event + lookback_ns > ts_now {
206 retained_events.push(event.clone());
207 }
208 }
209
210 if retained_events.is_empty() && !self.events.is_empty() {
212 retained_events.push(self.events.last().unwrap().clone());
214 }
215
216 self.events = retained_events;
217 }
218
219 pub fn base_calculate_balance_locked(
229 &mut self,
230 instrument: InstrumentAny,
231 side: OrderSide,
232 quantity: Quantity,
233 price: Price,
234 use_quote_for_inverse: Option<bool>,
235 ) -> anyhow::Result<Money> {
236 let base_currency = instrument
237 .base_currency()
238 .unwrap_or(instrument.quote_currency());
239 let quote_currency = instrument.quote_currency();
240 let notional: f64 = match side {
241 OrderSide::Buy => instrument
242 .calculate_notional_value(quantity, price, use_quote_for_inverse)
243 .as_f64(),
244 OrderSide::Sell => quantity.as_f64(),
245 _ => panic!("Invalid `OrderSide` in `base_calculate_balance_locked`"),
246 };
247
248 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
250 Ok(Money::new(notional, base_currency))
251 } else if side == OrderSide::Buy {
252 Ok(Money::new(notional, quote_currency))
253 } else if side == OrderSide::Sell {
254 Ok(Money::new(notional, base_currency))
255 } else {
256 panic!("Invalid `OrderSide` in `base_calculate_balance_locked`")
257 }
258 }
259
260 pub fn base_calculate_pnls(
270 &self,
271 instrument: InstrumentAny,
272 fill: OrderFilled,
273 position: Option<Position>,
274 ) -> anyhow::Result<Vec<Money>> {
275 let mut pnls: HashMap<Currency, Money> = HashMap::new();
276 let base_currency = instrument.base_currency();
277
278 let fill_qty_value = position.map_or(fill.last_qty.as_f64(), |pos| {
279 pos.quantity.as_f64().min(fill.last_qty.as_f64())
280 });
281 let fill_qty = Quantity::new(fill_qty_value, fill.last_qty.precision);
282
283 let notional = instrument.calculate_notional_value(fill_qty, fill.last_px, None);
284
285 if fill.order_side == OrderSide::Buy {
286 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
287 pnls.insert(
288 base_currency_value,
289 Money::new(fill_qty_value, base_currency_value),
290 );
291 }
292 pnls.insert(
293 notional.currency,
294 Money::new(-notional.as_f64(), notional.currency),
295 );
296 } else if fill.order_side == OrderSide::Sell {
297 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
298 pnls.insert(
299 base_currency_value,
300 Money::new(-fill_qty_value, base_currency_value),
301 );
302 }
303 pnls.insert(
304 notional.currency,
305 Money::new(notional.as_f64(), notional.currency),
306 );
307 } else {
308 panic!("Invalid `OrderSide` in base_calculate_pnls")
309 }
310 Ok(pnls.into_values().collect())
311 }
312
313 pub fn base_calculate_commission(
323 &self,
324 instrument: InstrumentAny,
325 last_qty: Quantity,
326 last_px: Price,
327 liquidity_side: LiquiditySide,
328 use_quote_for_inverse: Option<bool>,
329 ) -> anyhow::Result<Money> {
330 assert!(
331 liquidity_side != LiquiditySide::NoLiquiditySide,
332 "Invalid `LiquiditySide`"
333 );
334 let notional = instrument
335 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
336 .as_f64();
337 let commission = if liquidity_side == LiquiditySide::Maker {
338 notional * instrument.maker_fee().to_f64().unwrap()
339 } else if liquidity_side == LiquiditySide::Taker {
340 notional * instrument.taker_fee().to_f64().unwrap()
341 } else {
342 panic!("Invalid `LiquiditySide` {liquidity_side}")
343 };
344 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
345 Ok(Money::new(commission, instrument.base_currency().unwrap()))
346 } else {
347 Ok(Money::new(commission, instrument.quote_currency()))
348 }
349 }
350}
351
352#[cfg(all(test, feature = "stubs"))]
353mod tests {
354 use rstest::rstest;
355
356 use super::*;
357
358 #[rstest]
359 fn test_base_purge_account_events_retains_latest_when_all_purged() {
360 use crate::{
361 enums::AccountType,
362 events::account::stubs::cash_account_state,
363 identifiers::stubs::{account_id, uuid4},
364 types::{Currency, stubs::stub_account_balance},
365 };
366
367 let mut account = BaseAccount::new(cash_account_state(), true);
368
369 let event1 = AccountState::new(
371 account_id(),
372 AccountType::Cash,
373 vec![stub_account_balance()],
374 vec![],
375 true,
376 uuid4(),
377 UnixNanos::from(100_000_000),
378 UnixNanos::from(100_000_000),
379 Some(Currency::USD()),
380 );
381 let event2 = AccountState::new(
382 account_id(),
383 AccountType::Cash,
384 vec![stub_account_balance()],
385 vec![],
386 true,
387 uuid4(),
388 UnixNanos::from(200_000_000),
389 UnixNanos::from(200_000_000),
390 Some(Currency::USD()),
391 );
392 let event3 = AccountState::new(
393 account_id(),
394 AccountType::Cash,
395 vec![stub_account_balance()],
396 vec![],
397 true,
398 uuid4(),
399 UnixNanos::from(300_000_000),
400 UnixNanos::from(300_000_000),
401 Some(Currency::USD()),
402 );
403
404 account.base_apply(event1);
405 account.base_apply(event2.clone());
406 account.base_apply(event3.clone());
407
408 assert_eq!(account.events.len(), 4);
409
410 account.base_purge_account_events(UnixNanos::from(1_000_000_000), 0);
411
412 assert_eq!(account.events.len(), 1);
413 assert_eq!(account.events[0].ts_event, event3.ts_event);
414 assert_eq!(account.base_last_event().unwrap().ts_event, event3.ts_event);
415 }
416}