nautilus_network/
tls.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Module for wrapping raw socket streams with TLS encryption.
17
18use std::{fs::File, io::BufReader, path::Path};
19
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_tungstenite::{
23    MaybeTlsStream,
24    tungstenite::{Error, handshake::client::Request, stream::Mode},
25};
26
27/// A connector that can be used when establishing connections, allowing to control whether
28/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
29/// `Plain` variant.
30#[non_exhaustive]
31#[derive(Clone)]
32#[allow(dead_code)]
33pub enum Connector {
34    /// No TLS connection.
35    Plain,
36    /// TLS connection using `rustls`.
37    Rustls(std::sync::Arc<rustls::ClientConfig>),
38}
39
40mod encryption {
41
42    pub mod rustls {
43        use std::{convert::TryFrom, sync::Arc};
44
45        use nautilus_cryptography::tls::create_tls_config;
46        pub use rustls::ClientConfig;
47        use rustls::pki_types::ServerName;
48        use tokio::io::{AsyncRead, AsyncWrite};
49        use tokio_rustls::TlsConnector as TokioTlsConnector;
50        use tokio_tungstenite::{
51            MaybeTlsStream,
52            tungstenite::{Error, error::TlsError, stream::Mode},
53        };
54
55        pub async fn wrap_stream<S>(
56            socket: S,
57            domain: String,
58            mode: Mode,
59            tls_connector: Option<Arc<ClientConfig>>,
60        ) -> Result<MaybeTlsStream<S>, Error>
61        where
62            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
63        {
64            match mode {
65                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
66                Mode::Tls => {
67                    let config = match tls_connector {
68                        Some(config) => config,
69                        None => create_tls_config(),
70                    };
71                    let domain = ServerName::try_from(domain.as_str())
72                        .map_err(|_| TlsError::InvalidDnsName)?
73                        .to_owned();
74                    let stream = TokioTlsConnector::from(config);
75                    let connected = stream.connect(domain, socket).await;
76
77                    match connected {
78                        Err(e) => Err(Error::Io(e)),
79                        Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
80                    }
81                }
82            }
83        }
84    }
85
86    pub mod plain {
87        use tokio::io::{AsyncRead, AsyncWrite};
88        use tokio_tungstenite::{
89            MaybeTlsStream,
90            tungstenite::{
91                error::{Error, UrlError},
92                stream::Mode,
93            },
94        };
95
96        pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
97        where
98            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
99        {
100            match mode {
101                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
102                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
103            }
104        }
105    }
106}
107
108pub async fn tcp_tls<S>(
109    request: &Request,
110    mode: Mode,
111    stream: S,
112    connector: Option<Connector>,
113) -> Result<MaybeTlsStream<S>, Error>
114where
115    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
116    MaybeTlsStream<S>: Unpin,
117{
118    let domain = domain(request)?;
119
120    match connector {
121        Some(conn) => match conn {
122            Connector::Rustls(conn) => {
123                self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
124            }
125            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
126        },
127        None => self::encryption::rustls::wrap_stream(stream, domain, mode, None).await,
128    }
129}
130
131/// Extracts the host name from the request URI.
132///
133/// # Panics
134///
135/// Panics if the request URI has no host component.
136#[allow(clippy::result_large_err)]
137fn domain(request: &Request) -> Result<String, Error> {
138    match request.uri().host() {
139        // rustls expects IPv6 addresses without the surrounding [] brackets
140        Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
141        Some(d) => Ok(d.to_string()),
142        None => panic!("No host name"),
143    }
144}
145
146pub fn create_tls_config_from_certs_dir(certs_dir: &Path) -> anyhow::Result<rustls::ClientConfig> {
147    if !certs_dir.is_dir() {
148        anyhow::bail!("Certificate path is not a directory: {certs_dir:?}");
149    }
150
151    let mut client_cert = None;
152    let mut client_key = None;
153    let mut root_store = rustls::RootCertStore::empty();
154    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
155
156    for entry in std::fs::read_dir(certs_dir)? {
157        let entry = entry?;
158        let path = entry.path();
159
160        if client_key.is_none()
161            && let Ok(key) = load_private_key(&path)
162        {
163            client_key = Some(key);
164            continue;
165        }
166
167        if let Ok(certs) = load_certs(&path)
168            && !certs.is_empty()
169        {
170            if client_cert.is_none() {
171                client_cert = Some(certs);
172            } else {
173                for cert in certs {
174                    if let Err(e) = root_store.add(cert) {
175                        eprintln!("Warning: Invalid certificate in {path:?}: {e}");
176                    }
177                }
178            }
179        }
180    }
181
182    let (cert, key) = client_cert
183        .zip(client_key)
184        .ok_or_else(|| anyhow::anyhow!("Could not find both client certificate and private key"))?;
185
186    Ok(rustls::ClientConfig::builder()
187        .with_root_certificates(root_store)
188        .with_client_auth_cert(cert, key)?)
189}
190
191fn load_private_key(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
192    let file = File::open(path)?;
193    let mut reader = BufReader::new(file);
194
195    let pkcs8_keys: Vec<_> = rustls_pemfile::pkcs8_private_keys(&mut reader)
196        .filter_map(std::result::Result::ok)
197        .collect();
198
199    if let Some(key) = pkcs8_keys.into_iter().next() {
200        return Ok(key.into());
201    }
202
203    let file = File::open(path)?;
204    let mut reader = BufReader::new(file);
205    let rsa_keys: Vec<_> = rustls_pemfile::rsa_private_keys(&mut reader)
206        .filter_map(std::result::Result::ok)
207        .collect();
208
209    if let Some(key) = rsa_keys.into_iter().next() {
210        return Ok(key.into());
211    }
212
213    anyhow::bail!("No valid private key found in {path:?}");
214}
215
216fn load_certs(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
217    let file = File::open(path)?;
218    let mut reader = BufReader::new(file);
219    let certs = rustls_pemfile::certs(&mut reader)
220        .filter_map(std::result::Result::ok)
221        .collect();
222    Ok(certs)
223}