1use std::collections::{HashMap, HashSet};
25
26use itertools::Itertools;
27use nautilus_core::correctness::{check_equal_usize, check_map_not_empty, FAILED};
28use nautilus_model::{enums::PriceType, identifiers::Symbol, types::Currency};
29use rust_decimal::Decimal;
30use ustr::Ustr;
31
32#[must_use]
36pub fn get_exchange_rate(
37 from_currency: Currency,
38 to_currency: Currency,
39 price_type: PriceType,
40 quotes_bid: HashMap<Symbol, Decimal>,
41 quotes_ask: HashMap<Symbol, Decimal>,
42) -> Decimal {
43 check_map_not_empty("es_bid, stringify!(quotes_bid)).expect(FAILED);
44 check_map_not_empty("es_ask, stringify!(quotes_ask)).expect(FAILED);
45 check_equal_usize(
46 quotes_bid.len(),
47 quotes_ask.len(),
48 "quotes_bid.len()",
49 "quotes_ask.len()",
50 )
51 .expect(FAILED);
52
53 if from_currency == to_currency {
54 return Decimal::ONE;
55 }
56
57 let calculation_quotes = match price_type {
58 PriceType::Bid => quotes_bid,
59 PriceType::Ask => quotes_ask,
60 PriceType::Mid => quotes_bid
61 .iter()
62 .map(|(k, v)| {
63 let ask = quotes_ask.get(k).unwrap_or(v);
64 (*k, (v + ask) / Decimal::TWO)
65 })
66 .collect(),
67 _ => {
68 panic!("Cannot calculate exchange rate for PriceType: {price_type:?}");
69 }
70 };
71
72 let mut codes = HashSet::new();
73 let mut exchange_rates: HashMap<Ustr, HashMap<Ustr, Decimal>> = HashMap::new();
74
75 for (symbol, quote) in &calculation_quotes {
77 let pieces: Vec<&str> = symbol.as_str().split('/').collect();
79 let code_lhs = Ustr::from(pieces[0]);
80 let code_rhs = Ustr::from(pieces[1]);
81
82 codes.insert(code_lhs);
83 codes.insert(code_rhs);
84
85 exchange_rates.entry(code_lhs).or_default();
87 exchange_rates.entry(code_rhs).or_default();
88
89 if let Some(rates_lhs) = exchange_rates.get_mut(&code_lhs) {
91 rates_lhs.insert(code_lhs, Decimal::ONE);
92 rates_lhs.insert(code_rhs, *quote);
93 }
94 if let Some(rates_rhs) = exchange_rates.get_mut(&code_rhs) {
95 rates_rhs.insert(code_rhs, Decimal::ONE);
96 }
97 }
98
99 let code_perms: Vec<(Ustr, Ustr)> = codes
101 .iter()
102 .cartesian_product(codes.iter())
103 .filter(|(a, b)| a != b)
104 .map(|(a, b)| (*a, *b))
105 .collect();
106
107 for (perm0, perm1) in &code_perms {
109 let rate_0_to_1 = exchange_rates
111 .get(perm0)
112 .and_then(|rates| rates.get(perm1))
113 .copied();
114
115 if let Some(rate) = rate_0_to_1 {
116 if let Some(xrate_perm1) = exchange_rates.get_mut(perm1) {
117 if !xrate_perm1.contains_key(perm0) {
118 xrate_perm1.insert(*perm0, Decimal::ONE / rate);
119 }
120 }
121 }
122
123 let rate_1_to_0 = exchange_rates
125 .get(perm1)
126 .and_then(|rates| rates.get(perm0))
127 .copied();
128
129 if let Some(rate) = rate_1_to_0 {
130 if let Some(xrate_perm0) = exchange_rates.get_mut(perm0) {
131 if !xrate_perm0.contains_key(perm1) {
132 xrate_perm0.insert(*perm1, Decimal::ONE / rate);
133 }
134 }
135 }
136 }
137
138 if let Some(quotes) = exchange_rates.get(&from_currency.code) {
140 if let Some(&rate) = quotes.get(&to_currency.code) {
141 return rate;
142 }
143 }
144
145 for (perm0, perm1) in &code_perms {
147 if exchange_rates
149 .get(perm1)
150 .is_some_and(|rates| rates.contains_key(perm0))
151 {
152 continue;
153 }
154
155 for code in &codes {
157 let rates_through_common = {
159 let rates_perm0 = exchange_rates.get(perm0);
160 let rates_perm1 = exchange_rates.get(perm1);
161
162 match (rates_perm0, rates_perm1) {
163 (Some(rates0), Some(rates1)) => {
164 if let (Some(&rate1), Some(&rate2)) = (rates0.get(code), rates1.get(code)) {
165 Some((rate1, rate2))
166 } else {
167 None
168 }
169 }
170 _ => None,
171 }
172 };
173
174 let rates_from_code = if rates_through_common.is_none() {
176 if let Some(rates_code) = exchange_rates.get(code) {
177 if let (Some(&rate1), Some(&rate2)) =
178 (rates_code.get(perm0), rates_code.get(perm1))
179 {
180 Some((rate1, rate2))
181 } else {
182 None
183 }
184 } else {
185 None
186 }
187 } else {
188 None
189 };
190
191 if let Some((common_rate1, common_rate2)) = rates_through_common.or(rates_from_code) {
193 if let Some(rates_perm1) = exchange_rates.get_mut(perm1) {
195 rates_perm1.insert(*perm0, common_rate2 / common_rate1);
196 }
197
198 if let Some(rates_perm0) = exchange_rates.get_mut(perm0) {
200 if !rates_perm0.contains_key(perm1) {
201 rates_perm0.insert(*perm1, common_rate1 / common_rate2);
202 }
203 }
204 }
205 }
206 }
207
208 let xrate = exchange_rates
209 .get(&from_currency.code)
210 .and_then(|quotes| quotes.get(&to_currency.code))
211 .copied()
212 .unwrap_or(Decimal::ZERO);
213
214 xrate
215}
216
217#[cfg(test)]
218mod tests {
219 use std::str::FromStr;
220
221 use rust_decimal::prelude::FromPrimitive;
222 use rust_decimal_macros::dec;
223
224 use super::*;
225
226 fn setup_test_quotes() -> (HashMap<Symbol, Decimal>, HashMap<Symbol, Decimal>) {
228 let mut quotes_bid = HashMap::new();
229 let mut quotes_ask = HashMap::new();
230
231 quotes_bid.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1000));
233 quotes_ask.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1002));
234
235 quotes_bid.insert(Symbol::from_str_unchecked("GBP/USD"), dec!(1.3000));
236 quotes_ask.insert(Symbol::from_str_unchecked("GBP/USD"), dec!(1.3002));
237
238 quotes_bid.insert(Symbol::from_str_unchecked("USD/JPY"), dec!(110.00));
239 quotes_ask.insert(Symbol::from_str_unchecked("USD/JPY"), dec!(110.02));
240
241 quotes_bid.insert(Symbol::from_str_unchecked("AUD/USD"), dec!(0.7500));
242 quotes_ask.insert(Symbol::from_str_unchecked("AUD/USD"), dec!(0.7502));
243
244 (quotes_bid, quotes_ask)
245 }
246
247 #[test]
248 fn test_same_currency() {
250 let (quotes_bid, quotes_ask) = setup_test_quotes();
251 let rate = get_exchange_rate(
252 Currency::from_str("USD").unwrap(),
253 Currency::from_str("USD").unwrap(),
254 PriceType::Mid,
255 quotes_bid,
256 quotes_ask,
257 );
258 assert_eq!(rate, Decimal::ONE);
259 }
260
261 #[test]
262 fn test_direct_pair() {
264 let (quotes_bid, quotes_ask) = setup_test_quotes();
265
266 let rate_bid = get_exchange_rate(
268 Currency::from_str("EUR").unwrap(),
269 Currency::from_str("USD").unwrap(),
270 PriceType::Bid,
271 quotes_bid.clone(),
272 quotes_ask.clone(),
273 );
274 assert_eq!(rate_bid, dec!(1.1000));
275
276 let rate_ask = get_exchange_rate(
278 Currency::from_str("EUR").unwrap(),
279 Currency::from_str("USD").unwrap(),
280 PriceType::Ask,
281 quotes_bid.clone(),
282 quotes_ask.clone(),
283 );
284 assert_eq!(rate_ask, dec!(1.1002));
285
286 let rate_mid = get_exchange_rate(
288 Currency::from_str("EUR").unwrap(),
289 Currency::from_str("USD").unwrap(),
290 PriceType::Mid,
291 quotes_bid,
292 quotes_ask,
293 );
294 assert_eq!(rate_mid, dec!(1.1001));
295 }
296
297 #[test]
298 fn test_inverse_pair() {
300 let (quotes_bid, quotes_ask) = setup_test_quotes();
301
302 let rate = get_exchange_rate(
303 Currency::from_str("USD").unwrap(),
304 Currency::from_str("EUR").unwrap(),
305 PriceType::Mid,
306 quotes_bid,
307 quotes_ask,
308 );
309
310 let expected = Decimal::ONE / dec!(1.1001);
312 assert!((rate - expected).abs() < dec!(0.0001));
313 }
314
315 #[test]
316 fn test_cross_pair_through_usd() {
318 let (quotes_bid, quotes_ask) = setup_test_quotes();
319
320 let rate = get_exchange_rate(
321 Currency::from_str("EUR").unwrap(),
322 Currency::from_str("JPY").unwrap(),
323 PriceType::Mid,
324 quotes_bid,
325 quotes_ask,
326 );
327
328 let expected = dec!(1.1001) * dec!(110.01);
330 assert!((rate - expected).abs() < dec!(0.01));
331 }
332
333 #[test]
334 fn test_multiple_path_cross_pair() {
336 let (quotes_bid, quotes_ask) = setup_test_quotes();
337
338 let rate = get_exchange_rate(
339 Currency::from_str("GBP").unwrap(),
340 Currency::from_str("AUD").unwrap(),
341 PriceType::Mid,
342 quotes_bid,
343 quotes_ask,
344 );
345
346 let expected = dec!(1.3001) / dec!(0.7501);
349 assert!((rate - expected).abs() < dec!(0.01));
350 }
351
352 #[test]
353 fn test_missing_pairs() {
355 let mut quotes_bid = HashMap::new();
356 let mut quotes_ask = HashMap::new();
357
358 quotes_bid.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1000));
360 quotes_ask.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1002));
361
362 let rate = get_exchange_rate(
363 Currency::from_str("EUR").unwrap(),
364 Currency::from_str("JPY").unwrap(),
365 PriceType::Mid,
366 quotes_bid,
367 quotes_ask,
368 );
369
370 assert_eq!(rate, Decimal::ZERO); }
372
373 #[test]
374 #[should_panic]
375 fn test_empty_quotes() {
376 let quotes_bid = HashMap::new();
377 let quotes_ask = HashMap::new();
378
379 let out_xrate = get_exchange_rate(
380 Currency::from_str("EUR").unwrap(),
381 Currency::from_str("USD").unwrap(),
382 PriceType::Mid,
383 quotes_bid,
384 quotes_ask,
385 );
386
387 assert_eq!(out_xrate, Decimal::ZERO);
388 }
389
390 #[test]
391 #[should_panic]
392 fn test_unequal_quotes_length() {
393 let mut quotes_bid = HashMap::new();
394 let mut quotes_ask = HashMap::new();
395
396 quotes_bid.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1000));
397 quotes_bid.insert(Symbol::from_str_unchecked("GBP/USD"), dec!(1.3000));
398 quotes_ask.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1002));
399
400 let out_xrate = get_exchange_rate(
401 Currency::from_str("EUR").unwrap(),
402 Currency::from_str("USD").unwrap(),
403 PriceType::Mid,
404 quotes_bid,
405 quotes_ask,
406 );
407
408 assert_eq!(out_xrate, Decimal::ZERO);
409 }
410
411 #[test]
412 #[should_panic]
413 fn test_invalid_price_type() {
415 let (quotes_bid, quotes_ask) = setup_test_quotes();
416
417 let out_xrate = get_exchange_rate(
418 Currency::from_str("EUR").unwrap(),
419 Currency::from_str("USD").unwrap(),
420 PriceType::Last, quotes_bid,
422 quotes_ask,
423 );
424
425 assert_eq!(out_xrate, Decimal::ZERO);
426 }
427
428 #[test]
429 fn test_extensive_cross_pairs() {
431 let mut quotes_bid = HashMap::new();
432 let mut quotes_ask = HashMap::new();
433
434 let pairs = vec![
436 ("EUR/USD", (1.1000, 1.1002)),
437 ("GBP/USD", (1.3000, 1.3002)),
438 ("USD/JPY", (110.00, 110.02)),
439 ("EUR/GBP", (0.8461, 0.8463)),
440 ("AUD/USD", (0.7500, 0.7502)),
441 ("NZD/USD", (0.7000, 0.7002)),
442 ("USD/CAD", (1.2500, 1.2502)),
443 ];
444
445 for (pair, (bid, ask)) in pairs {
446 quotes_bid.insert(
447 Symbol::from_str_unchecked(pair),
448 Decimal::from_f64(bid).unwrap(),
449 );
450 quotes_ask.insert(
451 Symbol::from_str_unchecked(pair),
452 Decimal::from_f64(ask).unwrap(),
453 );
454 }
455
456 let test_pairs = vec![
458 ("EUR", "JPY", 121.022), ("GBP", "JPY", 143.024), ("AUD", "JPY", 82.51), ("EUR", "CAD", 1.375), ("NZD", "CAD", 0.875), ("AUD", "NZD", 1.071), ];
465
466 for (from, to, expected) in test_pairs {
467 let rate = get_exchange_rate(
468 Currency::from_str(from).unwrap(),
469 Currency::from_str(to).unwrap(),
470 PriceType::Mid,
471 quotes_bid.clone(),
472 quotes_ask.clone(),
473 );
474
475 let expected_dec = Decimal::from_f64(expected).unwrap();
476 assert!(
477 (rate - expected_dec).abs() < dec!(0.01),
478 "Failed for pair {from}/{to}: got {rate}, expected {expected_dec}"
479 );
480 }
481 }
482
483 #[test]
484 fn test_rate_consistency() {
486 let (quotes_bid, quotes_ask) = setup_test_quotes();
487
488 let rate_eur_usd = get_exchange_rate(
489 Currency::from_str("EUR").unwrap(),
490 Currency::from_str("USD").unwrap(),
491 PriceType::Mid,
492 quotes_bid.clone(),
493 quotes_ask.clone(),
494 );
495
496 let rate_usd_eur = get_exchange_rate(
497 Currency::from_str("USD").unwrap(),
498 Currency::from_str("EUR").unwrap(),
499 PriceType::Mid,
500 quotes_bid,
501 quotes_ask,
502 );
503
504 assert!((rate_eur_usd * rate_usd_eur - Decimal::ONE).abs() < dec!(0.0001));
506 }
507}