1use std::collections::HashMap;
22
23use nautilus_core::{UnixNanos, datetime::secs_to_nanos};
24use rust_decimal::{Decimal, prelude::ToPrimitive};
25use serde::{Deserialize, Serialize};
26
27use crate::{
28 enums::{AccountType, LiquiditySide, OrderSide},
29 events::{AccountState, OrderFilled},
30 identifiers::AccountId,
31 instruments::{Instrument, InstrumentAny},
32 position::Position,
33 types::{AccountBalance, Currency, Money, Price, Quantity},
34};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[cfg_attr(
38 feature = "python",
39 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
40)]
41pub struct BaseAccount {
42 pub id: AccountId,
43 pub account_type: AccountType,
44 pub base_currency: Option<Currency>,
45 pub calculate_account_state: bool,
46 pub events: Vec<AccountState>,
47 pub commissions: HashMap<Currency, f64>,
48 pub balances: HashMap<Currency, AccountBalance>,
49 pub balances_starting: HashMap<Currency, Money>,
50}
51
52impl BaseAccount {
53 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
55 let mut balances_starting: HashMap<Currency, Money> = HashMap::new();
56 let mut balances: HashMap<Currency, AccountBalance> = HashMap::new();
57 event.balances.iter().for_each(|balance| {
58 balances_starting.insert(balance.currency, balance.total);
59 balances.insert(balance.currency, *balance);
60 });
61 Self {
62 id: event.account_id,
63 account_type: event.account_type,
64 base_currency: event.base_currency,
65 calculate_account_state,
66 events: vec![event],
67 commissions: HashMap::new(),
68 balances,
69 balances_starting,
70 }
71 }
72
73 #[must_use]
79 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
80 let currency = currency
81 .or(self.base_currency)
82 .expect("Currency must be specified");
83 self.balances.get(¤cy)
84 }
85
86 #[must_use]
92 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
93 let currency = currency
94 .or(self.base_currency)
95 .expect("Currency must be specified");
96 let account_balance = self.balances.get(¤cy);
97 account_balance.map(|balance| balance.total)
98 }
99
100 #[must_use]
101 pub fn base_balances_total(&self) -> HashMap<Currency, Money> {
102 self.balances
103 .iter()
104 .map(|(currency, balance)| (*currency, balance.total))
105 .collect()
106 }
107
108 #[must_use]
114 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
115 let currency = currency
116 .or(self.base_currency)
117 .expect("Currency must be specified");
118 let account_balance = self.balances.get(¤cy);
119 account_balance.map(|balance| balance.free)
120 }
121
122 #[must_use]
123 pub fn base_balances_free(&self) -> HashMap<Currency, Money> {
124 self.balances
125 .iter()
126 .map(|(currency, balance)| (*currency, balance.free))
127 .collect()
128 }
129
130 #[must_use]
136 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
137 let currency = currency
138 .or(self.base_currency)
139 .expect("Currency must be specified");
140 let account_balance = self.balances.get(¤cy);
141 account_balance.map(|balance| balance.locked)
142 }
143
144 #[must_use]
145 pub fn base_balances_locked(&self) -> HashMap<Currency, Money> {
146 self.balances
147 .iter()
148 .map(|(currency, balance)| (*currency, balance.locked))
149 .collect()
150 }
151
152 #[must_use]
153 pub fn base_last_event(&self) -> Option<AccountState> {
154 self.events.last().cloned()
155 }
156
157 pub fn update_balances(&mut self, balances: Vec<AccountBalance>) {
163 for balance in balances {
164 if balance.total.raw < 0 {
166 panic!("Cannot update balances with total less than 0.0")
168 } else {
169 self.balances.insert(balance.currency, balance);
171 }
172 }
173 }
174
175 pub fn update_commissions(&mut self, commission: Money) {
176 if commission.as_decimal() == Decimal::ZERO {
177 return;
178 }
179
180 let currency = commission.currency;
181 let total_commissions = self.commissions.get(¤cy).unwrap_or(&0.0);
182
183 self.commissions
184 .insert(currency, total_commissions + commission.as_f64());
185 }
186
187 pub fn base_apply(&mut self, event: AccountState) {
188 self.update_balances(event.balances.clone());
189 self.events.push(event);
190 }
191
192 pub fn base_purge_account_events(&mut self, ts_now: UnixNanos, lookback_secs: u64) {
200 let lookback_ns = UnixNanos::from(secs_to_nanos(lookback_secs as f64));
201
202 let mut retained_events = Vec::new();
203
204 for event in &self.events {
205 if event.ts_event + lookback_ns > ts_now {
206 retained_events.push(event.clone());
207 }
208 }
209
210 if retained_events.is_empty() && !self.events.is_empty() {
212 retained_events.push(self.events.last().unwrap().clone());
214 }
215
216 self.events = retained_events;
217 }
218
219 pub fn base_calculate_balance_locked(
229 &mut self,
230 instrument: InstrumentAny,
231 side: OrderSide,
232 quantity: Quantity,
233 price: Price,
234 use_quote_for_inverse: Option<bool>,
235 ) -> anyhow::Result<Money> {
236 let base_currency = instrument
237 .base_currency()
238 .unwrap_or(instrument.quote_currency());
239 let quote_currency = instrument.quote_currency();
240 let notional: f64 = match side {
241 OrderSide::Buy => instrument
242 .calculate_notional_value(quantity, price, use_quote_for_inverse)
243 .as_f64(),
244 OrderSide::Sell => quantity.as_f64(),
245 _ => panic!("Invalid `OrderSide` in `base_calculate_balance_locked`"),
246 };
247
248 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
250 Ok(Money::new(notional, base_currency))
251 } else if side == OrderSide::Buy {
252 Ok(Money::new(notional, quote_currency))
253 } else if side == OrderSide::Sell {
254 Ok(Money::new(notional, base_currency))
255 } else {
256 panic!("Invalid `OrderSide` in `base_calculate_balance_locked`")
257 }
258 }
259
260 pub fn base_calculate_pnls(
270 &self,
271 instrument: InstrumentAny,
272 fill: OrderFilled,
273 position: Option<Position>,
274 ) -> anyhow::Result<Vec<Money>> {
275 let mut pnls: HashMap<Currency, Money> = HashMap::new();
276 let quote_currency = instrument.quote_currency();
277 let base_currency = instrument.base_currency();
278
279 let fill_px = fill.last_px.as_f64();
280 let fill_qty = position.map_or(fill.last_qty.as_f64(), |pos| {
281 pos.quantity.as_f64().min(fill.last_qty.as_f64())
282 });
283 if fill.order_side == OrderSide::Buy {
284 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
285 pnls.insert(
286 base_currency_value,
287 Money::new(fill_qty, base_currency_value),
288 );
289 }
290 pnls.insert(
291 quote_currency,
292 Money::new(-(fill_qty * fill_px), quote_currency),
293 );
294 } else if fill.order_side == OrderSide::Sell {
295 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
296 pnls.insert(
297 base_currency_value,
298 Money::new(-fill_qty, base_currency_value),
299 );
300 }
301 pnls.insert(
302 quote_currency,
303 Money::new(fill_qty * fill_px, quote_currency),
304 );
305 } else {
306 panic!("Invalid `OrderSide` in base_calculate_pnls")
307 }
308 Ok(pnls.into_values().collect())
309 }
310
311 pub fn base_calculate_commission(
321 &self,
322 instrument: InstrumentAny,
323 last_qty: Quantity,
324 last_px: Price,
325 liquidity_side: LiquiditySide,
326 use_quote_for_inverse: Option<bool>,
327 ) -> anyhow::Result<Money> {
328 assert!(
329 liquidity_side != LiquiditySide::NoLiquiditySide,
330 "Invalid `LiquiditySide`"
331 );
332 let notional = instrument
333 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
334 .as_f64();
335 let commission = if liquidity_side == LiquiditySide::Maker {
336 notional * instrument.maker_fee().to_f64().unwrap()
337 } else if liquidity_side == LiquiditySide::Taker {
338 notional * instrument.taker_fee().to_f64().unwrap()
339 } else {
340 panic!("Invalid `LiquiditySide` {liquidity_side}")
341 };
342 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
343 Ok(Money::new(commission, instrument.base_currency().unwrap()))
344 } else {
345 Ok(Money::new(commission, instrument.quote_currency()))
346 }
347 }
348}
349
350#[cfg(all(test, feature = "stubs"))]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_base_purge_account_events_retains_latest_when_all_purged() {
356 use crate::{
357 enums::AccountType,
358 events::account::stubs::cash_account_state,
359 identifiers::stubs::{account_id, uuid4},
360 types::{Currency, stubs::stub_account_balance},
361 };
362
363 let mut account = BaseAccount::new(cash_account_state(), true);
364
365 let event1 = AccountState::new(
367 account_id(),
368 AccountType::Cash,
369 vec![stub_account_balance()],
370 vec![],
371 true,
372 uuid4(),
373 UnixNanos::from(100_000_000),
374 UnixNanos::from(100_000_000),
375 Some(Currency::USD()),
376 );
377 let event2 = AccountState::new(
378 account_id(),
379 AccountType::Cash,
380 vec![stub_account_balance()],
381 vec![],
382 true,
383 uuid4(),
384 UnixNanos::from(200_000_000),
385 UnixNanos::from(200_000_000),
386 Some(Currency::USD()),
387 );
388 let event3 = AccountState::new(
389 account_id(),
390 AccountType::Cash,
391 vec![stub_account_balance()],
392 vec![],
393 true,
394 uuid4(),
395 UnixNanos::from(300_000_000),
396 UnixNanos::from(300_000_000),
397 Some(Currency::USD()),
398 );
399
400 account.base_apply(event1);
401 account.base_apply(event2.clone());
402 account.base_apply(event3.clone());
403
404 assert_eq!(account.events.len(), 4);
405
406 account.base_purge_account_events(UnixNanos::from(1_000_000_000), 0);
407
408 assert_eq!(account.events.len(), 1);
409 assert_eq!(account.events[0].ts_event, event3.ts_event);
410 assert_eq!(account.base_last_event().unwrap().ts_event, event3.ts_event);
411 }
412}