1use std::collections::{HashMap, HashSet};
21
22use nautilus_model::enums::PriceType;
23use ustr::Ustr;
24
25pub fn get_exchange_rate(
40 from_currency: Ustr,
41 to_currency: Ustr,
42 price_type: PriceType,
43 quotes_bid: HashMap<String, f64>,
44 quotes_ask: HashMap<String, f64>,
45) -> anyhow::Result<Option<f64>> {
46 if from_currency == to_currency {
47 return Ok(Some(1.0));
50 }
51
52 if quotes_bid.is_empty() || quotes_ask.is_empty() {
53 return Err(anyhow::anyhow!("Quote maps must not be empty"));
54 }
55 if quotes_bid.len() != quotes_ask.len() {
56 return Err(anyhow::anyhow!("Quote maps must have equal lengths"));
57 }
58
59 let effective_quotes: HashMap<String, f64> = match price_type {
61 PriceType::Bid => quotes_bid,
62 PriceType::Ask => quotes_ask,
63 PriceType::Mid => {
64 let mut mid_quotes = HashMap::new();
65 for (pair, bid) in "es_bid {
66 let ask = quotes_ask
67 .get(pair)
68 .ok_or_else(|| anyhow::anyhow!("Missing ask quote for pair {pair}"))?;
69 mid_quotes.insert(pair.clone(), (bid + ask) / 2.0);
70 }
71 mid_quotes
72 }
73 _ => anyhow::bail!("Invalid `price_type`, was '{price_type}'"),
74 };
75
76 let mut graph: HashMap<Ustr, Vec<(Ustr, f64)>> = HashMap::new();
78 for (pair, rate) in effective_quotes {
79 let parts: Vec<&str> = pair.split('/').collect();
80 if parts.len() != 2 {
81 log::warn!("Skipping invalid pair string: {pair}");
82 continue;
83 }
84 let base = Ustr::from(parts[0]);
85 let quote = Ustr::from(parts[1]);
86
87 graph.entry(base).or_default().push((quote, rate));
88 graph.entry(quote).or_default().push((base, 1.0 / rate));
89 }
90
91 let mut stack: Vec<(Ustr, f64)> = vec![(from_currency, 1.0)];
93 let mut visited: HashSet<Ustr> = HashSet::new();
94 visited.insert(from_currency);
95
96 while let Some((current, current_rate)) = stack.pop() {
97 if current == to_currency {
98 return Ok(Some(current_rate));
99 }
100 if let Some(neighbors) = graph.get(¤t) {
101 for (neighbor, rate) in neighbors {
102 if visited.insert(*neighbor) {
103 stack.push((*neighbor, current_rate * rate));
104 }
105 }
106 }
107 }
108
109 Ok(None)
111}
112
113#[cfg(test)]
117mod tests {
118 use std::collections::HashMap;
119
120 use rstest::rstest;
121 use ustr::Ustr;
122
123 use super::*;
124
125 fn setup_test_quotes() -> (HashMap<String, f64>, HashMap<String, f64>) {
126 let mut quotes_bid = HashMap::new();
127 let mut quotes_ask = HashMap::new();
128
129 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
131 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
132
133 quotes_bid.insert("GBP/USD".to_string(), 1.3000);
134 quotes_ask.insert("GBP/USD".to_string(), 1.3002);
135
136 quotes_bid.insert("USD/JPY".to_string(), 110.00);
137 quotes_ask.insert("USD/JPY".to_string(), 110.02);
138
139 quotes_bid.insert("AUD/USD".to_string(), 0.7500);
140 quotes_ask.insert("AUD/USD".to_string(), 0.7502);
141
142 (quotes_bid, quotes_ask)
143 }
144
145 #[rstest]
146 fn test_invalid_pair_string() {
147 let mut quotes_bid = HashMap::new();
148 let mut quotes_ask = HashMap::new();
149 quotes_bid.insert("EURUSD".to_string(), 1.1000);
151 quotes_ask.insert("EURUSD".to_string(), 1.1002);
152 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
154 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
155
156 let rate = get_exchange_rate(
157 Ustr::from("EUR"),
158 Ustr::from("USD"),
159 PriceType::Mid,
160 quotes_bid,
161 quotes_ask,
162 )
163 .unwrap();
164
165 let expected = (1.1000 + 1.1002) / 2.0;
166 assert!((rate.unwrap() - expected).abs() < 0.0001);
167 }
168
169 #[rstest]
170 fn test_same_currency() {
171 let (quotes_bid, quotes_ask) = setup_test_quotes();
172 let rate = get_exchange_rate(
173 Ustr::from("USD"),
174 Ustr::from("USD"),
175 PriceType::Mid,
176 quotes_bid,
177 quotes_ask,
178 )
179 .unwrap();
180 assert_eq!(rate, Some(1.0));
181 }
182
183 #[rstest(
184 price_type, expected,
185 case(PriceType::Bid, 1.1000),
186 case(PriceType::Ask, 1.1002),
187 case(PriceType::Mid, (1.1000 + 1.1002) / 2.0)
188 )]
189 fn test_direct_pair(price_type: PriceType, expected: f64) {
190 let (quotes_bid, quotes_ask) = setup_test_quotes();
191
192 let rate = get_exchange_rate(
193 Ustr::from("EUR"),
194 Ustr::from("USD"),
195 price_type,
196 quotes_bid,
197 quotes_ask,
198 )
199 .unwrap();
200
201 let rate = rate.unwrap_or_else(|| panic!("Expected a conversion rate for {price_type}"));
202 assert!((rate - expected).abs() < 0.0001);
203 }
204
205 #[rstest]
206 fn test_inverse_pair() {
207 let (quotes_bid, quotes_ask) = setup_test_quotes();
208
209 let rate_eur_usd = get_exchange_rate(
210 Ustr::from("EUR"),
211 Ustr::from("USD"),
212 PriceType::Mid,
213 quotes_bid.clone(),
214 quotes_ask.clone(),
215 )
216 .unwrap();
217 let rate_usd_eur = get_exchange_rate(
218 Ustr::from("USD"),
219 Ustr::from("EUR"),
220 PriceType::Mid,
221 quotes_bid,
222 quotes_ask,
223 )
224 .unwrap();
225 if let (Some(eur_usd), Some(usd_eur)) = (rate_eur_usd, rate_usd_eur) {
226 assert!(eur_usd.mul_add(usd_eur, -1.0).abs() < 0.0001);
227 } else {
228 panic!("Expected valid conversion rates for inverse conversion");
229 }
230 }
231
232 #[rstest]
233 fn test_cross_pair_through_usd() {
234 let (quotes_bid, quotes_ask) = setup_test_quotes();
235 let rate = get_exchange_rate(
236 Ustr::from("EUR"),
237 Ustr::from("JPY"),
238 PriceType::Mid,
239 quotes_bid,
240 quotes_ask,
241 )
242 .unwrap();
243 let mid_eur_usd = (1.1000 + 1.1002) / 2.0;
245 let mid_usd_jpy = (110.00 + 110.02) / 2.0;
246 let expected = mid_eur_usd * mid_usd_jpy;
247 if let Some(val) = rate {
248 assert!((val - expected).abs() < 0.1);
249 } else {
250 panic!("Expected conversion rate through USD but got None");
251 }
252 }
253
254 #[rstest]
255 fn test_no_conversion_path() {
256 let mut quotes_bid = HashMap::new();
257 let mut quotes_ask = HashMap::new();
258
259 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
261 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
262
263 let rate = get_exchange_rate(
265 Ustr::from("EUR"),
266 Ustr::from("JPY"),
267 PriceType::Mid,
268 quotes_bid,
269 quotes_ask,
270 )
271 .unwrap();
272 assert_eq!(rate, None);
273 }
274
275 #[rstest]
276 fn test_empty_quotes() {
277 let quotes_bid: HashMap<String, f64> = HashMap::new();
278 let quotes_ask: HashMap<String, f64> = HashMap::new();
279 let result = get_exchange_rate(
280 Ustr::from("EUR"),
281 Ustr::from("USD"),
282 PriceType::Mid,
283 quotes_bid,
284 quotes_ask,
285 );
286 assert!(result.is_err());
287 }
288
289 #[rstest]
290 fn test_unequal_quotes_length() {
291 let mut quotes_bid = HashMap::new();
292 let mut quotes_ask = HashMap::new();
293
294 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
295 quotes_bid.insert("GBP/USD".to_string(), 1.3000);
296 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
297 let result = get_exchange_rate(
300 Ustr::from("EUR"),
301 Ustr::from("USD"),
302 PriceType::Mid,
303 quotes_bid,
304 quotes_ask,
305 );
306 assert!(result.is_err());
307 }
308
309 #[rstest]
310 fn test_invalid_price_type() {
311 let (quotes_bid, quotes_ask) = setup_test_quotes();
312 let result = get_exchange_rate(
314 Ustr::from("EUR"),
315 Ustr::from("USD"),
316 PriceType::Last,
317 quotes_bid,
318 quotes_ask,
319 );
320 assert!(result.is_err());
321 }
322
323 #[rstest]
324 fn test_cycle_handling() {
325 let mut quotes_bid = HashMap::new();
326 let mut quotes_ask = HashMap::new();
327 quotes_bid.insert("EUR/USD".to_string(), 1.1);
329 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
330 quotes_bid.insert("USD/EUR".to_string(), 0.909);
331 quotes_ask.insert("USD/EUR".to_string(), 0.9091);
332
333 let rate = get_exchange_rate(
334 Ustr::from("EUR"),
335 Ustr::from("USD"),
336 PriceType::Mid,
337 quotes_bid,
338 quotes_ask,
339 )
340 .unwrap();
341
342 let expected = (1.1 + 1.1002) / 2.0;
344 assert!((rate.unwrap() - expected).abs() < 0.0001);
345 }
346
347 #[rstest]
348 fn test_multiple_paths() {
349 let mut quotes_bid = HashMap::new();
350 let mut quotes_ask = HashMap::new();
351 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
353 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
354 quotes_bid.insert("EUR/GBP".to_string(), 0.8461);
356 quotes_ask.insert("EUR/GBP".to_string(), 0.8463);
357 quotes_bid.insert("GBP/USD".to_string(), 1.3000);
358 quotes_ask.insert("GBP/USD".to_string(), 1.3002);
359
360 let rate = get_exchange_rate(
361 Ustr::from("EUR"),
362 Ustr::from("USD"),
363 PriceType::Mid,
364 quotes_bid,
365 quotes_ask,
366 )
367 .unwrap();
368
369 let direct: f64 = (1.1000_f64 + 1.1002_f64) / 2.0;
371 let indirect: f64 = ((0.8461_f64 + 0.8463_f64) / 2.0) * ((1.3000_f64 + 1.3002_f64) / 2.0);
372 assert!((direct - indirect).abs() < 0.0001_f64);
373 assert!((rate.unwrap() - direct).abs() < 0.0001_f64);
374 }
375}