1use std::{
26 fmt::Debug,
27 ops::{Deref, DerefMut},
28};
29
30use ahash::AHashSet;
31use nautilus_common::actor::{DataActor, DataActorCore};
32use nautilus_model::{
33 data::QuoteTick,
34 enums::OrderSide,
35 identifiers::{InstrumentId, StrategyId},
36 instruments::Instrument,
37 orders::Order,
38 types::{Price, Quantity},
39};
40use rust_decimal::Decimal;
41
42use crate::strategy::{Strategy, StrategyConfig, StrategyCore};
43
44pub struct GridMarketMaker {
51 core: StrategyCore,
52 instrument_id: InstrumentId,
53 trade_size: Quantity,
54 num_levels: usize,
55 grid_interval: f64,
56 skew_factor: f64,
57 max_position: Quantity,
58 requote_threshold: f64,
59 price_precision: u8,
60 last_quoted_mid: Option<Price>,
61}
62
63impl GridMarketMaker {
64 #[must_use]
66 pub fn new(
67 instrument_id: InstrumentId,
68 trade_size: Quantity,
69 num_levels: usize,
70 grid_interval: f64,
71 skew_factor: f64,
72 max_position: Quantity,
73 requote_threshold: f64,
74 ) -> Self {
75 let config = StrategyConfig {
76 strategy_id: Some(StrategyId::from("GRID_MM-001")),
77 order_id_tag: Some("001".to_string()),
78 ..Default::default()
79 };
80 Self {
81 core: StrategyCore::new(config),
82 instrument_id,
83 trade_size,
84 num_levels,
85 grid_interval,
86 skew_factor,
87 max_position,
88 requote_threshold,
89 price_precision: 0,
90 last_quoted_mid: None,
91 }
92 }
93
94 fn should_requote(&self, mid: Price) -> bool {
95 match self.last_quoted_mid {
96 Some(last_mid) => (mid.as_f64() - last_mid.as_f64()).abs() >= self.requote_threshold,
97 None => true,
98 }
99 }
100
101 fn grid_orders(
108 &self,
109 mid: Price,
110 net_position: f64,
111 worst_long: Decimal,
112 worst_short: Decimal,
113 ) -> Vec<(OrderSide, Price)> {
114 let precision = self.price_precision;
115 let skew = Price::new(self.skew_factor * net_position, precision);
116 let trade_size = self.trade_size.as_decimal();
117 let max_pos = self.max_position.as_decimal();
118 let mut projected_long = worst_long;
119 let mut projected_short = worst_short;
120 let mut orders = Vec::new();
121
122 for level in 1..=self.num_levels {
123 let offset = Price::new(level as f64 * self.grid_interval, precision);
124
125 if projected_long + trade_size <= max_pos {
126 orders.push((OrderSide::Buy, mid - offset - skew));
127 projected_long += trade_size;
128 }
129
130 if projected_short - trade_size >= -max_pos {
131 orders.push((OrderSide::Sell, mid + offset - skew));
132 projected_short -= trade_size;
133 }
134 }
135
136 orders
137 }
138}
139
140impl Deref for GridMarketMaker {
141 type Target = DataActorCore;
142 fn deref(&self) -> &Self::Target {
143 &self.core
144 }
145}
146
147impl DerefMut for GridMarketMaker {
148 fn deref_mut(&mut self) -> &mut Self::Target {
149 &mut self.core
150 }
151}
152
153impl Debug for GridMarketMaker {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 f.debug_struct(stringify!(GridMarketMaker))
156 .field("instrument_id", &self.instrument_id)
157 .field("trade_size", &self.trade_size)
158 .field("num_levels", &self.num_levels)
159 .field("grid_interval", &self.grid_interval)
160 .field("skew_factor", &self.skew_factor)
161 .field("max_position", &self.max_position)
162 .field("requote_threshold", &self.requote_threshold)
163 .finish()
164 }
165}
166
167impl DataActor for GridMarketMaker {
168 fn on_start(&mut self) -> anyhow::Result<()> {
169 let price_precision = {
170 let cache = self.cache();
171 cache
172 .instrument(&self.instrument_id)
173 .expect("Instrument should be in cache")
174 .price_precision()
175 };
176 self.price_precision = price_precision;
177
178 self.subscribe_quotes(self.instrument_id, None, None);
179 Ok(())
180 }
181
182 fn on_stop(&mut self) -> anyhow::Result<()> {
183 self.cancel_all_orders(self.instrument_id, None, None)?;
184 self.close_all_positions(self.instrument_id, None, None, None, None, None, None)?;
185 self.unsubscribe_quotes(self.instrument_id, None, None);
186 Ok(())
187 }
188
189 fn on_quote(&mut self, quote: &QuoteTick) -> anyhow::Result<()> {
190 let mid_f64 = (quote.bid_price.as_f64() + quote.ask_price.as_f64()) / 2.0;
192 let mid = Price::new(mid_f64, self.price_precision);
193
194 if !self.should_requote(mid) {
195 return Ok(());
196 }
197
198 self.cancel_all_orders(self.instrument_id, None, None)?;
199
200 let (net_position, worst_long, worst_short) = {
203 let strategy_id = StrategyId::from(self.actor_id.inner().as_str());
204 let instrument_id = Some(&self.instrument_id);
205 let strategy = Some(&strategy_id);
206 let cache = self.cache();
207
208 let mut position_qty = 0.0_f64;
209 let mut position_dec = Decimal::ZERO;
210 for p in cache.positions_open(None, instrument_id, strategy, None, None) {
211 position_qty += p.signed_qty;
212 position_dec += p.quantity.as_decimal()
213 * if p.signed_qty < 0.0 {
214 Decimal::NEGATIVE_ONE
215 } else {
216 Decimal::ONE
217 };
218 }
219
220 let mut pending_buy_dec = Decimal::ZERO;
221 let mut pending_sell_dec = Decimal::ZERO;
222 let mut seen = AHashSet::new();
223
224 for order in cache
226 .orders_open(None, instrument_id, strategy, None, None)
227 .iter()
228 .chain(
229 cache
230 .orders_inflight(None, instrument_id, strategy, None, None)
231 .iter(),
232 )
233 {
234 if !seen.insert(order.client_order_id()) {
235 continue;
236 }
237 let qty = order.leaves_qty().as_decimal();
238 match order.order_side() {
239 OrderSide::Buy => pending_buy_dec += qty,
240 _ => pending_sell_dec += qty,
241 }
242 }
243
244 (
245 position_qty,
246 position_dec + pending_buy_dec,
247 position_dec - pending_sell_dec,
248 )
249 };
250
251 let grid = self.grid_orders(mid, net_position, worst_long, worst_short);
252
253 if grid.is_empty() {
256 return Ok(());
257 }
258
259 let instrument_id = self.instrument_id;
260 let trade_size = self.trade_size;
261
262 for (side, price) in grid {
263 let order = self.core.order_factory().limit(
264 instrument_id,
265 side,
266 trade_size,
267 price,
268 None,
269 None,
270 Some(true), None,
272 None,
273 None,
274 None,
275 None,
276 None,
277 None,
278 None,
279 None,
280 );
281 self.submit_order(order, None, None)?;
282 }
283
284 self.last_quoted_mid = Some(mid);
285 Ok(())
286 }
287}
288
289impl Strategy for GridMarketMaker {
290 fn core(&self) -> &StrategyCore {
291 &self.core
292 }
293
294 fn core_mut(&mut self) -> &mut StrategyCore {
295 &mut self.core
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use nautilus_model::{
302 enums::OrderSide,
303 identifiers::InstrumentId,
304 types::{Price, Quantity},
305 };
306 use rstest::rstest;
307 use rust_decimal_macros::dec;
308
309 use super::GridMarketMaker;
310
311 const PRECISION: u8 = 2;
312
313 fn create_strategy(
314 num_levels: usize,
315 grid_interval: f64,
316 skew_factor: f64,
317 max_position: Quantity,
318 requote_threshold: f64,
319 ) -> GridMarketMaker {
320 let mut strategy = GridMarketMaker::new(
321 InstrumentId::from("ETHUSDT-PERP.BINANCE"),
322 Quantity::from("0.100"),
323 num_levels,
324 grid_interval,
325 skew_factor,
326 max_position,
327 requote_threshold,
328 );
329 strategy.price_precision = PRECISION;
330 strategy
331 }
332
333 fn mid(value: &str) -> Price {
334 Price::new(value.parse::<f64>().unwrap(), PRECISION)
335 }
336
337 #[rstest]
338 fn test_should_requote_true_when_no_previous_quote() {
339 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
340 assert!(strategy.should_requote(mid("1000.00")));
341 }
342
343 #[rstest]
344 fn test_should_requote_false_within_threshold() {
345 let mut strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
346 strategy.last_quoted_mid = Some(mid("1000.00"));
347 assert!(!strategy.should_requote(mid("1000.30")));
348 }
349
350 #[rstest]
351 fn test_should_requote_true_at_threshold() {
352 let mut strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
353 strategy.last_quoted_mid = Some(mid("1000.00"));
354 assert!(strategy.should_requote(mid("1000.50")));
355 }
356
357 #[rstest]
358 fn test_should_requote_true_beyond_threshold_negative() {
359 let mut strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
360 strategy.last_quoted_mid = Some(mid("1000.00"));
361 assert!(strategy.should_requote(mid("999.40")));
362 }
363
364 #[rstest]
365 fn test_grid_orders_flat_position_symmetric() {
366 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
367 let orders = strategy.grid_orders(mid("1000.00"), 0.0, dec!(0), dec!(0));
368
369 assert_eq!(orders.len(), 6);
370
371 let buys: Vec<_> = orders
372 .iter()
373 .filter(|(s, _)| *s == OrderSide::Buy)
374 .collect();
375 let sells: Vec<_> = orders
376 .iter()
377 .filter(|(s, _)| *s == OrderSide::Sell)
378 .collect();
379 assert_eq!(buys.len(), 3);
380 assert_eq!(sells.len(), 3);
381
382 assert_eq!(buys[0].1, mid("999.00"));
384 assert_eq!(buys[1].1, mid("998.00"));
385 assert_eq!(buys[2].1, mid("997.00"));
386
387 assert_eq!(sells[0].1, mid("1001.00"));
389 assert_eq!(sells[1].1, mid("1002.00"));
390 assert_eq!(sells[2].1, mid("1003.00"));
391 }
392
393 #[rstest]
394 fn test_grid_orders_skew_shifts_prices() {
395 let strategy = create_strategy(1, 5.0, 1.0, Quantity::from("10.0"), 0.50);
397 let orders = strategy.grid_orders(mid("1000.00"), 2.0, dec!(2), dec!(2));
398
399 assert_eq!(orders.len(), 2);
400 assert_eq!(orders[0], (OrderSide::Buy, mid("993.00")));
402 assert_eq!(orders[1], (OrderSide::Sell, mid("1003.00")));
404 }
405
406 fn count_side(orders: &[(OrderSide, Price)], side: OrderSide) -> usize {
407 orders.iter().filter(|(s, _)| *s == side).count()
408 }
409
410 #[rstest]
411 fn test_grid_orders_max_position_limits_buy_levels() {
412 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
414 let orders = strategy.grid_orders(mid("1000.00"), 9.9, dec!(9.9), dec!(9.9));
415
416 assert_eq!(count_side(&orders, OrderSide::Buy), 1);
417 assert_eq!(count_side(&orders, OrderSide::Sell), 3);
418 }
419
420 #[rstest]
421 fn test_grid_orders_max_position_limits_sell_levels() {
422 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
424 let orders = strategy.grid_orders(mid("1000.00"), -9.9, dec!(-9.9), dec!(-9.9));
425
426 assert_eq!(count_side(&orders, OrderSide::Buy), 3);
427 assert_eq!(count_side(&orders, OrderSide::Sell), 1);
428 }
429
430 #[rstest]
431 fn test_grid_orders_max_position_blocks_all_buys() {
432 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("10.0"), 0.50);
434 let orders = strategy.grid_orders(mid("1000.00"), 10.0, dec!(10), dec!(10));
435
436 assert_eq!(count_side(&orders, OrderSide::Buy), 0);
437 assert_eq!(count_side(&orders, OrderSide::Sell), 3);
438 }
439
440 #[rstest]
441 fn test_grid_orders_projected_exposure_across_levels() {
442 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("0.150"), 0.50);
444 let orders = strategy.grid_orders(mid("1000.00"), 0.0, dec!(0), dec!(0));
445
446 assert_eq!(count_side(&orders, OrderSide::Buy), 1);
447 assert_eq!(count_side(&orders, OrderSide::Sell), 1);
448 }
449
450 #[rstest]
451 fn test_grid_orders_empty_when_fully_constrained() {
452 let strategy = create_strategy(3, 1.0, 0.0, Quantity::from("0.050"), 0.50);
454 let orders = strategy.grid_orders(mid("1000.00"), 0.0, dec!(0), dec!(0));
455 assert!(orders.is_empty());
456 }
457}