nautilus_model/instruments/
synthetic.rs1use std::{
17 collections::HashMap,
18 hash::{Hash, Hasher},
19};
20
21use derive_builder::Builder;
22use evalexpr::{ContextWithMutableVariables, DefaultNumericTypes, HashMapContext, Node, Value};
23use nautilus_core::{UnixNanos, correctness::FAILED};
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 identifiers::{InstrumentId, Symbol, Venue},
28 types::Price,
29};
30#[derive(Clone, Debug, Builder)]
35#[cfg_attr(
36 feature = "python",
37 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.model")
38)]
39pub struct SyntheticInstrument {
40 pub id: InstrumentId,
42 pub price_precision: u8,
44 pub price_increment: Price,
46 pub components: Vec<InstrumentId>,
48 pub formula: String,
50 pub ts_event: UnixNanos,
52 pub ts_init: UnixNanos,
54 context: HashMapContext,
55 variables: Vec<String>,
56 operator_tree: Node,
57}
58
59impl Serialize for SyntheticInstrument {
60 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
61 where
62 S: serde::Serializer,
63 {
64 use serde::ser::SerializeStruct;
65 let mut state = serializer.serialize_struct("SyntheticInstrument", 7)?;
66 state.serialize_field("id", &self.id)?;
67 state.serialize_field("price_precision", &self.price_precision)?;
68 state.serialize_field("price_increment", &self.price_increment)?;
69 state.serialize_field("components", &self.components)?;
70 state.serialize_field("formula", &self.formula)?;
71 state.serialize_field("ts_event", &self.ts_event)?;
72 state.serialize_field("ts_init", &self.ts_init)?;
73 state.end()
74 }
75}
76
77impl<'de> Deserialize<'de> for SyntheticInstrument {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: serde::Deserializer<'de>,
81 {
82 #[derive(Deserialize)]
83 struct Fields {
84 id: InstrumentId,
85 price_precision: u8,
86 price_increment: Price,
87 components: Vec<InstrumentId>,
88 formula: String,
89 ts_event: UnixNanos,
90 ts_init: UnixNanos,
91 }
92
93 let fields = Fields::deserialize(deserializer)?;
94
95 let variables = fields
96 .components
97 .iter()
98 .map(std::string::ToString::to_string)
99 .collect();
100
101 let operator_tree =
102 evalexpr::build_operator_tree(&fields.formula).map_err(serde::de::Error::custom)?;
103
104 Ok(SyntheticInstrument {
105 id: fields.id,
106 price_precision: fields.price_precision,
107 price_increment: fields.price_increment,
108 components: fields.components,
109 formula: fields.formula,
110 ts_event: fields.ts_event,
111 ts_init: fields.ts_init,
112 context: HashMapContext::new(),
113 variables,
114 operator_tree,
115 })
116 }
117}
118
119impl SyntheticInstrument {
120 pub fn new_checked(
126 symbol: Symbol,
127 price_precision: u8,
128 components: Vec<InstrumentId>,
129 formula: String,
130 ts_event: UnixNanos,
131 ts_init: UnixNanos,
132 ) -> anyhow::Result<Self> {
133 let price_increment = Price::new(10f64.powi(-i32::from(price_precision)), price_precision);
134
135 let variables: Vec<String> = components
137 .iter()
138 .map(std::string::ToString::to_string)
139 .collect();
140
141 let operator_tree = evalexpr::build_operator_tree(&formula)?;
142
143 Ok(Self {
144 id: InstrumentId::new(symbol, Venue::synthetic()),
145 price_precision,
146 price_increment,
147 components,
148 formula,
149 context: HashMapContext::new(),
150 variables,
151 operator_tree,
152 ts_event,
153 ts_init,
154 })
155 }
156
157 pub fn new(
159 symbol: Symbol,
160 price_precision: u8,
161 components: Vec<InstrumentId>,
162 formula: String,
163 ts_event: UnixNanos,
164 ts_init: UnixNanos,
165 ) -> Self {
166 Self::new_checked(
167 symbol,
168 price_precision,
169 components,
170 formula,
171 ts_event,
172 ts_init,
173 )
174 .expect(FAILED)
175 }
176
177 #[must_use]
178 pub fn is_valid_formula(&self, formula: &str) -> bool {
179 evalexpr::build_operator_tree::<DefaultNumericTypes>(formula).is_ok()
180 }
181
182 pub fn change_formula(&mut self, formula: String) -> anyhow::Result<()> {
183 let operator_tree = evalexpr::build_operator_tree::<DefaultNumericTypes>(&formula)?;
184 self.formula = formula;
185 self.operator_tree = operator_tree;
186 Ok(())
187 }
188
189 #[allow(dead_code)]
192 pub fn calculate_from_map(&mut self, inputs: &HashMap<String, f64>) -> anyhow::Result<Price> {
193 let mut input_values = Vec::new();
194
195 for variable in &self.variables {
196 if let Some(&value) = inputs.get(variable) {
197 input_values.push(value);
198 self.context
199 .set_value(variable.clone(), Value::Float(value))
200 .expect("TODO: Unable to set value");
201 } else {
202 panic!("Missing price for component: {variable}");
203 }
204 }
205
206 self.calculate(&input_values)
207 }
208
209 pub fn calculate(&mut self, inputs: &[f64]) -> anyhow::Result<Price> {
212 if inputs.len() != self.variables.len() {
213 return Err(anyhow::anyhow!("Invalid number of input values"));
214 }
215
216 for (variable, input) in self.variables.iter().zip(inputs) {
217 self.context
218 .set_value(variable.clone(), Value::Float(*input))?;
219 }
220
221 let result: Value = self.operator_tree.eval_with_context(&self.context)?;
222
223 match result {
224 Value::Float(price) => Ok(Price::new(price, self.price_precision)),
225 _ => Err(anyhow::anyhow!(
226 "Failed to evaluate formula to a floating point number"
227 )),
228 }
229 }
230}
231
232impl PartialEq<Self> for SyntheticInstrument {
233 fn eq(&self, other: &Self) -> bool {
234 self.id == other.id
235 }
236}
237
238impl Eq for SyntheticInstrument {}
239
240impl Hash for SyntheticInstrument {
241 fn hash<H: Hasher>(&self, state: &mut H) {
242 self.id.hash(state);
243 }
244}
245
246#[cfg(test)]
250mod tests {
251 use rstest::rstest;
252
253 use super::*;
254
255 #[rstest]
256 fn test_calculate_from_map() {
257 let mut synth = SyntheticInstrument::default();
258 let mut inputs = HashMap::new();
259 inputs.insert("BTC.BINANCE".to_string(), 100.0);
260 inputs.insert("LTC.BINANCE".to_string(), 200.0);
261 let price = synth.calculate_from_map(&inputs).unwrap();
262
263 assert_eq!(price.as_f64(), 150.0);
264 assert_eq!(
265 synth.formula,
266 "(BTC.BINANCE + LTC.BINANCE) / 2.0".to_string()
267 );
268 }
269
270 #[rstest]
271 fn test_calculate() {
272 let mut synth = SyntheticInstrument::default();
273 let inputs = vec![100.0, 200.0];
274 let price = synth.calculate(&inputs).unwrap();
275 assert_eq!(price.as_f64(), 150.0);
276 }
277
278 #[rstest]
279 fn test_change_formula() {
280 let mut synth = SyntheticInstrument::default();
281 let new_formula = "(BTC.BINANCE + LTC.BINANCE) / 4".to_string();
282 synth.change_formula(new_formula.clone()).unwrap();
283
284 let mut inputs = HashMap::new();
285 inputs.insert("BTC.BINANCE".to_string(), 100.0);
286 inputs.insert("LTC.BINANCE".to_string(), 200.0);
287 let price = synth.calculate_from_map(&inputs).unwrap();
288
289 assert_eq!(price.as_f64(), 75.0);
290 assert_eq!(synth.formula, new_formula);
291 }
292}