nautilus_network/
tls.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Module for wrapping raw socket streams with TLS encryption.
17
18use std::{fs::File, io::BufReader, path::Path};
19
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_tungstenite::{
23    MaybeTlsStream,
24    tungstenite::{Error, handshake::client::Request, stream::Mode},
25};
26
27/// A connector that can be used when establishing connections, allowing to control whether
28/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
29/// `Plain` variant.
30#[non_exhaustive]
31#[derive(Clone)]
32#[allow(dead_code)]
33pub enum Connector {
34    /// No TLS connection.
35    Plain,
36    /// TLS connection using `rustls`.
37    Rustls(std::sync::Arc<rustls::ClientConfig>),
38}
39
40mod encryption {
41
42    pub mod rustls {
43        use std::{convert::TryFrom, sync::Arc};
44
45        use nautilus_cryptography::tls::create_tls_config;
46        pub use rustls::ClientConfig;
47        use rustls::pki_types::ServerName;
48        use tokio::io::{AsyncRead, AsyncWrite};
49        use tokio_rustls::TlsConnector as TokioTlsConnector;
50        use tokio_tungstenite::{
51            MaybeTlsStream,
52            tungstenite::{Error, error::TlsError, stream::Mode},
53        };
54
55        pub async fn wrap_stream<S>(
56            socket: S,
57            domain: String,
58            mode: Mode,
59            tls_connector: Option<Arc<ClientConfig>>,
60        ) -> Result<MaybeTlsStream<S>, Error>
61        where
62            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
63        {
64            match mode {
65                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
66                Mode::Tls => {
67                    let config = match tls_connector {
68                        Some(config) => config,
69                        None => create_tls_config(),
70                    };
71                    let domain = ServerName::try_from(domain.as_str())
72                        .map_err(|_| TlsError::InvalidDnsName)?
73                        .to_owned();
74                    let stream = TokioTlsConnector::from(config);
75                    let connected = stream.connect(domain, socket).await;
76
77                    match connected {
78                        Err(e) => Err(Error::Io(e)),
79                        Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
80                    }
81                }
82            }
83        }
84    }
85
86    pub mod plain {
87        use tokio::io::{AsyncRead, AsyncWrite};
88        use tokio_tungstenite::{
89            MaybeTlsStream,
90            tungstenite::{
91                error::{Error, UrlError},
92                stream::Mode,
93            },
94        };
95
96        pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
97        where
98            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
99        {
100            match mode {
101                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
102                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
103            }
104        }
105    }
106}
107
108pub async fn tcp_tls<S>(
109    request: &Request,
110    mode: Mode,
111    stream: S,
112    connector: Option<Connector>,
113) -> Result<MaybeTlsStream<S>, Error>
114where
115    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
116    MaybeTlsStream<S>: Unpin,
117{
118    let domain = domain(request)?;
119
120    match connector {
121        Some(conn) => match conn {
122            Connector::Rustls(conn) => {
123                self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
124            }
125            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
126        },
127        None => self::encryption::rustls::wrap_stream(stream, domain, mode, None).await,
128    }
129}
130
131fn domain(request: &Request) -> Result<String, Error> {
132    match request.uri().host() {
133        // rustls expects IPv6 addresses without the surrounding [] brackets
134        Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
135        Some(d) => Ok(d.to_string()),
136        None => panic!("No host name"),
137    }
138}
139
140pub fn create_tls_config_from_certs_dir(certs_dir: &Path) -> anyhow::Result<rustls::ClientConfig> {
141    if !certs_dir.is_dir() {
142        anyhow::bail!("Certificate path is not a directory: {certs_dir:?}");
143    }
144
145    let mut client_cert = None;
146    let mut client_key = None;
147    let mut root_store = rustls::RootCertStore::empty();
148    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
149
150    for entry in std::fs::read_dir(certs_dir)? {
151        let entry = entry?;
152        let path = entry.path();
153
154        if client_key.is_none() {
155            if let Ok(key) = load_private_key(&path) {
156                client_key = Some(key);
157                continue;
158            }
159        }
160
161        if let Ok(certs) = load_certs(&path) {
162            if !certs.is_empty() {
163                if client_cert.is_none() {
164                    client_cert = Some(certs);
165                } else {
166                    for cert in certs {
167                        if let Err(e) = root_store.add(cert) {
168                            eprintln!("Warning: Invalid certificate in {path:?}: {e}");
169                        }
170                    }
171                }
172            }
173        }
174    }
175
176    let (cert, key) = client_cert
177        .zip(client_key)
178        .ok_or_else(|| anyhow::anyhow!("Could not find both client certificate and private key"))?;
179
180    Ok(rustls::ClientConfig::builder()
181        .with_root_certificates(root_store)
182        .with_client_auth_cert(cert, key)?)
183}
184
185fn load_private_key(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
186    let file = File::open(path)?;
187    let mut reader = BufReader::new(file);
188
189    let pkcs8_keys: Vec<_> = rustls_pemfile::pkcs8_private_keys(&mut reader)
190        .filter_map(std::result::Result::ok)
191        .collect();
192
193    if let Some(key) = pkcs8_keys.into_iter().next() {
194        return Ok(key.into());
195    }
196
197    let file = File::open(path)?;
198    let mut reader = BufReader::new(file);
199    let rsa_keys: Vec<_> = rustls_pemfile::rsa_private_keys(&mut reader)
200        .filter_map(std::result::Result::ok)
201        .collect();
202
203    if let Some(key) = rsa_keys.into_iter().next() {
204        return Ok(key.into());
205    }
206
207    anyhow::bail!("No valid private key found in {path:?}");
208}
209
210fn load_certs(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
211    let file = File::open(path)?;
212    let mut reader = BufReader::new(file);
213    let certs = rustls_pemfile::certs(&mut reader)
214        .filter_map(std::result::Result::ok)
215        .collect();
216    Ok(certs)
217}