1use std::{
19 collections::HashMap,
20 sync::{
21 Arc,
22 atomic::{AtomicU64, Ordering},
23 },
24 time::Duration,
25};
26
27use futures_util::future::BoxFuture;
28use tokio::{
29 sync::{Mutex, OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
30 time,
31};
32use tracing::{error, info, warn};
33
34use crate::{
35 common::consts::INFLIGHT_MAX,
36 http::{
37 error::{Error, Result},
38 models::{HyperliquidFills, HyperliquidL2Book, HyperliquidOrderStatus},
39 },
40 websocket::messages::{
41 ActionRequest, CancelByCloidRequest, CancelRequest, HyperliquidWsRequest, ModifyRequest,
42 OrderRequest, OrderTypeRequest, PostRequest, PostResponse, TimeInForceRequest, TpSlRequest,
43 },
44};
45
46#[derive(Debug)]
52struct Waiter {
53 tx: oneshot::Sender<PostResponse>,
54 _permit: OwnedSemaphorePermit,
56}
57
58#[derive(Debug)]
59pub struct PostRouter {
60 inner: Mutex<HashMap<u64, Waiter>>,
61 inflight: Arc<Semaphore>, }
63
64impl Default for PostRouter {
65 fn default() -> Self {
66 Self {
67 inner: Mutex::new(HashMap::new()),
68 inflight: Arc::new(Semaphore::new(INFLIGHT_MAX)),
69 }
70 }
71}
72
73impl PostRouter {
74 pub fn new() -> Arc<Self> {
75 Arc::new(Self::default())
76 }
77
78 pub async fn register(&self, id: u64) -> Result<oneshot::Receiver<PostResponse>> {
80 let permit = self
82 .inflight
83 .clone()
84 .acquire_owned()
85 .await
86 .map_err(|_| Error::transport("post router semaphore closed"))?;
87
88 let (tx, rx) = oneshot::channel::<PostResponse>();
89 let mut map = self.inner.lock().await;
90 if map.contains_key(&id) {
91 return Err(Error::transport(format!("post id {id} already registered")));
92 }
93 map.insert(
94 id,
95 Waiter {
96 tx,
97 _permit: permit,
98 },
99 );
100 Ok(rx)
101 }
102
103 pub async fn complete(&self, resp: PostResponse) {
105 let id = resp.id;
106 let waiter = {
107 let mut map = self.inner.lock().await;
108 map.remove(&id)
109 };
110 if let Some(waiter) = waiter {
111 if waiter.tx.send(resp).is_err() {
112 warn!(id, "post waiter dropped before delivery");
113 }
114 } else {
116 warn!(id, "post response with unknown id (late/duplicate?)");
117 }
118 }
119
120 pub async fn cancel(&self, id: u64) {
122 let _ = {
123 let mut map = self.inner.lock().await;
124 map.remove(&id)
125 };
126 }
128
129 pub async fn await_with_timeout(
131 &self,
132 id: u64,
133 rx: oneshot::Receiver<PostResponse>,
134 timeout: Duration,
135 ) -> Result<PostResponse> {
136 match time::timeout(timeout, rx).await {
137 Ok(Ok(resp)) => Ok(resp),
138 Ok(Err(_closed)) => {
139 self.cancel(id).await;
140 Err(Error::transport("post response channel closed"))
141 }
142 Err(_elapsed) => {
143 self.cancel(id).await;
144 Err(Error::Timeout)
145 }
146 }
147 }
148}
149
150#[derive(Debug)]
155pub struct PostIds(AtomicU64);
156
157impl PostIds {
158 pub fn new(start: u64) -> Self {
159 Self(AtomicU64::new(start))
160 }
161 pub fn next(&self) -> u64 {
162 self.0.fetch_add(1, Ordering::Relaxed)
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum PostLane {
172 Alo, Normal, }
175
176#[derive(Debug)]
177pub struct ScheduledPost {
178 pub id: u64,
179 pub request: PostRequest,
180 pub lane: PostLane,
181}
182
183#[derive(Debug)]
184pub struct PostBatcher {
185 tx_alo: mpsc::Sender<ScheduledPost>,
186 tx_normal: mpsc::Sender<ScheduledPost>,
187}
188
189impl PostBatcher {
190 pub fn new<F>(send_fn: F) -> Self
192 where
193 F: Send + 'static + Clone + FnMut(HyperliquidWsRequest) -> BoxFuture<'static, Result<()>>,
194 {
195 let (tx_alo, rx_alo) = mpsc::channel::<ScheduledPost>(1024);
196 let (tx_normal, rx_normal) = mpsc::channel::<ScheduledPost>(4096);
197
198 tokio::spawn(Self::run_lane(
200 "ALO",
201 rx_alo,
202 Duration::from_millis(100),
203 send_fn.clone(),
204 ));
205
206 tokio::spawn(Self::run_lane(
208 "NORMAL",
209 rx_normal,
210 Duration::from_millis(50),
211 send_fn,
212 ));
213
214 Self { tx_alo, tx_normal }
215 }
216
217 async fn run_lane<F>(
218 lane_name: &'static str,
219 mut rx: mpsc::Receiver<ScheduledPost>,
220 tick: Duration,
221 mut send_fn: F,
222 ) where
223 F: Send + 'static + FnMut(HyperliquidWsRequest) -> BoxFuture<'static, Result<()>>,
224 {
225 let mut pend: Vec<ScheduledPost> = Vec::with_capacity(128);
226 let mut interval = time::interval(tick);
227 interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
228
229 loop {
230 tokio::select! {
231 maybe_item = rx.recv() => {
232 match maybe_item {
233 Some(item) => pend.push(item),
234 None => break, }
236 }
237 _ = interval.tick() => {
238 if pend.is_empty() { continue; }
239 let to_send = std::mem::take(&mut pend);
240 for item in to_send {
241 let req = HyperliquidWsRequest::Post { id: item.id, request: item.request.clone() };
242 if let Err(e) = send_fn(req).await {
243 error!(lane=%lane_name, id=%item.id, "failed to send post: {e}");
244 }
245 }
246 }
247 }
248 }
249 info!(lane=%lane_name, "post lane terminated");
250 }
251
252 pub async fn enqueue(&self, item: ScheduledPost) -> Result<()> {
253 match item.lane {
254 PostLane::Alo => self
255 .tx_alo
256 .send(item)
257 .await
258 .map_err(|_| Error::transport("ALO lane closed")),
259 PostLane::Normal => self
260 .tx_normal
261 .send(item)
262 .await
263 .map_err(|_| Error::transport("NORMAL lane closed")),
264 }
265 }
266}
267
268pub fn lane_for_action(action: &ActionRequest) -> PostLane {
270 match action {
271 ActionRequest::Order { orders, .. } => {
272 if orders.is_empty() {
273 return PostLane::Normal;
274 }
275 let all_alo = orders.iter().all(|o| {
276 matches!(
277 o.t,
278 OrderTypeRequest::Limit {
279 tif: TimeInForceRequest::Alo
280 }
281 )
282 });
283 if all_alo {
284 PostLane::Alo
285 } else {
286 PostLane::Normal
287 }
288 }
289 _ => PostLane::Normal,
290 }
291}
292
293#[derive(Debug, Clone, Copy)]
298pub enum Grouping {
299 Na,
300 NormalTpsl,
301 PositionTpsl,
302}
303impl Grouping {
304 pub fn as_str(&self) -> &'static str {
305 match self {
306 Self::Na => "na",
307 Self::NormalTpsl => "normalTpsl",
308 Self::PositionTpsl => "positionTpsl",
309 }
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct LimitOrderParams {
316 pub asset: u32,
317 pub is_buy: bool,
318 pub px: String,
319 pub sz: String,
320 pub reduce_only: bool,
321 pub tif: TimeInForceRequest,
322 pub cloid: Option<String>,
323}
324
325#[derive(Debug, Clone)]
327pub struct TriggerOrderParams {
328 pub asset: u32,
329 pub is_buy: bool,
330 pub px: String,
331 pub sz: String,
332 pub reduce_only: bool,
333 pub is_market: bool,
334 pub trigger_px: String,
335 pub tpsl: TpSlRequest,
336 pub cloid: Option<String>,
337}
338
339#[derive(Debug, Default)]
341pub struct OrderBuilder {
342 orders: Vec<OrderRequest>,
343 grouping: Grouping,
344}
345
346impl Default for Grouping {
347 fn default() -> Self {
348 Self::Na
349 }
350}
351
352impl OrderBuilder {
353 pub fn new() -> Self {
354 Self::default()
355 }
356 pub fn grouping(mut self, g: Grouping) -> Self {
357 self.grouping = g;
358 self
359 }
360
361 #[allow(clippy::too_many_arguments)]
363 pub fn push_limit(
364 self,
365 asset: u32,
366 is_buy: bool,
367 px: impl ToString,
368 sz: impl ToString,
369 reduce_only: bool,
370 tif: TimeInForceRequest,
371 cloid: Option<String>,
372 ) -> Self {
373 let params = LimitOrderParams {
374 asset,
375 is_buy,
376 px: px.to_string(),
377 sz: sz.to_string(),
378 reduce_only,
379 tif,
380 cloid,
381 };
382 self.push_limit_order(params)
383 }
384
385 pub fn push_limit_order(mut self, params: LimitOrderParams) -> Self {
387 self.orders.push(OrderRequest {
388 a: params.asset,
389 b: params.is_buy,
390 p: params.px,
391 s: params.sz,
392 r: params.reduce_only,
393 t: OrderTypeRequest::Limit { tif: params.tif },
394 c: params.cloid,
395 });
396 self
397 }
398
399 #[allow(clippy::too_many_arguments)]
401 pub fn push_trigger(
402 self,
403 asset: u32,
404 is_buy: bool,
405 px: impl ToString,
406 sz: impl ToString,
407 reduce_only: bool,
408 is_market: bool,
409 trigger_px: impl ToString,
410 tpsl: TpSlRequest,
411 cloid: Option<String>,
412 ) -> Self {
413 let params = TriggerOrderParams {
414 asset,
415 is_buy,
416 px: px.to_string(),
417 sz: sz.to_string(),
418 reduce_only,
419 is_market,
420 trigger_px: trigger_px.to_string(),
421 tpsl,
422 cloid,
423 };
424 self.push_trigger_order(params)
425 }
426
427 pub fn push_trigger_order(mut self, params: TriggerOrderParams) -> Self {
429 self.orders.push(OrderRequest {
430 a: params.asset,
431 b: params.is_buy,
432 p: params.px,
433 s: params.sz,
434 r: params.reduce_only,
435 t: OrderTypeRequest::Trigger {
436 is_market: params.is_market,
437 trigger_px: params.trigger_px,
438 tpsl: params.tpsl,
439 },
440 c: params.cloid,
441 });
442 self
443 }
444 pub fn build(self) -> ActionRequest {
445 ActionRequest::Order {
446 orders: self.orders,
447 grouping: self.grouping.as_str().to_string(),
448 }
449 }
450}
451
452pub fn cancel_many(cancels: Vec<(u32, u64)>) -> ActionRequest {
453 ActionRequest::Cancel {
454 cancels: cancels
455 .into_iter()
456 .map(|(a, o)| CancelRequest { a, o })
457 .collect(),
458 }
459}
460pub fn cancel_by_cloid(asset: u32, cloid: impl Into<String>) -> ActionRequest {
461 ActionRequest::CancelByCloid {
462 cancels: vec![CancelByCloidRequest {
463 asset,
464 cloid: cloid.into(),
465 }],
466 }
467}
468pub fn modify(oid: u64, new_order: OrderRequest) -> ActionRequest {
469 ActionRequest::Modify {
470 modifies: vec![ModifyRequest {
471 oid,
472 order: new_order,
473 }],
474 }
475}
476
477pub fn info_l2_book(coin: &str) -> PostRequest {
479 PostRequest::Info {
480 payload: serde_json::json!({"type":"l2Book","coin":coin}),
481 }
482}
483pub fn info_all_mids() -> PostRequest {
484 PostRequest::Info {
485 payload: serde_json::json!({"type":"allMids"}),
486 }
487}
488pub fn info_order_status(user: &str, oid: u64) -> PostRequest {
489 PostRequest::Info {
490 payload: serde_json::json!({"type":"orderStatus","user":user,"oid":oid}),
491 }
492}
493pub fn info_open_orders(user: &str, frontend: Option<bool>) -> PostRequest {
494 let mut body = serde_json::json!({"type":"openOrders","user":user});
495 if let Some(fe) = frontend {
496 body["frontend"] = serde_json::json!(fe);
497 }
498 PostRequest::Info { payload: body }
499}
500pub fn info_user_fills(user: &str, aggregate_by_time: Option<bool>) -> PostRequest {
501 let mut body = serde_json::json!({"type":"userFills","user":user});
502 if let Some(agg) = aggregate_by_time {
503 body["aggregateByTime"] = serde_json::json!(agg);
504 }
505 PostRequest::Info { payload: body }
506}
507pub fn info_user_rate_limit(user: &str) -> PostRequest {
508 PostRequest::Info {
509 payload: serde_json::json!({"type":"userRateLimit","user":user}),
510 }
511}
512pub fn info_candle(coin: &str, interval: &str) -> PostRequest {
513 PostRequest::Info {
514 payload: serde_json::json!({"type":"candle","coin":coin,"interval":interval}),
515 }
516}
517
518pub fn parse_l2_book(payload: &serde_json::Value) -> Result<HyperliquidL2Book> {
523 serde_json::from_value(payload.clone()).map_err(Error::Serde)
524}
525pub fn parse_user_fills(payload: &serde_json::Value) -> Result<HyperliquidFills> {
526 serde_json::from_value(payload.clone()).map_err(Error::Serde)
527}
528pub fn parse_order_status(payload: &serde_json::Value) -> Result<HyperliquidOrderStatus> {
529 serde_json::from_value(payload.clone()).map_err(Error::Serde)
530}
531
532#[derive(Debug)]
535pub enum ActionOutcome<'a> {
536 Resting {
537 oid: u64,
538 },
539 Filled {
540 total_sz: &'a str,
541 avg_px: &'a str,
542 oid: Option<u64>,
543 },
544 Error {
545 msg: &'a str,
546 },
547 Unknown(&'a serde_json::Value),
548}
549pub fn classify_action_payload(payload: &serde_json::Value) -> ActionOutcome<'_> {
550 if let Some(oid) = payload.get("oid").and_then(|v| v.as_u64()) {
551 if let (Some(total_sz), Some(avg_px)) = (
552 payload.get("totalSz").and_then(|v| v.as_str()),
553 payload.get("avgPx").and_then(|v| v.as_str()),
554 ) {
555 return ActionOutcome::Filled {
556 total_sz,
557 avg_px,
558 oid: Some(oid),
559 };
560 }
561 return ActionOutcome::Resting { oid };
562 }
563 if let (Some(total_sz), Some(avg_px)) = (
564 payload.get("totalSz").and_then(|v| v.as_str()),
565 payload.get("avgPx").and_then(|v| v.as_str()),
566 ) {
567 return ActionOutcome::Filled {
568 total_sz,
569 avg_px,
570 oid: None,
571 };
572 }
573 if let Some(msg) = payload
574 .get("error")
575 .and_then(|v| v.as_str())
576 .or_else(|| payload.get("message").and_then(|v| v.as_str()))
577 {
578 return ActionOutcome::Error { msg };
579 }
580 ActionOutcome::Unknown(payload)
581}
582
583#[derive(Clone, Debug)]
588pub struct WsSender {
589 inner: Arc<tokio::sync::Mutex<mpsc::Sender<HyperliquidWsRequest>>>,
590}
591
592impl WsSender {
593 pub fn new(tx: mpsc::Sender<HyperliquidWsRequest>) -> Self {
594 Self {
595 inner: Arc::new(tokio::sync::Mutex::new(tx)),
596 }
597 }
598
599 pub async fn send(&self, req: HyperliquidWsRequest) -> Result<()> {
600 let sender = self.inner.lock().await;
601 sender
602 .send(req)
603 .await
604 .map_err(|_| Error::transport("WebSocket sender closed"))
605 }
606}
607
608#[cfg(test)]
613mod tests {
614 use rstest::rstest;
615 use tokio::{
616 sync::oneshot,
617 time::{Duration, sleep, timeout},
618 };
619
620 use super::*;
621 use crate::{
622 common::consts::INFLIGHT_MAX,
623 websocket::messages::{
624 ActionRequest, HyperliquidWsRequest, OrderRequest, OrderTypeRequest, TimeInForceRequest,
625 },
626 };
627
628 fn mk_limit_alo(asset: u32) -> OrderRequest {
631 OrderRequest {
632 a: asset,
633 b: true,
634 p: "1".to_string(),
635 s: "1".to_string(),
636 r: false,
637 t: OrderTypeRequest::Limit {
638 tif: TimeInForceRequest::Alo,
639 },
640 c: None,
641 }
642 }
643
644 fn mk_limit_gtc(asset: u32) -> OrderRequest {
645 OrderRequest {
646 a: asset,
647 b: true,
648 p: "1".to_string(),
649 s: "1".to_string(),
650 r: false,
651 t: OrderTypeRequest::Limit {
652 tif: TimeInForceRequest::Gtc,
654 },
655 c: None,
656 }
657 }
658
659 #[rstest]
662 #[tokio::test(flavor = "multi_thread")]
663 async fn register_duplicate_id_errors() {
664 let router = PostRouter::new();
665 let _rx = router.register(42).await.expect("first register OK");
666
667 let err = router.register(42).await.expect_err("duplicate must error");
668 let msg = err.to_string().to_lowercase();
669 assert!(
670 msg.contains("already") || msg.contains("duplicate"),
671 "unexpected error: {msg}"
672 );
673 }
674
675 #[rstest]
676 #[tokio::test(flavor = "multi_thread")]
677 async fn timeout_cancels_and_allows_reregister() {
678 let router = PostRouter::new();
679 let id = 7;
680
681 let rx = router.register(id).await.unwrap();
682 let err = router
684 .await_with_timeout(id, rx, Duration::from_millis(25))
685 .await
686 .expect_err("should timeout");
687 assert!(
688 err.to_string().to_lowercase().contains("timeout")
689 || err.to_string().to_lowercase().contains("closed"),
690 "unexpected error kind: {err}"
691 );
692
693 let _rx2 = router
695 .register(id)
696 .await
697 .expect("id should be reusable after timeout cancel");
698 }
699
700 #[rstest]
701 #[tokio::test(flavor = "multi_thread")]
702 async fn inflight_cap_blocks_then_unblocks() {
703 let router = PostRouter::new();
704
705 let mut rxs = Vec::with_capacity(INFLIGHT_MAX);
707 for i in 0..INFLIGHT_MAX {
708 let rx = router.register(i as u64).await.unwrap();
709 rxs.push(rx); }
711
712 let router2 = Arc::clone(&router);
714 let (entered_tx, entered_rx) = oneshot::channel::<()>();
715 let (done_tx, done_rx) = oneshot::channel::<()>();
716 let (check_tx, check_rx) = oneshot::channel::<()>(); tokio::spawn(async move {
719 let _ = entered_tx.send(());
720 let _rx = router2.register(9_999_999).await.unwrap();
721 let _ = done_tx.send(());
722 });
723
724 entered_rx.await.unwrap();
726
727 tokio::spawn(async move {
729 if done_rx.await.is_ok() {
730 let _ = check_tx.send(());
731 }
732 });
733
734 assert!(
735 timeout(Duration::from_millis(50), check_rx).await.is_err(),
736 "should still be blocked while at cap"
737 );
738
739 router.cancel(0).await;
741
742 tokio::time::sleep(Duration::from_millis(100)).await;
744 }
745
746 #[rstest(
749 orders, expected,
750 case::all_alo(vec![mk_limit_alo(0), mk_limit_alo(1)], PostLane::Alo),
751 case::mixed_alo_gtc(vec![mk_limit_alo(0), mk_limit_gtc(1)], PostLane::Normal),
752 case::all_gtc(vec![mk_limit_gtc(0), mk_limit_gtc(1)], PostLane::Normal),
753 case::empty(vec![], PostLane::Normal),
754 )]
755 fn lane_classifier_cases(orders: Vec<OrderRequest>, expected: PostLane) {
756 let action = ActionRequest::Order {
757 orders,
758 grouping: "na".to_string(),
759 };
760 assert_eq!(lane_for_action(&action), expected);
761 }
762
763 #[rstest]
766 #[tokio::test(flavor = "multi_thread")]
767 async fn batcher_sends_on_tick() {
768 let sent: Arc<tokio::sync::Mutex<Vec<u64>>> = Arc::new(tokio::sync::Mutex::new(Vec::new()));
770 let sent_closure = sent.clone();
771
772 let send_fn = move |req: HyperliquidWsRequest| -> BoxFuture<'static, Result<()>> {
773 let sent_inner = sent_closure.clone();
774 Box::pin(async move {
775 if let HyperliquidWsRequest::Post { id, .. } = req {
776 sent_inner.lock().await.push(id);
777 }
778 Ok(())
779 })
780 };
781
782 let batcher = PostBatcher::new(send_fn);
783
784 for id in 1..=5u64 {
786 batcher
787 .enqueue(ScheduledPost {
788 id,
789 request: PostRequest::Info {
790 payload: serde_json::json!({"type":"allMids"}),
791 },
792 lane: PostLane::Normal,
793 })
794 .await
795 .unwrap();
796 }
797
798 sleep(Duration::from_millis(80)).await;
800
801 let got = sent.lock().await.clone();
802 assert_eq!(got.len(), 5, "expected 5 sends on first tick");
803 assert_eq!(got, vec![1, 2, 3, 4, 5]);
804 }
805}