1use std::{
17 collections::HashMap,
18 hash::{Hash, Hasher},
19};
20
21use derive_builder::Builder;
22use evalexpr::{ContextWithMutableVariables, HashMapContext, Node, Value};
23use nautilus_core::{UnixNanos, correctness::FAILED};
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 identifiers::{InstrumentId, Symbol, Venue},
28 types::Price,
29};
30
31fn make_safe_formula_with_variables_and_mapping(
36 formula: &str,
37 components: &[InstrumentId],
38) -> (String, Vec<String>, HashMap<String, String>) {
39 let mut safe_formula = formula.to_string();
40 let mut variables = Vec::with_capacity(components.len());
41 let mut safe_to_original = HashMap::new();
42
43 for component in components {
44 let original = component.to_string();
45 let safe = original.replace('-', "_");
46 safe_to_original.insert(safe.clone(), original.clone());
47 if original != safe {
48 safe_formula = safe_formula.replace(&original, &safe);
50 }
51
52 variables.push(safe);
53 }
54
55 (safe_formula, variables, safe_to_original)
56}
57
58#[derive(Clone, Debug, Builder)]
63#[cfg_attr(
64 feature = "python",
65 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
66)]
67pub struct SyntheticInstrument {
68 pub id: InstrumentId,
70 pub price_precision: u8,
72 pub price_increment: Price,
74 pub components: Vec<InstrumentId>,
76 pub formula: String,
82 pub ts_event: UnixNanos,
84 pub ts_init: UnixNanos,
86 context: HashMapContext,
87 variables: Vec<String>,
88 safe_to_original: HashMap<String, String>,
89 operator_tree: Node,
90}
91
92impl Serialize for SyntheticInstrument {
93 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
94 where
95 S: serde::Serializer,
96 {
97 use serde::ser::SerializeStruct;
98 let mut state = serializer.serialize_struct("SyntheticInstrument", 7)?;
99 state.serialize_field("id", &self.id)?;
100 state.serialize_field("price_precision", &self.price_precision)?;
101 state.serialize_field("price_increment", &self.price_increment)?;
102 state.serialize_field("components", &self.components)?;
103 state.serialize_field("formula", &self.formula)?;
104 state.serialize_field("ts_event", &self.ts_event)?;
105 state.serialize_field("ts_init", &self.ts_init)?;
106 state.end()
107 }
108}
109
110impl<'de> Deserialize<'de> for SyntheticInstrument {
111 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112 where
113 D: serde::Deserializer<'de>,
114 {
115 #[derive(Deserialize)]
116 struct Fields {
117 id: InstrumentId,
118 price_precision: u8,
119 price_increment: Price,
120 components: Vec<InstrumentId>,
121 formula: String,
122 ts_event: UnixNanos,
123 ts_init: UnixNanos,
124 }
125
126 let fields = Fields::deserialize(deserializer)?;
127
128 let (safe_formula, variables, safe_to_original) =
129 make_safe_formula_with_variables_and_mapping(&fields.formula, &fields.components);
130
131 let operator_tree =
132 evalexpr::build_operator_tree(&safe_formula).map_err(serde::de::Error::custom)?;
133
134 Ok(Self {
135 id: fields.id,
136 price_precision: fields.price_precision,
137 price_increment: fields.price_increment,
138 components: fields.components,
139 formula: safe_formula,
140 ts_event: fields.ts_event,
141 ts_init: fields.ts_init,
142 context: HashMapContext::new(),
143 variables,
144 safe_to_original,
145 operator_tree,
146 })
147 }
148}
149
150impl SyntheticInstrument {
151 pub fn new_checked(
160 symbol: Symbol,
161 price_precision: u8,
162 components: Vec<InstrumentId>,
163 formula: String,
164 ts_event: UnixNanos,
165 ts_init: UnixNanos,
166 ) -> anyhow::Result<Self> {
167 let price_increment = Price::new(10f64.powi(-i32::from(price_precision)), price_precision);
168
169 let (safe_formula, variables, safe_to_original) =
171 make_safe_formula_with_variables_and_mapping(&formula, &components);
172 let operator_tree = evalexpr::build_operator_tree(&safe_formula)?;
173
174 Ok(Self {
175 id: InstrumentId::new(symbol, Venue::synthetic()),
176 price_precision,
177 price_increment,
178 components,
179 formula: safe_formula,
180 context: HashMapContext::new(),
181 variables,
182 safe_to_original,
183 operator_tree,
184 ts_event,
185 ts_init,
186 })
187 }
188
189 pub fn is_valid_formula_for_components(formula: &str, components: &[InstrumentId]) -> bool {
190 let (safe_formula, _, _) =
191 make_safe_formula_with_variables_and_mapping(formula, components);
192 evalexpr::build_operator_tree(&safe_formula).is_ok()
193 }
194
195 pub fn new(
201 symbol: Symbol,
202 price_precision: u8,
203 components: Vec<InstrumentId>,
204 formula: String,
205 ts_event: UnixNanos,
206 ts_init: UnixNanos,
207 ) -> Self {
208 Self::new_checked(
209 symbol,
210 price_precision,
211 components,
212 formula,
213 ts_event,
214 ts_init,
215 )
216 .expect(FAILED)
217 }
218
219 #[must_use]
220 pub fn is_valid_formula(&self, formula: &str) -> bool {
221 Self::is_valid_formula_for_components(formula, &self.components)
222 }
223
224 pub fn change_formula(&mut self, formula: String) -> anyhow::Result<()> {
228 let (safe_formula, _, _) =
229 make_safe_formula_with_variables_and_mapping(&formula, &self.components);
230 let operator_tree = evalexpr::build_operator_tree(&safe_formula)?;
231 self.formula = safe_formula;
232 self.operator_tree = operator_tree;
233 Ok(())
234 }
235
236 pub fn calculate_from_map(&mut self, inputs: &HashMap<String, f64>) -> anyhow::Result<Price> {
243 let mut input_values = Vec::new();
244
245 for variable in &self.variables {
246 let original = self
247 .safe_to_original
248 .get(variable)
249 .ok_or_else(|| anyhow::anyhow!("Variable not found in mapping: {variable}"))?;
250
251 let value = inputs
252 .get(original)
253 .copied()
254 .ok_or_else(|| anyhow::anyhow!("Missing price for component: {original}"))?;
255
256 input_values.push(value);
257
258 self.context
259 .set_value(variable.clone(), Value::Float(value))
260 .map_err(|e| anyhow::anyhow!("Failed to set value for variable {variable}: {e}"))?;
261 }
262
263 self.calculate(&input_values)
264 }
265
266 pub fn calculate(&mut self, inputs: &[f64]) -> anyhow::Result<Price> {
272 if inputs.len() != self.variables.len() {
273 anyhow::bail!("Invalid number of input values");
274 }
275
276 for (variable, input) in self.variables.iter().zip(inputs) {
277 self.context
278 .set_value(variable.clone(), Value::Float(*input))?;
279 }
280
281 let result: Value = self.operator_tree.eval_with_context(&self.context)?;
282
283 match result {
284 Value::Float(price) => Ok(Price::new(price, self.price_precision)),
285 _ => anyhow::bail!("Failed to evaluate formula to a floating point number"),
286 }
287 }
288}
289
290impl PartialEq<Self> for SyntheticInstrument {
291 fn eq(&self, other: &Self) -> bool {
292 self.id == other.id
293 }
294}
295
296impl Eq for SyntheticInstrument {}
297
298impl Hash for SyntheticInstrument {
299 fn hash<H: Hasher>(&self, state: &mut H) {
300 self.id.hash(state);
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use std::str::FromStr;
307
308 use rstest::rstest;
309
310 use super::*;
311
312 #[rstest]
313 fn test_calculate_from_map() {
314 let mut synth = SyntheticInstrument::default();
315 let mut inputs = HashMap::new();
316 inputs.insert("BTC.BINANCE".to_string(), 100.0);
317 inputs.insert("LTC.BINANCE".to_string(), 200.0);
318 let price = synth.calculate_from_map(&inputs).unwrap();
319
320 assert_eq!(price, Price::from("150.0"));
321 assert_eq!(
322 synth.formula,
323 "(BTC.BINANCE + LTC.BINANCE) / 2.0".to_string()
324 );
325 }
326
327 #[rstest]
328 fn test_calculate() {
329 let mut synth = SyntheticInstrument::default();
330 let inputs = vec![100.0, 200.0];
331 let price = synth.calculate(&inputs).unwrap();
332 assert_eq!(price, Price::from("150.0"));
333 }
334
335 #[rstest]
336 fn test_change_formula() {
337 let mut synth = SyntheticInstrument::default();
338 let new_formula = "(BTC.BINANCE + LTC.BINANCE) / 4".to_string();
339 synth.change_formula(new_formula.clone()).unwrap();
340
341 let mut inputs = HashMap::new();
342 inputs.insert("BTC.BINANCE".to_string(), 100.0);
343 inputs.insert("LTC.BINANCE".to_string(), 200.0);
344 let price = synth.calculate_from_map(&inputs).unwrap();
345
346 assert_eq!(price, Price::from("75.0"));
347 assert_eq!(synth.formula, new_formula);
348 }
349
350 #[rstest]
351 fn test_hyphenated_instrument_ids_are_sanitized_and_backward_compatible_calculate() {
352 let comp1 = InstrumentId::from_str("ETHUSDC-PERP.BINANCE_FUTURES").unwrap();
353 let comp2 = InstrumentId::from_str("ETH_USDC-PERP.HYPERLIQUID").unwrap();
354
355 let components = vec![comp1, comp2];
356
357 let raw_formula = format!("({comp1} + {comp2}) / 2.0");
359
360 let symbol = Symbol::from("ETH-USDC");
361
362 let mut synth = SyntheticInstrument::new(
363 symbol,
364 2,
365 components.clone(),
366 raw_formula,
367 0.into(),
368 0.into(),
369 );
370
371 let mut inputs = HashMap::new();
372 inputs.insert(components[0].to_string(), 100.0);
373 inputs.insert(components[1].to_string(), 200.0);
374
375 let price = synth.calculate_from_map(&inputs).unwrap();
376
377 assert_eq!(price, Price::from("150.0"));
378 }
379
380 #[rstest]
381 fn test_hyphenated_instrument_ids_are_sanitized_calculate() {
382 let comp1 = InstrumentId::from_str("ETH-USDT-SWAP.OKX").unwrap();
383 let comp2 = InstrumentId::from_str("ETH-USDC-PERP.HYPERLIQUID").unwrap();
384
385 let components = vec![comp1, comp2];
386
387 let raw_formula = format!("({comp1} + {comp2}) / 2.0");
389
390 let symbol = Symbol::from("ETH-USD");
391
392 let mut synth =
393 SyntheticInstrument::new(symbol, 2, components, raw_formula, 0.into(), 0.into());
394
395 let inputs = vec![100.0, 200.0];
396 let price = synth.calculate(&inputs).unwrap();
397 assert_eq!(price, Price::from("150.0"));
398 }
399
400 #[rstest]
401 fn test_hyphenated_instrument_ids_are_sanitized_calculate_from_map() {
402 let comp1 = InstrumentId::from_str("ETH-USDT-SWAP.OKX").unwrap();
403 let comp2 = InstrumentId::from_str("ETH-USDC-PERP.HYPERLIQUID").unwrap();
404
405 let components = vec![comp1, comp2];
406
407 let raw_formula = format!("({comp1} + {comp2}) / 2.0");
409
410 let symbol = Symbol::from("ETH-USD");
411
412 let mut synth = SyntheticInstrument::new(
413 symbol,
414 2,
415 components.clone(),
416 raw_formula,
417 0.into(),
418 0.into(),
419 );
420
421 for c in &components {
424 let original = c.to_string();
425 let safe = original.replace('-', "_");
426
427 assert!(
428 !synth.formula.contains(&original),
429 "internal formula should not contain hyphenated identifier {original}"
430 );
431 assert!(
432 synth.formula.contains(&safe),
433 "internal formula should contain safe identifier {safe}"
434 );
435 }
436
437 let mut inputs = HashMap::new();
440 inputs.insert(components[0].to_string(), 100.0);
441 inputs.insert(components[1].to_string(), 200.0);
442
443 let price = synth.calculate_from_map(&inputs).unwrap();
444
445 assert_eq!(price, Price::from("150.0"));
446 }
447}