nautilus_infrastructure/sql/
pg.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use derive_builder::Builder;
17use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
18
19#[derive(Debug, Clone, Builder)]
20#[builder(default)]
21pub struct PostgresConnectOptions {
22    pub host: String,
23    pub port: u16,
24    pub username: String,
25    pub password: String,
26    pub database: String,
27}
28
29impl PostgresConnectOptions {
30    /// Creates a new [`PostgresConnectOptions`] instance.
31    pub fn new(
32        host: String,
33        port: u16,
34        username: String,
35        password: String,
36        database: String,
37    ) -> Self {
38        Self {
39            host,
40            port,
41            username,
42            password,
43            database,
44        }
45    }
46    pub fn connection_string(&self) -> String {
47        format!(
48            "postgres://{username}:{password}@{host}:{port}/{database}",
49            username = self.username,
50            password = self.password,
51            host = self.host,
52            port = self.port,
53            database = self.database
54        )
55    }
56
57    pub fn default_administrator() -> Self {
58        PostgresConnectOptions::new(
59            String::from("localhost"),
60            5432,
61            String::from("postgres"),
62            String::from("pass"),
63            String::from("nautilus"),
64        )
65    }
66}
67
68impl Default for PostgresConnectOptions {
69    fn default() -> Self {
70        PostgresConnectOptions::new(
71            String::from("localhost"),
72            5432,
73            String::from("nautilus"),
74            String::from("pass"),
75            String::from("nautilus"),
76        )
77    }
78}
79
80impl From<PostgresConnectOptions> for PgConnectOptions {
81    fn from(opt: PostgresConnectOptions) -> Self {
82        PgConnectOptions::new()
83            .host(opt.host.as_str())
84            .port(opt.port)
85            .username(opt.username.as_str())
86            .password(opt.password.as_str())
87            .database(opt.database.as_str())
88            .disable_statement_logging()
89    }
90}
91
92// Gets the postgres connect options from provided arguments, environment variables or defaults
93pub fn get_postgres_connect_options(
94    host: Option<String>,
95    port: Option<u16>,
96    username: Option<String>,
97    password: Option<String>,
98    database: Option<String>,
99) -> PostgresConnectOptions {
100    let defaults = PostgresConnectOptions::default_administrator();
101    let host = host
102        .or_else(|| std::env::var("POSTGRES_HOST").ok())
103        .unwrap_or(defaults.host);
104    let port = port
105        .or_else(|| {
106            std::env::var("POSTGRES_PORT")
107                .map(|port| port.parse::<u16>().unwrap())
108                .ok()
109        })
110        .unwrap_or(defaults.port);
111    let username = username
112        .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
113        .unwrap_or(defaults.username);
114    let database = database
115        .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
116        .unwrap_or(defaults.database);
117    let password = password
118        .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
119        .unwrap_or(defaults.password);
120    PostgresConnectOptions::new(host, port, username, password, database)
121}
122
123pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
124    Ok(PgPool::connect_with(options).await?)
125}
126
127/// Scans current path with keyword nautilus_trader and build schema dir
128fn get_schema_dir() -> anyhow::Result<String> {
129    std::env::var("SCHEMA_DIR").or_else(|_| {
130        let nautilus_git_repo_name = "nautilus_trader";
131        let binding = std::env::current_dir().unwrap();
132        let current_dir = binding.to_str().unwrap();
133        match current_dir.find(nautilus_git_repo_name){
134            Some(index) => {
135                let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
136                Ok(schema_path)
137            }
138            None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
139        }
140    })
141}
142
143pub async fn init_postgres(
144    pg: &PgPool,
145    database: String,
146    password: String,
147    schema_dir: Option<String>,
148) -> anyhow::Result<()> {
149    log::info!("Initializing Postgres database with target permissions and schema");
150
151    // Create public schema
152    match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
153        .execute(pg)
154        .await
155    {
156        Ok(_) => log::info!("Schema public created successfully"),
157        Err(e) => log::error!("Error creating schema public: {e:?}"),
158    }
159
160    // Create role if not exists
161    match sqlx::query(format!("CREATE ROLE {database} PASSWORD '{password}' LOGIN;").as_str())
162        .execute(pg)
163        .await
164    {
165        Ok(_) => log::info!("Role {database} created successfully"),
166        Err(e) => {
167            if e.to_string().contains("already exists") {
168                log::info!("Role {database} already exists");
169            } else {
170                log::error!("Error creating role {database}: {e:?}");
171            }
172        }
173    }
174
175    // Execute all the sql files in schema dir
176    let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
177    let mut sql_files =
178        std::fs::read_dir(schema_dir)?.collect::<Result<Vec<_>, std::io::Error>>()?;
179    for file in &mut sql_files {
180        let file_name = file.file_name();
181        log::info!("Executing schema file: {file_name:?}");
182        let file_path = file.path();
183        let sql_content = std::fs::read_to_string(file_path.clone())?;
184        // if filename is functions.sql, split by plpgsql; if not then by ;
185        let delimiter = match file_name.to_str() {
186            Some("functions.sql") => "$$ LANGUAGE plpgsql;",
187            _ => ";",
188        };
189        let sql_statements = sql_content
190            .split(delimiter)
191            .filter(|s| !s.trim().is_empty())
192            .map(|s| format!("{s}{delimiter}"));
193
194        for sql_statement in sql_statements {
195            sqlx::query(&sql_statement)
196                .execute(pg)
197                .await
198                .map_err(|e| {
199                    if e.to_string().contains("already exists") {
200                        log::info!("Already exists error on statement, skipping");
201                    } else {
202                        panic!("Error executing statement {sql_statement} with error: {e:?}")
203                    }
204                })
205                .unwrap();
206        }
207    }
208
209    // Grant connect
210    match sqlx::query(format!("GRANT CONNECT ON DATABASE {0} TO {0};", database).as_str())
211        .execute(pg)
212        .await
213    {
214        Ok(_) => log::info!("Connect privileges granted to role {database}"),
215        Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
216    }
217
218    // Grant all schema privileges to the role
219    match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
220        .execute(pg)
221        .await
222    {
223        Ok(_) => log::info!("All schema privileges granted to role {database}"),
224        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
225    }
226
227    // Grant all table privileges to the role
228    match sqlx::query(
229        format!(
230            "GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {};",
231            database
232        )
233        .as_str(),
234    )
235    .execute(pg)
236    .await
237    {
238        Ok(_) => log::info!("All tables privileges granted to role {database}"),
239        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
240    }
241
242    // Grant all sequence privileges to the role
243    match sqlx::query(
244        format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
245    )
246    .execute(pg)
247    .await
248    {
249        Ok(_) => log::info!("All sequences privileges granted to role {database}"),
250        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
251    }
252
253    // Grant all function privileges to the role
254    match sqlx::query(
255        format!(
256            "GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {};",
257            database
258        )
259        .as_str(),
260    )
261    .execute(pg)
262    .await
263    {
264        Ok(_) => log::info!("All functions privileges granted to role {database}"),
265        Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
266    }
267
268    Ok(())
269}
270
271pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
272    // Execute drop owned
273    match sqlx::query(format!("DROP OWNED BY {database}").as_str())
274        .execute(pg)
275        .await
276    {
277        Ok(_) => log::info!("Dropped owned objects by role {database}"),
278        Err(e) => log::error!("Error dropping owned by role {database}: {e:?}"),
279    }
280
281    // Revoke connect
282    match sqlx::query(format!("REVOKE CONNECT ON DATABASE {0} FROM {0};", database).as_str())
283        .execute(pg)
284        .await
285    {
286        Ok(_) => log::info!("Revoked connect privileges from role {database}"),
287        Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
288    }
289
290    // Revoke privileges
291    match sqlx::query(format!("REVOKE ALL PRIVILEGES ON DATABASE {0} FROM {0};", database).as_str())
292        .execute(pg)
293        .await
294    {
295        Ok(_) => log::info!("Revoked all privileges from role {database}"),
296        Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
297    }
298
299    // Execute drop schema
300    match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
301        .execute(pg)
302        .await
303    {
304        Ok(_) => log::info!("Dropped schema public"),
305        Err(e) => log::error!("Error dropping schema public: {e:?}"),
306    }
307
308    // Drop role
309    match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
310        .execute(pg)
311        .await
312    {
313        Ok(_) => log::info!("Dropped role {database}"),
314        Err(e) => log::error!("Error dropping role {database}: {e:?}"),
315    }
316    Ok(())
317}