use std::collections::{HashMap, HashSet};
use itertools::Itertools;
use nautilus_core::correctness::{check_equal_usize, check_map_not_empty, FAILED};
use nautilus_model::{enums::PriceType, identifiers::Symbol, types::Currency};
use rust_decimal::Decimal;
use ustr::Ustr;
#[must_use]
pub fn get_exchange_rate(
from_currency: Currency,
to_currency: Currency,
price_type: PriceType,
quotes_bid: HashMap<Symbol, Decimal>,
quotes_ask: HashMap<Symbol, Decimal>,
) -> Decimal {
check_map_not_empty("es_bid, stringify!(quotes_bid)).expect(FAILED);
check_map_not_empty("es_ask, stringify!(quotes_ask)).expect(FAILED);
check_equal_usize(
quotes_bid.len(),
quotes_ask.len(),
"quotes_bid.len()",
"quotes_ask.len()",
)
.expect(FAILED);
if from_currency == to_currency {
return Decimal::ONE;
}
let calculation_quotes = match price_type {
PriceType::Bid => quotes_bid,
PriceType::Ask => quotes_ask,
PriceType::Mid => quotes_bid
.iter()
.map(|(k, v)| {
let ask = quotes_ask.get(k).unwrap_or(v);
(*k, (v + ask) / Decimal::TWO)
})
.collect(),
_ => {
panic!("Cannot calculate exchange rate for PriceType: {price_type:?}");
}
};
let mut codes = HashSet::new();
let mut exchange_rates: HashMap<Ustr, HashMap<Ustr, Decimal>> = HashMap::new();
for (symbol, quote) in &calculation_quotes {
let pieces: Vec<&str> = symbol.as_str().split('/').collect();
let code_lhs = Ustr::from(pieces[0]);
let code_rhs = Ustr::from(pieces[1]);
codes.insert(code_lhs);
codes.insert(code_rhs);
exchange_rates.entry(code_lhs).or_default();
exchange_rates.entry(code_rhs).or_default();
if let Some(rates_lhs) = exchange_rates.get_mut(&code_lhs) {
rates_lhs.insert(code_lhs, Decimal::ONE);
rates_lhs.insert(code_rhs, *quote);
}
if let Some(rates_rhs) = exchange_rates.get_mut(&code_rhs) {
rates_rhs.insert(code_rhs, Decimal::ONE);
}
}
let code_perms: Vec<(Ustr, Ustr)> = codes
.iter()
.cartesian_product(codes.iter())
.filter(|(a, b)| a != b)
.map(|(a, b)| (*a, *b))
.collect();
for (perm0, perm1) in &code_perms {
let rate_0_to_1 = exchange_rates
.get(perm0)
.and_then(|rates| rates.get(perm1))
.copied();
if let Some(rate) = rate_0_to_1 {
if let Some(xrate_perm1) = exchange_rates.get_mut(perm1) {
if !xrate_perm1.contains_key(perm0) {
xrate_perm1.insert(*perm0, Decimal::ONE / rate);
}
}
}
let rate_1_to_0 = exchange_rates
.get(perm1)
.and_then(|rates| rates.get(perm0))
.copied();
if let Some(rate) = rate_1_to_0 {
if let Some(xrate_perm0) = exchange_rates.get_mut(perm0) {
if !xrate_perm0.contains_key(perm1) {
xrate_perm0.insert(*perm1, Decimal::ONE / rate);
}
}
}
}
if let Some(quotes) = exchange_rates.get(&from_currency.code) {
if let Some(&rate) = quotes.get(&to_currency.code) {
return rate;
}
}
for (perm0, perm1) in &code_perms {
if exchange_rates
.get(perm1)
.is_some_and(|rates| rates.contains_key(perm0))
{
continue;
}
for code in &codes {
let rates_through_common = {
let rates_perm0 = exchange_rates.get(perm0);
let rates_perm1 = exchange_rates.get(perm1);
match (rates_perm0, rates_perm1) {
(Some(rates0), Some(rates1)) => {
if let (Some(&rate1), Some(&rate2)) = (rates0.get(code), rates1.get(code)) {
Some((rate1, rate2))
} else {
None
}
}
_ => None,
}
};
let rates_from_code = if rates_through_common.is_none() {
if let Some(rates_code) = exchange_rates.get(code) {
if let (Some(&rate1), Some(&rate2)) =
(rates_code.get(perm0), rates_code.get(perm1))
{
Some((rate1, rate2))
} else {
None
}
} else {
None
}
} else {
None
};
if let Some((common_rate1, common_rate2)) = rates_through_common.or(rates_from_code) {
if let Some(rates_perm1) = exchange_rates.get_mut(perm1) {
rates_perm1.insert(*perm0, common_rate2 / common_rate1);
}
if let Some(rates_perm0) = exchange_rates.get_mut(perm0) {
if !rates_perm0.contains_key(perm1) {
rates_perm0.insert(*perm1, common_rate1 / common_rate2);
}
}
}
}
}
let xrate = exchange_rates
.get(&from_currency.code)
.and_then(|quotes| quotes.get(&to_currency.code))
.copied()
.unwrap_or(Decimal::ZERO);
xrate
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use rust_decimal::prelude::FromPrimitive;
use rust_decimal_macros::dec;
use super::*;
fn setup_test_quotes() -> (HashMap<Symbol, Decimal>, HashMap<Symbol, Decimal>) {
let mut quotes_bid = HashMap::new();
let mut quotes_ask = HashMap::new();
quotes_bid.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1000));
quotes_ask.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1002));
quotes_bid.insert(Symbol::from_str_unchecked("GBP/USD"), dec!(1.3000));
quotes_ask.insert(Symbol::from_str_unchecked("GBP/USD"), dec!(1.3002));
quotes_bid.insert(Symbol::from_str_unchecked("USD/JPY"), dec!(110.00));
quotes_ask.insert(Symbol::from_str_unchecked("USD/JPY"), dec!(110.02));
quotes_bid.insert(Symbol::from_str_unchecked("AUD/USD"), dec!(0.7500));
quotes_ask.insert(Symbol::from_str_unchecked("AUD/USD"), dec!(0.7502));
(quotes_bid, quotes_ask)
}
#[test]
fn test_same_currency() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let rate = get_exchange_rate(
Currency::from_str("USD").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
assert_eq!(rate, Decimal::ONE);
}
#[test]
fn test_direct_pair() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let rate_bid = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Bid,
quotes_bid.clone(),
quotes_ask.clone(),
);
assert_eq!(rate_bid, dec!(1.1000));
let rate_ask = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Ask,
quotes_bid.clone(),
quotes_ask.clone(),
);
assert_eq!(rate_ask, dec!(1.1002));
let rate_mid = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
assert_eq!(rate_mid, dec!(1.1001));
}
#[test]
fn test_inverse_pair() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let rate = get_exchange_rate(
Currency::from_str("USD").unwrap(),
Currency::from_str("EUR").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
let expected = Decimal::ONE / dec!(1.1001);
assert!((rate - expected).abs() < dec!(0.0001));
}
#[test]
fn test_cross_pair_through_usd() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let rate = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("JPY").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
let expected = dec!(1.1001) * dec!(110.01);
assert!((rate - expected).abs() < dec!(0.01));
}
#[test]
fn test_multiple_path_cross_pair() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let rate = get_exchange_rate(
Currency::from_str("GBP").unwrap(),
Currency::from_str("AUD").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
let expected = dec!(1.3001) / dec!(0.7501);
assert!((rate - expected).abs() < dec!(0.01));
}
#[test]
fn test_missing_pairs() {
let mut quotes_bid = HashMap::new();
let mut quotes_ask = HashMap::new();
quotes_bid.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1000));
quotes_ask.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1002));
let rate = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("JPY").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
assert_eq!(rate, Decimal::ZERO); }
#[test]
#[should_panic]
fn test_empty_quotes() {
let quotes_bid = HashMap::new();
let quotes_ask = HashMap::new();
let out_xrate = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
assert_eq!(out_xrate, Decimal::ZERO);
}
#[test]
#[should_panic]
fn test_unequal_quotes_length() {
let mut quotes_bid = HashMap::new();
let mut quotes_ask = HashMap::new();
quotes_bid.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1000));
quotes_bid.insert(Symbol::from_str_unchecked("GBP/USD"), dec!(1.3000));
quotes_ask.insert(Symbol::from_str_unchecked("EUR/USD"), dec!(1.1002));
let out_xrate = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
assert_eq!(out_xrate, Decimal::ZERO);
}
#[test]
#[should_panic]
fn test_invalid_price_type() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let out_xrate = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Last, quotes_bid,
quotes_ask,
);
assert_eq!(out_xrate, Decimal::ZERO);
}
#[test]
fn test_extensive_cross_pairs() {
let mut quotes_bid = HashMap::new();
let mut quotes_ask = HashMap::new();
let pairs = vec![
("EUR/USD", (1.1000, 1.1002)),
("GBP/USD", (1.3000, 1.3002)),
("USD/JPY", (110.00, 110.02)),
("EUR/GBP", (0.8461, 0.8463)),
("AUD/USD", (0.7500, 0.7502)),
("NZD/USD", (0.7000, 0.7002)),
("USD/CAD", (1.2500, 1.2502)),
];
for (pair, (bid, ask)) in pairs {
quotes_bid.insert(
Symbol::from_str_unchecked(pair),
Decimal::from_f64(bid).unwrap(),
);
quotes_ask.insert(
Symbol::from_str_unchecked(pair),
Decimal::from_f64(ask).unwrap(),
);
}
let test_pairs = vec![
("EUR", "JPY", 121.022), ("GBP", "JPY", 143.024), ("AUD", "JPY", 82.51), ("EUR", "CAD", 1.375), ("NZD", "CAD", 0.875), ("AUD", "NZD", 1.071), ];
for (from, to, expected) in test_pairs {
let rate = get_exchange_rate(
Currency::from_str(from).unwrap(),
Currency::from_str(to).unwrap(),
PriceType::Mid,
quotes_bid.clone(),
quotes_ask.clone(),
);
let expected_dec = Decimal::from_f64(expected).unwrap();
assert!(
(rate - expected_dec).abs() < dec!(0.01),
"Failed for pair {from}/{to}: got {rate}, expected {expected_dec}"
);
}
}
#[test]
fn test_rate_consistency() {
let (quotes_bid, quotes_ask) = setup_test_quotes();
let rate_eur_usd = get_exchange_rate(
Currency::from_str("EUR").unwrap(),
Currency::from_str("USD").unwrap(),
PriceType::Mid,
quotes_bid.clone(),
quotes_ask.clone(),
);
let rate_usd_eur = get_exchange_rate(
Currency::from_str("USD").unwrap(),
Currency::from_str("EUR").unwrap(),
PriceType::Mid,
quotes_bid,
quotes_ask,
);
assert!((rate_eur_usd * rate_usd_eur - Decimal::ONE).abs() < dec!(0.0001));
}
}