1use std::{fs::File, io::BufReader, path::Path};
19
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_tungstenite::{
23 tungstenite::{handshake::client::Request, stream::Mode, Error},
24 MaybeTlsStream,
25};
26
27#[non_exhaustive]
31#[derive(Clone)]
32#[allow(dead_code)]
33pub enum Connector {
34 Plain,
36 Rustls(std::sync::Arc<rustls::ClientConfig>),
38}
39
40mod encryption {
41
42 pub mod rustls {
43 use std::{convert::TryFrom, sync::Arc};
44
45 use nautilus_cryptography::tls::create_tls_config;
46 use rustls::pki_types::ServerName;
47 pub use rustls::ClientConfig;
48 use tokio::io::{AsyncRead, AsyncWrite};
49 use tokio_rustls::TlsConnector as TokioTlsConnector;
50 use tokio_tungstenite::{
51 tungstenite::{error::TlsError, stream::Mode, Error},
52 MaybeTlsStream,
53 };
54
55 pub async fn wrap_stream<S>(
56 socket: S,
57 domain: String,
58 mode: Mode,
59 tls_connector: Option<Arc<ClientConfig>>,
60 ) -> Result<MaybeTlsStream<S>, Error>
61 where
62 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
63 {
64 match mode {
65 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
66 Mode::Tls => {
67 let config = match tls_connector {
68 Some(config) => config,
69 None => create_tls_config(),
70 };
71 let domain = ServerName::try_from(domain.as_str())
72 .map_err(|_| TlsError::InvalidDnsName)?
73 .to_owned();
74 let stream = TokioTlsConnector::from(config);
75 let connected = stream.connect(domain, socket).await;
76
77 match connected {
78 Err(e) => Err(Error::Io(e)),
79 Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
80 }
81 }
82 }
83 }
84 }
85
86 pub mod plain {
87 use tokio::io::{AsyncRead, AsyncWrite};
88 use tokio_tungstenite::{
89 tungstenite::{
90 error::{Error, UrlError},
91 stream::Mode,
92 },
93 MaybeTlsStream,
94 };
95
96 pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
97 where
98 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
99 {
100 match mode {
101 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
102 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
103 }
104 }
105 }
106}
107
108pub async fn tcp_tls<S>(
109 request: &Request,
110 mode: Mode,
111 stream: S,
112 connector: Option<Connector>,
113) -> Result<MaybeTlsStream<S>, Error>
114where
115 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
116 MaybeTlsStream<S>: Unpin,
117{
118 let domain = domain(request)?;
119
120 match connector {
121 Some(conn) => match conn {
122 Connector::Rustls(conn) => {
123 self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
124 }
125 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
126 },
127 None => self::encryption::rustls::wrap_stream(stream, domain, mode, None).await,
128 }
129}
130
131fn domain(request: &Request) -> Result<String, Error> {
132 match request.uri().host() {
133 Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
135 Some(d) => Ok(d.to_string()),
136 None => panic!("No host name"),
137 }
138}
139
140pub fn create_tls_config_from_certs_dir(certs_dir: &Path) -> anyhow::Result<rustls::ClientConfig> {
141 if !certs_dir.is_dir() {
142 return Err(anyhow::anyhow!(
143 "Certificate path is not a directory: {}",
144 certs_dir.display()
145 ));
146 }
147
148 let mut client_cert = None;
149 let mut client_key = None;
150 let mut root_store = rustls::RootCertStore::empty();
151 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
152
153 for entry in std::fs::read_dir(certs_dir)? {
154 let entry = entry?;
155 let path = entry.path();
156
157 if client_key.is_none() {
158 if let Ok(key) = load_private_key(&path) {
159 client_key = Some(key);
160 continue;
161 }
162 }
163
164 if let Ok(certs) = load_certs(&path) {
165 if !certs.is_empty() {
166 if client_cert.is_none() {
167 client_cert = Some(certs);
168 } else {
169 for cert in certs {
170 if let Err(e) = root_store.add(cert) {
171 eprintln!("Warning: Invalid certificate in {}: {e}", path.display());
172 }
173 }
174 }
175 }
176 }
177 }
178
179 let (cert, key) = client_cert
180 .zip(client_key)
181 .ok_or_else(|| anyhow::anyhow!("Could not find both client certificate and private key"))?;
182
183 Ok(rustls::ClientConfig::builder()
184 .with_root_certificates(root_store)
185 .with_client_auth_cert(cert, key)?)
186}
187
188fn load_private_key(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
189 let file = File::open(path)?;
190 let mut reader = BufReader::new(file);
191
192 let pkcs8_keys: Vec<_> = rustls_pemfile::pkcs8_private_keys(&mut reader)
193 .filter_map(|result| result.ok())
194 .collect();
195
196 if let Some(key) = pkcs8_keys.into_iter().next() {
197 return Ok(key.into());
198 }
199
200 let file = File::open(path)?;
201 let mut reader = BufReader::new(file);
202 let rsa_keys: Vec<_> = rustls_pemfile::rsa_private_keys(&mut reader)
203 .filter_map(|result| result.ok())
204 .collect();
205
206 if let Some(key) = rsa_keys.into_iter().next() {
207 return Ok(key.into());
208 }
209
210 Err(anyhow::anyhow!(
211 "No valid private key found in {}",
212 path.display()
213 ))
214}
215
216fn load_certs(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
217 let file = File::open(path)?;
218 let mut reader = BufReader::new(file);
219 let certs = rustls_pemfile::certs(&mut reader)
220 .filter_map(|result| result.ok())
221 .collect();
222 Ok(certs)
223}