nautilus_network/
tls.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Module for wrapping raw socket streams with TLS encryption.
17
18use std::{fs::File, io::BufReader, path::Path};
19
20use rustls::pki_types::{CertificateDer, PrivateKeyDer};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio_tungstenite::{
23    tungstenite::{handshake::client::Request, stream::Mode, Error},
24    MaybeTlsStream,
25};
26
27/// A connector that can be used when establishing connections, allowing to control whether
28/// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
29/// `Plain` variant.
30#[non_exhaustive]
31#[derive(Clone)]
32#[allow(dead_code)]
33pub enum Connector {
34    /// No TLS connection.
35    Plain,
36    /// TLS connection using `rustls`.
37    Rustls(std::sync::Arc<rustls::ClientConfig>),
38}
39
40mod encryption {
41
42    pub mod rustls {
43        use std::{convert::TryFrom, sync::Arc};
44
45        use nautilus_cryptography::tls::create_tls_config;
46        use rustls::pki_types::ServerName;
47        pub use rustls::ClientConfig;
48        use tokio::io::{AsyncRead, AsyncWrite};
49        use tokio_rustls::TlsConnector as TokioTlsConnector;
50        use tokio_tungstenite::{
51            tungstenite::{error::TlsError, stream::Mode, Error},
52            MaybeTlsStream,
53        };
54
55        pub async fn wrap_stream<S>(
56            socket: S,
57            domain: String,
58            mode: Mode,
59            tls_connector: Option<Arc<ClientConfig>>,
60        ) -> Result<MaybeTlsStream<S>, Error>
61        where
62            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
63        {
64            match mode {
65                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
66                Mode::Tls => {
67                    let config = match tls_connector {
68                        Some(config) => config,
69                        None => create_tls_config(),
70                    };
71                    let domain = ServerName::try_from(domain.as_str())
72                        .map_err(|_| TlsError::InvalidDnsName)?
73                        .to_owned();
74                    let stream = TokioTlsConnector::from(config);
75                    let connected = stream.connect(domain, socket).await;
76
77                    match connected {
78                        Err(e) => Err(Error::Io(e)),
79                        Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
80                    }
81                }
82            }
83        }
84    }
85
86    pub mod plain {
87        use tokio::io::{AsyncRead, AsyncWrite};
88        use tokio_tungstenite::{
89            tungstenite::{
90                error::{Error, UrlError},
91                stream::Mode,
92            },
93            MaybeTlsStream,
94        };
95
96        pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
97        where
98            S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
99        {
100            match mode {
101                Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
102                Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
103            }
104        }
105    }
106}
107
108pub async fn tcp_tls<S>(
109    request: &Request,
110    mode: Mode,
111    stream: S,
112    connector: Option<Connector>,
113) -> Result<MaybeTlsStream<S>, Error>
114where
115    S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
116    MaybeTlsStream<S>: Unpin,
117{
118    let domain = domain(request)?;
119
120    match connector {
121        Some(conn) => match conn {
122            Connector::Rustls(conn) => {
123                self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
124            }
125            Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
126        },
127        None => self::encryption::rustls::wrap_stream(stream, domain, mode, None).await,
128    }
129}
130
131fn domain(request: &Request) -> Result<String, Error> {
132    match request.uri().host() {
133        // rustls expects IPv6 addresses without the surrounding [] brackets
134        Some(d) if d.starts_with('[') && d.ends_with(']') => Ok(d[1..d.len() - 1].to_string()),
135        Some(d) => Ok(d.to_string()),
136        None => panic!("No host name"),
137    }
138}
139
140pub fn create_tls_config_from_certs_dir(certs_dir: &Path) -> anyhow::Result<rustls::ClientConfig> {
141    if !certs_dir.is_dir() {
142        return Err(anyhow::anyhow!(
143            "Certificate path is not a directory: {}",
144            certs_dir.display()
145        ));
146    }
147
148    let mut client_cert = None;
149    let mut client_key = None;
150    let mut root_store = rustls::RootCertStore::empty();
151    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
152
153    for entry in std::fs::read_dir(certs_dir)? {
154        let entry = entry?;
155        let path = entry.path();
156
157        if client_key.is_none() {
158            if let Ok(key) = load_private_key(&path) {
159                client_key = Some(key);
160                continue;
161            }
162        }
163
164        if let Ok(certs) = load_certs(&path) {
165            if !certs.is_empty() {
166                if client_cert.is_none() {
167                    client_cert = Some(certs);
168                } else {
169                    for cert in certs {
170                        if let Err(e) = root_store.add(cert) {
171                            eprintln!("Warning: Invalid certificate in {}: {e}", path.display());
172                        }
173                    }
174                }
175            }
176        }
177    }
178
179    let (cert, key) = client_cert
180        .zip(client_key)
181        .ok_or_else(|| anyhow::anyhow!("Could not find both client certificate and private key"))?;
182
183    Ok(rustls::ClientConfig::builder()
184        .with_root_certificates(root_store)
185        .with_client_auth_cert(cert, key)?)
186}
187
188fn load_private_key(path: &Path) -> anyhow::Result<PrivateKeyDer<'static>> {
189    let file = File::open(path)?;
190    let mut reader = BufReader::new(file);
191
192    let pkcs8_keys: Vec<_> = rustls_pemfile::pkcs8_private_keys(&mut reader)
193        .filter_map(|result| result.ok())
194        .collect();
195
196    if let Some(key) = pkcs8_keys.into_iter().next() {
197        return Ok(key.into());
198    }
199
200    let file = File::open(path)?;
201    let mut reader = BufReader::new(file);
202    let rsa_keys: Vec<_> = rustls_pemfile::rsa_private_keys(&mut reader)
203        .filter_map(|result| result.ok())
204        .collect();
205
206    if let Some(key) = rsa_keys.into_iter().next() {
207        return Ok(key.into());
208    }
209
210    Err(anyhow::anyhow!(
211        "No valid private key found in {}",
212        path.display()
213    ))
214}
215
216fn load_certs(path: &Path) -> anyhow::Result<Vec<CertificateDer<'static>>> {
217    let file = File::open(path)?;
218    let mut reader = BufReader::new(file);
219    let certs = rustls_pemfile::certs(&mut reader)
220        .filter_map(|result| result.ok())
221        .collect();
222    Ok(certs)
223}