1use std::collections::HashMap;
17
18use rust_decimal::{prelude::ToPrimitive, Decimal};
19use serde::{Deserialize, Serialize};
20
21use crate::{
22 enums::{AccountType, LiquiditySide, OrderSide},
23 events::{AccountState, OrderFilled},
24 identifiers::AccountId,
25 instruments::InstrumentAny,
26 position::Position,
27 types::{AccountBalance, Currency, Money, Price, Quantity},
28};
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[cfg_attr(
32 feature = "python",
33 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
34)]
35pub struct BaseAccount {
36 pub id: AccountId,
37 pub account_type: AccountType,
38 pub base_currency: Option<Currency>,
39 pub calculate_account_state: bool,
40 pub events: Vec<AccountState>,
41 pub commissions: HashMap<Currency, f64>,
42 pub balances: HashMap<Currency, AccountBalance>,
43 pub balances_starting: HashMap<Currency, Money>,
44}
45
46impl BaseAccount {
47 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
49 let mut balances_starting: HashMap<Currency, Money> = HashMap::new();
50 let mut balances: HashMap<Currency, AccountBalance> = HashMap::new();
51 event.balances.iter().for_each(|balance| {
52 balances_starting.insert(balance.currency, balance.total);
53 balances.insert(balance.currency, *balance);
54 });
55 Self {
56 id: event.account_id,
57 account_type: event.account_type,
58 base_currency: event.base_currency,
59 calculate_account_state,
60 events: vec![event],
61 commissions: HashMap::new(),
62 balances,
63 balances_starting,
64 }
65 }
66
67 #[must_use]
68 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
69 let currency = currency
70 .or(self.base_currency)
71 .expect("Currency must be specified");
72 self.balances.get(¤cy)
73 }
74
75 #[must_use]
76 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
77 let currency = currency
78 .or(self.base_currency)
79 .expect("Currency must be specified");
80 let account_balance = self.balances.get(¤cy);
81 account_balance.map(|balance| balance.total)
82 }
83
84 #[must_use]
85 pub fn base_balances_total(&self) -> HashMap<Currency, Money> {
86 self.balances
87 .iter()
88 .map(|(currency, balance)| (*currency, balance.total))
89 .collect()
90 }
91
92 #[must_use]
93 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
94 let currency = currency
95 .or(self.base_currency)
96 .expect("Currency must be specified");
97 let account_balance = self.balances.get(¤cy);
98 account_balance.map(|balance| balance.free)
99 }
100
101 #[must_use]
102 pub fn base_balances_free(&self) -> HashMap<Currency, Money> {
103 self.balances
104 .iter()
105 .map(|(currency, balance)| (*currency, balance.free))
106 .collect()
107 }
108
109 #[must_use]
110 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
111 let currency = currency
112 .or(self.base_currency)
113 .expect("Currency must be specified");
114 let account_balance = self.balances.get(¤cy);
115 account_balance.map(|balance| balance.locked)
116 }
117
118 #[must_use]
119 pub fn base_balances_locked(&self) -> HashMap<Currency, Money> {
120 self.balances
121 .iter()
122 .map(|(currency, balance)| (*currency, balance.locked))
123 .collect()
124 }
125
126 #[must_use]
127 pub fn base_last_event(&self) -> Option<AccountState> {
128 self.events.last().cloned()
129 }
130
131 pub fn update_balances(&mut self, balances: Vec<AccountBalance>) {
132 for balance in balances {
133 if balance.total.raw < 0 {
135 panic!("Cannot update balances with total less than 0.0")
137 } else {
138 self.balances.insert(balance.currency, balance);
140 }
141 }
142 }
143
144 pub fn update_commissions(&mut self, commission: Money) {
145 if commission.as_decimal() == Decimal::ZERO {
146 return;
147 }
148
149 let currency = commission.currency;
150 let total_commissions = self.commissions.get(¤cy).unwrap_or(&0.0);
151
152 self.commissions
153 .insert(currency, total_commissions + commission.as_f64());
154 }
155
156 pub fn base_apply(&mut self, event: AccountState) {
157 self.update_balances(event.balances.clone());
158 self.events.push(event);
159 }
160
161 pub fn base_calculate_balance_locked(
162 &mut self,
163 instrument: InstrumentAny,
164 side: OrderSide,
165 quantity: Quantity,
166 price: Price,
167 use_quote_for_inverse: Option<bool>,
168 ) -> anyhow::Result<Money> {
169 let base_currency = instrument
170 .base_currency()
171 .unwrap_or(instrument.quote_currency());
172 let quote_currency = instrument.quote_currency();
173 let notional: f64 = match side {
174 OrderSide::Buy => instrument
175 .calculate_notional_value(quantity, price, use_quote_for_inverse)
176 .as_f64(),
177 OrderSide::Sell => quantity.as_f64(),
178 _ => panic!("Invalid `OrderSide` in `base_calculate_balance_locked`"),
179 };
180 let taker_fee = instrument.taker_fee().to_f64().unwrap();
182 let locked: f64 = (notional * taker_fee).mul_add(2.0, notional);
183
184 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
186 Ok(Money::new(locked, base_currency))
187 } else if side == OrderSide::Buy {
188 Ok(Money::new(locked, quote_currency))
189 } else if side == OrderSide::Sell {
190 Ok(Money::new(locked, base_currency))
191 } else {
192 panic!("Invalid `OrderSide` in `base_calculate_balance_locked`")
193 }
194 }
195
196 pub fn base_calculate_pnls(
197 &self,
198 instrument: InstrumentAny,
199 fill: OrderFilled,
200 position: Option<Position>,
201 ) -> anyhow::Result<Vec<Money>> {
202 let mut pnls: HashMap<Currency, Money> = HashMap::new();
203 let quote_currency = instrument.quote_currency();
204 let base_currency = instrument.base_currency();
205
206 let fill_px = fill.last_px.as_f64();
207 let fill_qty = position.map_or(fill.last_qty.as_f64(), |pos| {
208 pos.quantity.as_f64().min(fill.last_qty.as_f64())
209 });
210 if fill.order_side == OrderSide::Buy {
211 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
212 pnls.insert(
213 base_currency_value,
214 Money::new(fill_qty, base_currency_value),
215 );
216 }
217 pnls.insert(
218 quote_currency,
219 Money::new(-(fill_qty * fill_px), quote_currency),
220 );
221 } else if fill.order_side == OrderSide::Sell {
222 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
223 pnls.insert(
224 base_currency_value,
225 Money::new(-fill_qty, base_currency_value),
226 );
227 }
228 pnls.insert(
229 quote_currency,
230 Money::new(fill_qty * fill_px, quote_currency),
231 );
232 } else {
233 panic!("Invalid `OrderSide` in base_calculate_pnls")
234 }
235 Ok(pnls.into_values().collect())
236 }
237
238 pub fn base_calculate_commission(
239 &self,
240 instrument: InstrumentAny,
241 last_qty: Quantity,
242 last_px: Price,
243 liquidity_side: LiquiditySide,
244 use_quote_for_inverse: Option<bool>,
245 ) -> anyhow::Result<Money> {
246 assert!(
247 liquidity_side != LiquiditySide::NoLiquiditySide,
248 "Invalid `LiquiditySide`"
249 );
250 let notional = instrument
251 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
252 .as_f64();
253 let commission = if liquidity_side == LiquiditySide::Maker {
254 notional * instrument.maker_fee().to_f64().unwrap()
255 } else if liquidity_side == LiquiditySide::Taker {
256 notional * instrument.taker_fee().to_f64().unwrap()
257 } else {
258 panic!("Invalid `LiquiditySide` {liquidity_side}")
259 };
260 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
261 Ok(Money::new(commission, instrument.base_currency().unwrap()))
262 } else {
263 Ok(Money::new(commission, instrument.quote_currency()))
264 }
265 }
266}
267
268pub trait Account: 'static + Send {
269 fn id(&self) -> AccountId;
270 fn account_type(&self) -> AccountType;
271 fn base_currency(&self) -> Option<Currency>;
272 fn is_cash_account(&self) -> bool;
273 fn is_margin_account(&self) -> bool;
274 fn calculated_account_state(&self) -> bool;
275 fn balance_total(&self, currency: Option<Currency>) -> Option<Money>;
276 fn balances_total(&self) -> HashMap<Currency, Money>;
277 fn balance_free(&self, currency: Option<Currency>) -> Option<Money>;
278 fn balances_free(&self) -> HashMap<Currency, Money>;
279 fn balance_locked(&self, currency: Option<Currency>) -> Option<Money>;
280 fn balances_locked(&self) -> HashMap<Currency, Money>;
281 fn balance(&self, currency: Option<Currency>) -> Option<&AccountBalance>;
282 fn last_event(&self) -> Option<AccountState>;
283 fn events(&self) -> Vec<AccountState>;
284 fn event_count(&self) -> usize;
285 fn currencies(&self) -> Vec<Currency>;
286 fn starting_balances(&self) -> HashMap<Currency, Money>;
287 fn balances(&self) -> HashMap<Currency, AccountBalance>;
288 fn apply(&mut self, event: AccountState);
289 fn calculate_balance_locked(
290 &mut self,
291 instrument: InstrumentAny,
292 side: OrderSide,
293 quantity: Quantity,
294 price: Price,
295 use_quote_for_inverse: Option<bool>,
296 ) -> anyhow::Result<Money>;
297 fn calculate_pnls(
298 &self,
299 instrument: InstrumentAny,
300 fill: OrderFilled,
301 position: Option<Position>,
302 ) -> anyhow::Result<Vec<Money>>;
303 fn calculate_commission(
304 &self,
305 instrument: InstrumentAny,
306 last_qty: Quantity,
307 last_px: Price,
308 liquidity_side: LiquiditySide,
309 use_quote_for_inverse: Option<bool>,
310 ) -> anyhow::Result<Money>;
311}