1use std::collections::{HashMap, HashSet};
21
22use nautilus_model::enums::PriceType;
23use ustr::Ustr;
24
25pub fn get_exchange_rate(
40 from_currency: Ustr,
41 to_currency: Ustr,
42 price_type: PriceType,
43 quotes_bid: HashMap<String, f64>,
44 quotes_ask: HashMap<String, f64>,
45) -> anyhow::Result<Option<f64>> {
46 if from_currency == to_currency {
47 return Ok(Some(1.0));
50 }
51
52 if quotes_bid.is_empty() || quotes_ask.is_empty() {
53 anyhow::bail!("Quote maps must not be empty");
54 }
55 if quotes_bid.len() != quotes_ask.len() {
56 anyhow::bail!("Quote maps must have equal lengths");
57 }
58
59 let effective_quotes: HashMap<String, f64> = match price_type {
61 PriceType::Bid => quotes_bid,
62 PriceType::Ask => quotes_ask,
63 PriceType::Mid => {
64 let mut mid_quotes = HashMap::new();
65 for (pair, bid) in "es_bid {
66 let ask = quotes_ask
67 .get(pair)
68 .ok_or_else(|| anyhow::anyhow!("Missing ask quote for pair {pair}"))?;
69 mid_quotes.insert(pair.clone(), (bid + ask) / 2.0);
70 }
71 mid_quotes
72 }
73 _ => anyhow::bail!("Invalid `price_type`, was '{price_type}'"),
74 };
75
76 let mut graph: HashMap<Ustr, Vec<(Ustr, f64)>> = HashMap::new();
78 for (pair, rate) in effective_quotes {
79 let parts: Vec<&str> = pair.split('/').collect();
80 if parts.len() != 2 {
81 log::warn!("Skipping invalid pair string: {pair}");
82 continue;
83 }
84 let base = Ustr::from(parts[0]);
85 let quote = Ustr::from(parts[1]);
86
87 graph.entry(base).or_default().push((quote, rate));
88 graph.entry(quote).or_default().push((base, 1.0 / rate));
89 }
90
91 let mut stack: Vec<(Ustr, f64)> = vec![(from_currency, 1.0)];
93 let mut visited: HashSet<Ustr> = HashSet::new();
94 visited.insert(from_currency);
95
96 while let Some((current, current_rate)) = stack.pop() {
97 if current == to_currency {
98 return Ok(Some(current_rate));
99 }
100 if let Some(neighbors) = graph.get(¤t) {
101 for (neighbor, rate) in neighbors {
102 if visited.insert(*neighbor) {
103 stack.push((*neighbor, current_rate * rate));
104 }
105 }
106 }
107 }
108
109 Ok(None)
111}
112
113#[cfg(test)]
117mod tests {
118 use std::collections::HashMap;
119
120 use rstest::rstest;
121 use ustr::Ustr;
122
123 use super::*;
124
125 fn setup_test_quotes() -> (HashMap<String, f64>, HashMap<String, f64>) {
126 let mut quotes_bid = HashMap::new();
127 let mut quotes_ask = HashMap::new();
128
129 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
131 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
132
133 quotes_bid.insert("GBP/USD".to_string(), 1.3000);
134 quotes_ask.insert("GBP/USD".to_string(), 1.3002);
135
136 quotes_bid.insert("USD/JPY".to_string(), 110.00);
137 quotes_ask.insert("USD/JPY".to_string(), 110.02);
138
139 quotes_bid.insert("AUD/USD".to_string(), 0.7500);
140 quotes_ask.insert("AUD/USD".to_string(), 0.7502);
141
142 (quotes_bid, quotes_ask)
143 }
144
145 #[rstest]
146 fn test_invalid_pair_string() {
147 let mut quotes_bid = HashMap::new();
148 let mut quotes_ask = HashMap::new();
149 quotes_bid.insert("EURUSD".to_string(), 1.1000);
151 quotes_ask.insert("EURUSD".to_string(), 1.1002);
152 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
154 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
155
156 let rate = get_exchange_rate(
157 Ustr::from("EUR"),
158 Ustr::from("USD"),
159 PriceType::Mid,
160 quotes_bid,
161 quotes_ask,
162 )
163 .unwrap();
164
165 let expected = f64::midpoint(1.1000, 1.1002);
166 assert!((rate.unwrap() - expected).abs() < 0.0001);
167 }
168
169 #[rstest]
170 fn test_same_currency() {
171 let (quotes_bid, quotes_ask) = setup_test_quotes();
172 let rate = get_exchange_rate(
173 Ustr::from("USD"),
174 Ustr::from("USD"),
175 PriceType::Mid,
176 quotes_bid,
177 quotes_ask,
178 )
179 .unwrap();
180 assert_eq!(rate, Some(1.0));
181 }
182
183 #[rstest(
184 price_type,
185 expected,
186 case(PriceType::Bid, 1.1000),
187 case(PriceType::Ask, 1.1002),
188 case(PriceType::Mid, f64::midpoint(1.1000, 1.1002))
189 )]
190 fn test_direct_pair(price_type: PriceType, expected: f64) {
191 let (quotes_bid, quotes_ask) = setup_test_quotes();
192
193 let rate = get_exchange_rate(
194 Ustr::from("EUR"),
195 Ustr::from("USD"),
196 price_type,
197 quotes_bid,
198 quotes_ask,
199 )
200 .unwrap();
201
202 let rate = rate.unwrap_or_else(|| panic!("Expected a conversion rate for {price_type}"));
203 assert!((rate - expected).abs() < 0.0001);
204 }
205
206 #[rstest]
207 fn test_inverse_pair() {
208 let (quotes_bid, quotes_ask) = setup_test_quotes();
209
210 let rate_eur_usd = get_exchange_rate(
211 Ustr::from("EUR"),
212 Ustr::from("USD"),
213 PriceType::Mid,
214 quotes_bid.clone(),
215 quotes_ask.clone(),
216 )
217 .unwrap();
218 let rate_usd_eur = get_exchange_rate(
219 Ustr::from("USD"),
220 Ustr::from("EUR"),
221 PriceType::Mid,
222 quotes_bid,
223 quotes_ask,
224 )
225 .unwrap();
226 if let (Some(eur_usd), Some(usd_eur)) = (rate_eur_usd, rate_usd_eur) {
227 assert!(eur_usd.mul_add(usd_eur, -1.0).abs() < 0.0001);
228 } else {
229 panic!("Expected valid conversion rates for inverse conversion");
230 }
231 }
232
233 #[rstest]
234 fn test_cross_pair_through_usd() {
235 let (quotes_bid, quotes_ask) = setup_test_quotes();
236 let rate = get_exchange_rate(
237 Ustr::from("EUR"),
238 Ustr::from("JPY"),
239 PriceType::Mid,
240 quotes_bid,
241 quotes_ask,
242 )
243 .unwrap();
244 let mid_eur_usd = f64::midpoint(1.1000, 1.1002);
246 let mid_usd_jpy = f64::midpoint(110.00, 110.02);
247 let expected = mid_eur_usd * mid_usd_jpy;
248 if let Some(val) = rate {
249 assert!((val - expected).abs() < 0.1);
250 } else {
251 panic!("Expected conversion rate through USD but got None");
252 }
253 }
254
255 #[rstest]
256 fn test_no_conversion_path() {
257 let mut quotes_bid = HashMap::new();
258 let mut quotes_ask = HashMap::new();
259
260 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
262 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
263
264 let rate = get_exchange_rate(
266 Ustr::from("EUR"),
267 Ustr::from("JPY"),
268 PriceType::Mid,
269 quotes_bid,
270 quotes_ask,
271 )
272 .unwrap();
273 assert_eq!(rate, None);
274 }
275
276 #[rstest]
277 fn test_empty_quotes() {
278 let quotes_bid: HashMap<String, f64> = HashMap::new();
279 let quotes_ask: HashMap<String, f64> = HashMap::new();
280 let result = get_exchange_rate(
281 Ustr::from("EUR"),
282 Ustr::from("USD"),
283 PriceType::Mid,
284 quotes_bid,
285 quotes_ask,
286 );
287 assert!(result.is_err());
288 }
289
290 #[rstest]
291 fn test_unequal_quotes_length() {
292 let mut quotes_bid = HashMap::new();
293 let mut quotes_ask = HashMap::new();
294
295 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
296 quotes_bid.insert("GBP/USD".to_string(), 1.3000);
297 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
298 let result = get_exchange_rate(
301 Ustr::from("EUR"),
302 Ustr::from("USD"),
303 PriceType::Mid,
304 quotes_bid,
305 quotes_ask,
306 );
307 assert!(result.is_err());
308 }
309
310 #[rstest]
311 fn test_invalid_price_type() {
312 let (quotes_bid, quotes_ask) = setup_test_quotes();
313 let result = get_exchange_rate(
315 Ustr::from("EUR"),
316 Ustr::from("USD"),
317 PriceType::Last,
318 quotes_bid,
319 quotes_ask,
320 );
321 assert!(result.is_err());
322 }
323
324 #[rstest]
325 fn test_cycle_handling() {
326 let mut quotes_bid = HashMap::new();
327 let mut quotes_ask = HashMap::new();
328 quotes_bid.insert("EUR/USD".to_string(), 1.1);
330 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
331 quotes_bid.insert("USD/EUR".to_string(), 0.909);
332 quotes_ask.insert("USD/EUR".to_string(), 0.9091);
333
334 let rate = get_exchange_rate(
335 Ustr::from("EUR"),
336 Ustr::from("USD"),
337 PriceType::Mid,
338 quotes_bid,
339 quotes_ask,
340 )
341 .unwrap();
342
343 let expected = f64::midpoint(1.1, 1.1002);
345 assert!((rate.unwrap() - expected).abs() < 0.0001);
346 }
347
348 #[rstest]
349 fn test_multiple_paths() {
350 let mut quotes_bid = HashMap::new();
351 let mut quotes_ask = HashMap::new();
352 quotes_bid.insert("EUR/USD".to_string(), 1.1000);
354 quotes_ask.insert("EUR/USD".to_string(), 1.1002);
355 quotes_bid.insert("EUR/GBP".to_string(), 0.8461);
357 quotes_ask.insert("EUR/GBP".to_string(), 0.8463);
358 quotes_bid.insert("GBP/USD".to_string(), 1.3000);
359 quotes_ask.insert("GBP/USD".to_string(), 1.3002);
360
361 let rate = get_exchange_rate(
362 Ustr::from("EUR"),
363 Ustr::from("USD"),
364 PriceType::Mid,
365 quotes_bid,
366 quotes_ask,
367 )
368 .unwrap();
369
370 let direct: f64 = f64::midpoint(1.1000_f64, 1.1002_f64);
372 let indirect: f64 =
373 f64::midpoint(0.8461_f64, 0.8463_f64) * f64::midpoint(1.3000_f64, 1.3002_f64);
374 assert!((direct - indirect).abs() < 0.0001_f64);
375 assert!((rate.unwrap() - direct).abs() < 0.0001_f64);
376 }
377}