Skip to main content

nautilus_infrastructure/sql/
pg.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20fn validate_sql_identifier(value: &str, label: &str) -> anyhow::Result<()> {
21    if value.is_empty() {
22        anyhow::bail!("{label} must not be empty");
23    }
24    if !value.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
25        anyhow::bail!(
26            "{label} contains invalid characters (only alphanumeric and underscore allowed): {value}"
27        );
28    }
29    Ok(())
30}
31
32fn escape_sql_string(value: &str) -> String {
33    value.replace('\'', "''")
34}
35
36#[derive(Debug, Clone, Builder)]
37#[builder(default)]
38#[cfg_attr(
39    feature = "python",
40    pyo3::pyclass(
41        module = "nautilus_trader.core.nautilus_pyo3.infrastructure",
42        from_py_object
43    )
44)]
45#[cfg_attr(
46    feature = "python",
47    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.infrastructure")
48)]
49pub struct PostgresConnectOptions {
50    pub host: String,
51    pub port: u16,
52    pub username: String,
53    pub password: String,
54    pub database: String,
55}
56
57impl PostgresConnectOptions {
58    /// Creates a new [`PostgresConnectOptions`] instance.
59    #[must_use]
60    pub const fn new(
61        host: String,
62        port: u16,
63        username: String,
64        password: String,
65        database: String,
66    ) -> Self {
67        Self {
68            host,
69            port,
70            username,
71            password,
72            database,
73        }
74    }
75
76    #[must_use]
77    pub fn connection_string(&self) -> String {
78        format!(
79            "postgres://{username}:{password}@{host}:{port}/{database}",
80            username = self.username,
81            password = self.password,
82            host = self.host,
83            port = self.port,
84            database = self.database
85        )
86    }
87
88    /// Returns the connection string with the password masked for safe logging.
89    #[must_use]
90    pub fn connection_string_masked(&self) -> String {
91        format!(
92            "postgres://{username}:***@{host}:{port}/{database}",
93            username = self.username,
94            host = self.host,
95            port = self.port,
96            database = self.database
97        )
98    }
99
100    #[must_use]
101    pub fn default_administrator() -> Self {
102        Self::new(
103            String::from("localhost"),
104            5432,
105            String::from("nautilus"),
106            String::from("pass"),
107            String::from("nautilus"),
108        )
109    }
110}
111
112impl Default for PostgresConnectOptions {
113    fn default() -> Self {
114        Self::new(
115            String::from("localhost"),
116            5432,
117            String::from("nautilus"),
118            String::from("pass"),
119            String::from("nautilus"),
120        )
121    }
122}
123
124impl From<PostgresConnectOptions> for PgConnectOptions {
125    fn from(opt: PostgresConnectOptions) -> Self {
126        Self::new()
127            .host(opt.host.as_str())
128            .port(opt.port)
129            .username(opt.username.as_str())
130            .password(opt.password.as_str())
131            .database(opt.database.as_str())
132            .disable_statement_logging()
133    }
134}
135
136/// Constructs `PostgresConnectOptions` by merging provided arguments, environment variables, and defaults.
137///
138/// # Panics
139///
140/// Panics if an environment variable for port cannot be parsed into a `u16`.
141#[must_use]
142pub fn get_postgres_connect_options(
143    host: Option<String>,
144    port: Option<u16>,
145    username: Option<String>,
146    password: Option<String>,
147    database: Option<String>,
148) -> PostgresConnectOptions {
149    let defaults = PostgresConnectOptions::default_administrator();
150    let host = host
151        .or_else(|| std::env::var("POSTGRES_HOST").ok())
152        .unwrap_or(defaults.host);
153    let port = port
154        .or_else(|| {
155            std::env::var("POSTGRES_PORT")
156                .map(|port| port.parse::<u16>().unwrap())
157                .ok()
158        })
159        .unwrap_or(defaults.port);
160    let username = username
161        .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
162        .unwrap_or(defaults.username);
163    let database = database
164        .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
165        .unwrap_or(defaults.database);
166    let password = password
167        .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
168        .unwrap_or(defaults.password);
169    PostgresConnectOptions::new(host, port, username, password, database)
170}
171
172/// Connects to a Postgres database with the provided connection `options` returning a connection pool.
173///
174/// # Errors
175///
176/// Returns an error if establishing the database connection fails.
177pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
178    Ok(PgPool::connect_with(options).await?)
179}
180
181/// Scans the current working directory for the `nautilus_trader` repository
182/// and constructs the path to the SQL schema directory.
183///
184/// # Errors
185///
186/// Returns an error if the `SCHEMA_DIR` environment variable is not set and the repository
187/// cannot be located in the current directory path.
188///
189/// # Panics
190///
191/// Panics if the current working directory cannot be determined or contains invalid UTF-8.
192fn get_schema_dir() -> anyhow::Result<String> {
193    std::env::var("SCHEMA_DIR").or_else(|_| {
194        let nautilus_git_repo_name = "nautilus_trader";
195        let binding = std::env::current_dir().unwrap();
196        let current_dir = binding.to_str().unwrap();
197        match current_dir.find(nautilus_git_repo_name){
198            Some(index) => {
199                let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
200                Ok(schema_path)
201            }
202            None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
203        }
204    })
205}
206
207/// Initializes the Postgres database by creating schema, roles, and executing SQL files from `schema_dir`.
208///
209/// # Errors
210///
211/// Returns an error if any SQL execution or file system operation fails.
212///
213/// # Panics
214///
215/// Panics if `schema_dir` is missing and cannot be determined or if other unwraps fail.
216pub async fn init_postgres(
217    pg: &PgPool,
218    database: String,
219    password: String,
220    schema_dir: Option<String>,
221) -> anyhow::Result<()> {
222    log::info!("Initializing Postgres database with target permissions and schema");
223
224    validate_sql_identifier(&database, "database")?;
225
226    // Create public schema
227    match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
228        .execute(pg)
229        .await
230    {
231        Ok(_) => log::info!("Schema public created successfully"),
232        Err(e) => log::error!("Error creating schema public: {e:?}"),
233    }
234
235    // Create role if not exists
236    let escaped_password = escape_sql_string(&password);
237    match sqlx::query(
238        format!("CREATE ROLE {database} PASSWORD '{escaped_password}' LOGIN;").as_str(),
239    )
240    .execute(pg)
241    .await
242    {
243        Ok(_) => log::info!("Role {database} created successfully"),
244        Err(e) => {
245            if e.to_string().contains("already exists") {
246                log::info!("Role {database} already exists");
247            } else {
248                log::error!("Error creating role {database}: {e:?}");
249            }
250        }
251    }
252
253    // Execute all the sql files in schema dir
254    let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
255    let sql_files = vec!["types.sql", "functions.sql", "partitions.sql", "tables.sql"];
256    let plpgsql_regex =
257        Regex::new(r"\$\$ LANGUAGE plpgsql(?:[ \t\r\n]+SECURITY[ \t\r\n]+DEFINER)?;")?;
258    for file_name in &sql_files {
259        log::info!("Executing schema file: {file_name:?}");
260        let file_path = format!("{schema_dir}/{file_name}");
261        let sql_content = std::fs::read_to_string(&file_path)?;
262        let sql_statements: Vec<String> = match *file_name {
263            "functions.sql" | "partitions.sql" => {
264                let mut statements = Vec::new();
265                let mut last_end = 0;
266
267                for mat in plpgsql_regex.find_iter(&sql_content) {
268                    let statement = sql_content[last_end..mat.end()].to_string();
269                    if !statement.trim().is_empty() {
270                        statements.push(statement);
271                    }
272                    last_end = mat.end();
273                }
274                statements
275            }
276            _ => sql_content
277                .split(';')
278                .filter(|s| !s.trim().is_empty())
279                .map(|s| format!("{s};"))
280                .collect(),
281        };
282
283        for sql_statement in sql_statements {
284            sqlx::query(&sql_statement)
285                .execute(pg)
286                .await
287                .map_err(|e| {
288                    if e.to_string().contains("already exists") {
289                        log::info!("Already exists error on statement, skipping");
290                    } else {
291                        panic!("Error executing statement {sql_statement} with error: {e:?}")
292                    }
293                })
294                .unwrap();
295        }
296    }
297
298    // Grant connect
299    match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
300        .execute(pg)
301        .await
302    {
303        Ok(_) => log::info!("Connect privileges granted to role {database}"),
304        Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
305    }
306
307    // Grant all schema privileges to the role
308    match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
309        .execute(pg)
310        .await
311    {
312        Ok(_) => log::info!("All schema privileges granted to role {database}"),
313        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
314    }
315
316    // Grant all table privileges to the role
317    match sqlx::query(
318        format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
319    )
320    .execute(pg)
321    .await
322    {
323        Ok(_) => log::info!("All tables privileges granted to role {database}"),
324        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
325    }
326
327    // Grant all sequence privileges to the role
328    match sqlx::query(
329        format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
330    )
331    .execute(pg)
332    .await
333    {
334        Ok(_) => log::info!("All sequences privileges granted to role {database}"),
335        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
336    }
337
338    // Grant all function privileges to the role
339    match sqlx::query(
340        format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
341    )
342    .execute(pg)
343    .await
344    {
345        Ok(_) => log::info!("All functions privileges granted to role {database}"),
346        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
347    }
348
349    Ok(())
350}
351
352/// Drops the Postgres database with the given name using the provided connection pool.
353///
354/// # Errors
355///
356/// Returns an error if the DROP DATABASE command fails.
357pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
358    validate_sql_identifier(&database, "database")?;
359
360    // Execute drop owned
361    match sqlx::query(format!("DROP OWNED BY {database}").as_str())
362        .execute(pg)
363        .await
364    {
365        Ok(_) => log::info!("Dropped owned objects by role {database}"),
366        Err(e) => {
367            let err_msg = e.to_string();
368            if err_msg.contains("2BP01") || err_msg.contains("required by the database system") {
369                log::warn!("Skipping system-required objects for role {database}");
370            } else {
371                log::error!("Error dropping owned by role {database}: {e:?}");
372            }
373        }
374    }
375
376    // Revoke connect
377    match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
378        .execute(pg)
379        .await
380    {
381        Ok(_) => log::info!("Revoked connect privileges from role {database}"),
382        Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
383    }
384
385    // Revoke privileges
386    match sqlx::query(
387        format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
388    )
389    .execute(pg)
390    .await
391    {
392        Ok(_) => log::info!("Revoked all privileges from role {database}"),
393        Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
394    }
395
396    // Execute drop schema
397    match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
398        .execute(pg)
399        .await
400    {
401        Ok(_) => log::info!("Dropped schema public"),
402        Err(e) => log::error!("Error dropping schema public: {e:?}"),
403    }
404
405    // Drop role
406    match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
407        .execute(pg)
408        .await
409    {
410        Ok(_) => log::info!("Dropped role {database}"),
411        Err(e) => {
412            let err_msg = e.to_string();
413            if err_msg.contains("55006") || err_msg.contains("current user cannot be dropped") {
414                log::warn!("Cannot drop currently connected role {database}");
415            } else {
416                log::error!("Error dropping role {database}: {e:?}");
417            }
418        }
419    }
420    Ok(())
421}