nautilus_network/
tls.rs
1use std::{fs::File, io::BufReader, path::Path};
19
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_tungstenite::{
23 MaybeTlsStream,
24 tungstenite::{Error, handshake::client::Request, stream::Mode},
25};
26
27#[non_exhaustive]
31#[derive(Clone)]
32#[allow(dead_code)]
33pub enum Connector {
34 Plain,
36 Rustls(std::sync::Arc<rustls::ClientConfig>),
38}
39
40mod encryption {
41
42 pub mod rustls {
43 use std::{convert::TryFrom, sync::Arc};
44
45 use nautilus_cryptography::tls::create_tls_config;
46 pub use rustls::ClientConfig;
47 use rustls::pki_types::ServerName;
48 use tokio::io::{AsyncRead, AsyncWrite};
49 use tokio_rustls::TlsConnector as TokioTlsConnector;
50 use tokio_tungstenite::{
51 MaybeTlsStream,
52 tungstenite::{Error, error::TlsError, stream::Mode},
53 };
54
55 pub async fn wrap_stream<S>(
56 socket: S,
57 domain: String,
58 mode: Mode,
59 tls_connector: Option<Arc<ClientConfig>>,
60 ) -> Result<MaybeTlsStream<S>, Error>
61 where
62 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
63 {
64 match mode {
65 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
66 Mode::Tls => {
67 let config = match tls_connector {
68 Some(config) => config,
69 None => create_tls_config(),
70 };
71 let domain = ServerName::try_from(domain.as_str())
72 .map_err(|_| TlsError::InvalidDnsName)?
73 .to_owned();
74 let stream = TokioTlsConnector::from(config);
75 let connected = stream.connect(domain, socket).await;
76
77 match connected {
78 Err(e) => Err(Error::Io(e)),
79 Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
80 }
81 }
82 }
83 }
84 }
85
86 pub mod plain {
87 use tokio::io::{AsyncRead, AsyncWrite};
88 use tokio_tungstenite::{
89 MaybeTlsStream,
90 tungstenite::{
91 error::{Error, UrlError},
92 stream::Mode,
93 },
94 };
95
96 pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
97 where
98 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
99 {
100 match mode {
101 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
102 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
103 }
104 }
105 }
106}
107
108pub async fn tcp_tls<S>(
109 request: &Request,
110 mode: Mode,
111 stream: S,
112 connector: Option<Connector>,
113) -> Result<MaybeTlsStream<S>, Error>
114where
115 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
116 MaybeTlsStream<S>: Unpin,
117{
118 let domain = domain(request)?;
119
120 match connector {
121 Some(conn) => match conn {
122 Connector::Rustls(conn) => {
123 self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
124 }
125 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
126 },
127 None => self::encryption::rustls::wrap_stream(stream, domain, mode, None).await,
128 }
129}
130
131fn domain(request: &Request) -> Result<String, Error> {
132 match request.uri().host() {
133 Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
135 Some(d) => Ok(d.to_string()),
136 None => panic!("No host name"),
137 }
138}
139
140pub fn create_tls_config_from_certs_dir(certs_dir: &Path) -> anyhow::Result<rustls::ClientConfig> {
141 if !certs_dir.is_dir() {
142 anyhow::bail!("Certificate path is not a directory: {certs_dir:?}");
143 }
144
145 let mut client_cert = None;
146 let mut client_key = None;
147 let mut root_store = rustls::RootCertStore::empty();
148 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
149
150 for entry in std::fs::read_dir(certs_dir)? {
151 let entry = entry?;
152 let path = entry.path();
153
154 if client_key.is_none() {
155 if let Ok(key) = load_private_key(&path) {
156 client_key = Some(key);
157 continue;
158 }
159 }
160
161 if let Ok(certs) = load_certs(&path) {
162 if !certs.is_empty() {
163 if client_cert.is_none() {
164 client_cert = Some(certs);
165 } else {
166 for cert in certs {
167 if let Err(e) = root_store.add(cert) {
168 eprintln!("Warning: Invalid certificate in {path:?}: {e}");
169 }
170 }
171 }
172 }
173 }
174 }
175
176 let (cert, key) = client_cert
177 .zip(client_key)
178 .ok_or_else(|| anyhow::anyhow!("Could not find both client certificate and private key"))?;
179
180 Ok(rustls::ClientConfig::builder()
181 .with_root_certificates(root_store)
182 .with_client_auth_cert(cert, key)?)
183}
184
185fn load_private_key(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
186 let file = File::open(path)?;
187 let mut reader = BufReader::new(file);
188
189 let pkcs8_keys: Vec<_> = rustls_pemfile::pkcs8_private_keys(&mut reader)
190 .filter_map(std::result::Result::ok)
191 .collect();
192
193 if let Some(key) = pkcs8_keys.into_iter().next() {
194 return Ok(key.into());
195 }
196
197 let file = File::open(path)?;
198 let mut reader = BufReader::new(file);
199 let rsa_keys: Vec<_> = rustls_pemfile::rsa_private_keys(&mut reader)
200 .filter_map(std::result::Result::ok)
201 .collect();
202
203 if let Some(key) = rsa_keys.into_iter().next() {
204 return Ok(key.into());
205 }
206
207 anyhow::bail!("No valid private key found in {path:?}");
208}
209
210fn load_certs(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
211 let file = File::open(path)?;
212 let mut reader = BufReader::new(file);
213 let certs = rustls_pemfile::certs(&mut reader)
214 .filter_map(std::result::Result::ok)
215 .collect();
216 Ok(certs)
217}