nautilus_infrastructure/sql/
pg.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20#[derive(Debug, Clone, Builder)]
21#[builder(default)]
22#[cfg_attr(
23    feature = "python",
24    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
25)]
26pub struct PostgresConnectOptions {
27    pub host: String,
28    pub port: u16,
29    pub username: String,
30    pub password: String,
31    pub database: String,
32}
33
34impl PostgresConnectOptions {
35    /// Creates a new [`PostgresConnectOptions`] instance.
36    #[must_use]
37    pub const fn new(
38        host: String,
39        port: u16,
40        username: String,
41        password: String,
42        database: String,
43    ) -> Self {
44        Self {
45            host,
46            port,
47            username,
48            password,
49            database,
50        }
51    }
52
53    #[must_use]
54    pub fn connection_string(&self) -> String {
55        format!(
56            "postgres://{username}:{password}@{host}:{port}/{database}",
57            username = self.username,
58            password = self.password,
59            host = self.host,
60            port = self.port,
61            database = self.database
62        )
63    }
64
65    #[must_use]
66    pub fn default_administrator() -> Self {
67        Self::new(
68            String::from("localhost"),
69            5432,
70            String::from("nautilus"),
71            String::from("pass"),
72            String::from("nautilus"),
73        )
74    }
75}
76
77impl Default for PostgresConnectOptions {
78    fn default() -> Self {
79        Self::new(
80            String::from("localhost"),
81            5432,
82            String::from("nautilus"),
83            String::from("pass"),
84            String::from("nautilus"),
85        )
86    }
87}
88
89impl From<PostgresConnectOptions> for PgConnectOptions {
90    fn from(opt: PostgresConnectOptions) -> Self {
91        Self::new()
92            .host(opt.host.as_str())
93            .port(opt.port)
94            .username(opt.username.as_str())
95            .password(opt.password.as_str())
96            .database(opt.database.as_str())
97            .disable_statement_logging()
98    }
99}
100
101/// Constructs `PostgresConnectOptions` by merging provided arguments, environment variables, and defaults.
102///
103/// # Panics
104///
105/// Panics if an environment variable for port cannot be parsed into a `u16`.
106#[must_use]
107pub fn get_postgres_connect_options(
108    host: Option<String>,
109    port: Option<u16>,
110    username: Option<String>,
111    password: Option<String>,
112    database: Option<String>,
113) -> PostgresConnectOptions {
114    let defaults = PostgresConnectOptions::default_administrator();
115    let host = host
116        .or_else(|| std::env::var("POSTGRES_HOST").ok())
117        .unwrap_or(defaults.host);
118    let port = port
119        .or_else(|| {
120            std::env::var("POSTGRES_PORT")
121                .map(|port| port.parse::<u16>().unwrap())
122                .ok()
123        })
124        .unwrap_or(defaults.port);
125    let username = username
126        .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
127        .unwrap_or(defaults.username);
128    let database = database
129        .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
130        .unwrap_or(defaults.database);
131    let password = password
132        .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
133        .unwrap_or(defaults.password);
134    PostgresConnectOptions::new(host, port, username, password, database)
135}
136
137/// Connects to a Postgres database with the provided connection `options` returning a connection pool.
138///
139/// # Errors
140///
141/// Returns an error if establishing the database connection fails.
142pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
143    Ok(PgPool::connect_with(options).await?)
144}
145
146/// Scans the current working directory for the `nautilus_trader` repository
147/// and constructs the path to the SQL schema directory.
148///
149/// # Errors
150///
151/// Returns an error if the `SCHEMA_DIR` environment variable is not set and the repository
152/// cannot be located in the current directory path.
153///
154/// # Panics
155///
156/// Panics if the current working directory cannot be determined or contains invalid UTF-8.
157fn get_schema_dir() -> anyhow::Result<String> {
158    std::env::var("SCHEMA_DIR").or_else(|_| {
159        let nautilus_git_repo_name = "nautilus_trader";
160        let binding = std::env::current_dir().unwrap();
161        let current_dir = binding.to_str().unwrap();
162        match current_dir.find(nautilus_git_repo_name){
163            Some(index) => {
164                let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
165                Ok(schema_path)
166            }
167            None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
168        }
169    })
170}
171
172/// Initializes the Postgres database by creating schema, roles, and executing SQL files from `schema_dir`.
173///
174/// # Errors
175///
176/// Returns an error if any SQL execution or file system operation fails.
177///
178/// # Panics
179///
180/// Panics if `schema_dir` is missing and cannot be determined or if other unwraps fail.
181pub async fn init_postgres(
182    pg: &PgPool,
183    database: String,
184    password: String,
185    schema_dir: Option<String>,
186) -> anyhow::Result<()> {
187    log::info!("Initializing Postgres database with target permissions and schema");
188
189    // Create public schema
190    match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
191        .execute(pg)
192        .await
193    {
194        Ok(_) => log::info!("Schema public created successfully"),
195        Err(e) => log::error!("Error creating schema public: {e:?}"),
196    }
197
198    // Create role if not exists
199    match sqlx::query(format!("CREATE ROLE {database} PASSWORD '{password}' LOGIN;").as_str())
200        .execute(pg)
201        .await
202    {
203        Ok(_) => log::info!("Role {database} created successfully"),
204        Err(e) => {
205            if e.to_string().contains("already exists") {
206                log::info!("Role {database} already exists");
207            } else {
208                log::error!("Error creating role {database}: {e:?}");
209            }
210        }
211    }
212
213    // Execute all the sql files in schema dir
214    let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
215    let mut sql_files =
216        std::fs::read_dir(schema_dir)?.collect::<Result<Vec<_>, std::io::Error>>()?;
217    let plpgsql_regex = Regex::new(r"\$\$ LANGUAGE plpgsql(?:\s+SECURITY\s+DEFINER)?;")?;
218    for file in &mut sql_files {
219        let file_name = file.file_name();
220        log::info!("Executing schema file: {file_name:?}");
221        let file_path = file.path();
222        let sql_content = std::fs::read_to_string(file_path.clone())?;
223        let sql_statements: Vec<String> = match file_name.to_str() {
224            Some("functions.sql" | "partitions.sql") => {
225                let mut statements = Vec::new();
226                let mut last_end = 0;
227
228                for mat in plpgsql_regex.find_iter(&sql_content) {
229                    let statement = sql_content[last_end..mat.end()].to_string();
230                    if !statement.trim().is_empty() {
231                        statements.push(statement);
232                    }
233                    last_end = mat.end();
234                }
235                statements
236            }
237            _ => sql_content
238                .split(';')
239                .filter(|s| !s.trim().is_empty())
240                .map(|s| format!("{s};"))
241                .collect(),
242        };
243
244        for sql_statement in sql_statements {
245            sqlx::query(&sql_statement)
246                .execute(pg)
247                .await
248                .map_err(|e| {
249                    if e.to_string().contains("already exists") {
250                        log::info!("Already exists error on statement, skipping");
251                    } else {
252                        panic!("Error executing statement {sql_statement} with error: {e:?}")
253                    }
254                })
255                .unwrap();
256        }
257    }
258
259    // Grant connect
260    match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
261        .execute(pg)
262        .await
263    {
264        Ok(_) => log::info!("Connect privileges granted to role {database}"),
265        Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
266    }
267
268    // Grant all schema privileges to the role
269    match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
270        .execute(pg)
271        .await
272    {
273        Ok(_) => log::info!("All schema privileges granted to role {database}"),
274        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
275    }
276
277    // Grant all table privileges to the role
278    match sqlx::query(
279        format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
280    )
281    .execute(pg)
282    .await
283    {
284        Ok(_) => log::info!("All tables privileges granted to role {database}"),
285        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
286    }
287
288    // Grant all sequence privileges to the role
289    match sqlx::query(
290        format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
291    )
292    .execute(pg)
293    .await
294    {
295        Ok(_) => log::info!("All sequences privileges granted to role {database}"),
296        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
297    }
298
299    // Grant all function privileges to the role
300    match sqlx::query(
301        format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
302    )
303    .execute(pg)
304    .await
305    {
306        Ok(_) => log::info!("All functions privileges granted to role {database}"),
307        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
308    }
309
310    Ok(())
311}
312
313/// Drops the Postgres database with the given name using the provided connection pool.
314///
315/// # Errors
316///
317/// Returns an error if the DROP DATABASE command fails.
318pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
319    // Execute drop owned
320    match sqlx::query(format!("DROP OWNED BY {database}").as_str())
321        .execute(pg)
322        .await
323    {
324        Ok(_) => log::info!("Dropped owned objects by role {database}"),
325        Err(e) => log::error!("Error dropping owned by role {database}: {e:?}"),
326    }
327
328    // Revoke connect
329    match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
330        .execute(pg)
331        .await
332    {
333        Ok(_) => log::info!("Revoked connect privileges from role {database}"),
334        Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
335    }
336
337    // Revoke privileges
338    match sqlx::query(
339        format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
340    )
341    .execute(pg)
342    .await
343    {
344        Ok(_) => log::info!("Revoked all privileges from role {database}"),
345        Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
346    }
347
348    // Execute drop schema
349    match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
350        .execute(pg)
351        .await
352    {
353        Ok(_) => log::info!("Dropped schema public"),
354        Err(e) => log::error!("Error dropping schema public: {e:?}"),
355    }
356
357    // Drop role
358    match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
359        .execute(pg)
360        .await
361    {
362        Ok(_) => log::info!("Dropped role {database}"),
363        Err(e) => log::error!("Error dropping role {database}: {e:?}"),
364    }
365    Ok(())
366}