nautilus_infrastructure/sql/
pg.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20#[derive(Debug, Clone, Builder)]
21#[builder(default)]
22#[cfg_attr(
23    feature = "python",
24    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
25)]
26#[cfg_attr(
27    feature = "python",
28    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.infrastructure")
29)]
30pub struct PostgresConnectOptions {
31    pub host: String,
32    pub port: u16,
33    pub username: String,
34    pub password: String,
35    pub database: String,
36}
37
38impl PostgresConnectOptions {
39    /// Creates a new [`PostgresConnectOptions`] instance.
40    #[must_use]
41    pub const fn new(
42        host: String,
43        port: u16,
44        username: String,
45        password: String,
46        database: String,
47    ) -> Self {
48        Self {
49            host,
50            port,
51            username,
52            password,
53            database,
54        }
55    }
56
57    #[must_use]
58    pub fn connection_string(&self) -> String {
59        format!(
60            "postgres://{username}:{password}@{host}:{port}/{database}",
61            username = self.username,
62            password = self.password,
63            host = self.host,
64            port = self.port,
65            database = self.database
66        )
67    }
68
69    #[must_use]
70    pub fn default_administrator() -> Self {
71        Self::new(
72            String::from("localhost"),
73            5432,
74            String::from("nautilus"),
75            String::from("pass"),
76            String::from("nautilus"),
77        )
78    }
79}
80
81impl Default for PostgresConnectOptions {
82    fn default() -> Self {
83        Self::new(
84            String::from("localhost"),
85            5432,
86            String::from("nautilus"),
87            String::from("pass"),
88            String::from("nautilus"),
89        )
90    }
91}
92
93impl From<PostgresConnectOptions> for PgConnectOptions {
94    fn from(opt: PostgresConnectOptions) -> Self {
95        Self::new()
96            .host(opt.host.as_str())
97            .port(opt.port)
98            .username(opt.username.as_str())
99            .password(opt.password.as_str())
100            .database(opt.database.as_str())
101            .disable_statement_logging()
102    }
103}
104
105/// Constructs `PostgresConnectOptions` by merging provided arguments, environment variables, and defaults.
106///
107/// # Panics
108///
109/// Panics if an environment variable for port cannot be parsed into a `u16`.
110#[must_use]
111pub fn get_postgres_connect_options(
112    host: Option<String>,
113    port: Option<u16>,
114    username: Option<String>,
115    password: Option<String>,
116    database: Option<String>,
117) -> PostgresConnectOptions {
118    let defaults = PostgresConnectOptions::default_administrator();
119    let host = host
120        .or_else(|| std::env::var("POSTGRES_HOST").ok())
121        .unwrap_or(defaults.host);
122    let port = port
123        .or_else(|| {
124            std::env::var("POSTGRES_PORT")
125                .map(|port| port.parse::<u16>().unwrap())
126                .ok()
127        })
128        .unwrap_or(defaults.port);
129    let username = username
130        .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
131        .unwrap_or(defaults.username);
132    let database = database
133        .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
134        .unwrap_or(defaults.database);
135    let password = password
136        .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
137        .unwrap_or(defaults.password);
138    PostgresConnectOptions::new(host, port, username, password, database)
139}
140
141/// Connects to a Postgres database with the provided connection `options` returning a connection pool.
142///
143/// # Errors
144///
145/// Returns an error if establishing the database connection fails.
146pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
147    Ok(PgPool::connect_with(options).await?)
148}
149
150/// Scans the current working directory for the `nautilus_trader` repository
151/// and constructs the path to the SQL schema directory.
152///
153/// # Errors
154///
155/// Returns an error if the `SCHEMA_DIR` environment variable is not set and the repository
156/// cannot be located in the current directory path.
157///
158/// # Panics
159///
160/// Panics if the current working directory cannot be determined or contains invalid UTF-8.
161fn get_schema_dir() -> anyhow::Result<String> {
162    std::env::var("SCHEMA_DIR").or_else(|_| {
163        let nautilus_git_repo_name = "nautilus_trader";
164        let binding = std::env::current_dir().unwrap();
165        let current_dir = binding.to_str().unwrap();
166        match current_dir.find(nautilus_git_repo_name){
167            Some(index) => {
168                let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
169                Ok(schema_path)
170            }
171            None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
172        }
173    })
174}
175
176/// Initializes the Postgres database by creating schema, roles, and executing SQL files from `schema_dir`.
177///
178/// # Errors
179///
180/// Returns an error if any SQL execution or file system operation fails.
181///
182/// # Panics
183///
184/// Panics if `schema_dir` is missing and cannot be determined or if other unwraps fail.
185pub async fn init_postgres(
186    pg: &PgPool,
187    database: String,
188    password: String,
189    schema_dir: Option<String>,
190) -> anyhow::Result<()> {
191    log::info!("Initializing Postgres database with target permissions and schema");
192
193    // Create public schema
194    match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
195        .execute(pg)
196        .await
197    {
198        Ok(_) => log::info!("Schema public created successfully"),
199        Err(e) => log::error!("Error creating schema public: {e:?}"),
200    }
201
202    // Create role if not exists
203    match sqlx::query(format!("CREATE ROLE {database} PASSWORD '{password}' LOGIN;").as_str())
204        .execute(pg)
205        .await
206    {
207        Ok(_) => log::info!("Role {database} created successfully"),
208        Err(e) => {
209            if e.to_string().contains("already exists") {
210                log::info!("Role {database} already exists");
211            } else {
212                log::error!("Error creating role {database}: {e:?}");
213            }
214        }
215    }
216
217    // Execute all the sql files in schema dir
218    let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
219    let sql_files = vec!["types.sql", "functions.sql", "partitions.sql", "tables.sql"];
220    let plpgsql_regex =
221        Regex::new(r"\$\$ LANGUAGE plpgsql(?:[ \t\r\n]+SECURITY[ \t\r\n]+DEFINER)?;")?;
222    for file_name in &sql_files {
223        log::info!("Executing schema file: {file_name:?}");
224        let file_path = format!("{}/{}", schema_dir, file_name);
225        let sql_content = std::fs::read_to_string(&file_path)?;
226        let sql_statements: Vec<String> = match *file_name {
227            "functions.sql" | "partitions.sql" => {
228                let mut statements = Vec::new();
229                let mut last_end = 0;
230
231                for mat in plpgsql_regex.find_iter(&sql_content) {
232                    let statement = sql_content[last_end..mat.end()].to_string();
233                    if !statement.trim().is_empty() {
234                        statements.push(statement);
235                    }
236                    last_end = mat.end();
237                }
238                statements
239            }
240            _ => sql_content
241                .split(';')
242                .filter(|s| !s.trim().is_empty())
243                .map(|s| format!("{s};"))
244                .collect(),
245        };
246
247        for sql_statement in sql_statements {
248            sqlx::query(&sql_statement)
249                .execute(pg)
250                .await
251                .map_err(|e| {
252                    if e.to_string().contains("already exists") {
253                        log::info!("Already exists error on statement, skipping");
254                    } else {
255                        panic!("Error executing statement {sql_statement} with error: {e:?}")
256                    }
257                })
258                .unwrap();
259        }
260    }
261
262    // Grant connect
263    match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
264        .execute(pg)
265        .await
266    {
267        Ok(_) => log::info!("Connect privileges granted to role {database}"),
268        Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
269    }
270
271    // Grant all schema privileges to the role
272    match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
273        .execute(pg)
274        .await
275    {
276        Ok(_) => log::info!("All schema privileges granted to role {database}"),
277        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
278    }
279
280    // Grant all table privileges to the role
281    match sqlx::query(
282        format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
283    )
284    .execute(pg)
285    .await
286    {
287        Ok(_) => log::info!("All tables privileges granted to role {database}"),
288        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
289    }
290
291    // Grant all sequence privileges to the role
292    match sqlx::query(
293        format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
294    )
295    .execute(pg)
296    .await
297    {
298        Ok(_) => log::info!("All sequences privileges granted to role {database}"),
299        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
300    }
301
302    // Grant all function privileges to the role
303    match sqlx::query(
304        format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
305    )
306    .execute(pg)
307    .await
308    {
309        Ok(_) => log::info!("All functions privileges granted to role {database}"),
310        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
311    }
312
313    Ok(())
314}
315
316/// Drops the Postgres database with the given name using the provided connection pool.
317///
318/// # Errors
319///
320/// Returns an error if the DROP DATABASE command fails.
321pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
322    // Execute drop owned
323    match sqlx::query(format!("DROP OWNED BY {database}").as_str())
324        .execute(pg)
325        .await
326    {
327        Ok(_) => log::info!("Dropped owned objects by role {database}"),
328        Err(e) => {
329            let err_msg = e.to_string();
330            if err_msg.contains("2BP01") || err_msg.contains("required by the database system") {
331                log::warn!("Skipping system-required objects for role {database}");
332            } else {
333                log::error!("Error dropping owned by role {database}: {e:?}");
334            }
335        }
336    }
337
338    // Revoke connect
339    match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
340        .execute(pg)
341        .await
342    {
343        Ok(_) => log::info!("Revoked connect privileges from role {database}"),
344        Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
345    }
346
347    // Revoke privileges
348    match sqlx::query(
349        format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
350    )
351    .execute(pg)
352    .await
353    {
354        Ok(_) => log::info!("Revoked all privileges from role {database}"),
355        Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
356    }
357
358    // Execute drop schema
359    match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
360        .execute(pg)
361        .await
362    {
363        Ok(_) => log::info!("Dropped schema public"),
364        Err(e) => log::error!("Error dropping schema public: {e:?}"),
365    }
366
367    // Drop role
368    match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
369        .execute(pg)
370        .await
371    {
372        Ok(_) => log::info!("Dropped role {database}"),
373        Err(e) => {
374            let err_msg = e.to_string();
375            if err_msg.contains("55006") || err_msg.contains("current user cannot be dropped") {
376                log::warn!("Cannot drop currently connected role {database}");
377            } else {
378                log::error!("Error dropping role {database}: {e:?}");
379            }
380        }
381    }
382    Ok(())
383}