nautilus_infrastructure/sql/
pg.rs1use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20fn validate_sql_identifier(value: &str, label: &str) -> anyhow::Result<()> {
21 if value.is_empty() {
22 anyhow::bail!("{label} must not be empty");
23 }
24 if !value.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
25 anyhow::bail!(
26 "{label} contains invalid characters (only alphanumeric and underscore allowed): {value}"
27 );
28 }
29 Ok(())
30}
31
32fn escape_sql_string(value: &str) -> String {
33 value.replace('\'', "''")
34}
35
36#[derive(Debug, Clone, Builder)]
37#[builder(default)]
38#[cfg_attr(
39 feature = "python",
40 pyo3::pyclass(
41 module = "nautilus_trader.core.nautilus_pyo3.infrastructure",
42 from_py_object
43 )
44)]
45#[cfg_attr(
46 feature = "python",
47 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.infrastructure")
48)]
49pub struct PostgresConnectOptions {
50 pub host: String,
51 pub port: u16,
52 pub username: String,
53 pub password: String,
54 pub database: String,
55}
56
57impl PostgresConnectOptions {
58 #[must_use]
60 pub const fn new(
61 host: String,
62 port: u16,
63 username: String,
64 password: String,
65 database: String,
66 ) -> Self {
67 Self {
68 host,
69 port,
70 username,
71 password,
72 database,
73 }
74 }
75
76 #[must_use]
77 pub fn connection_string(&self) -> String {
78 format!(
79 "postgres://{username}:{password}@{host}:{port}/{database}",
80 username = self.username,
81 password = self.password,
82 host = self.host,
83 port = self.port,
84 database = self.database
85 )
86 }
87
88 #[must_use]
90 pub fn connection_string_masked(&self) -> String {
91 format!(
92 "postgres://{username}:***@{host}:{port}/{database}",
93 username = self.username,
94 host = self.host,
95 port = self.port,
96 database = self.database
97 )
98 }
99
100 #[must_use]
101 pub fn default_administrator() -> Self {
102 Self::new(
103 String::from("localhost"),
104 5432,
105 String::from("nautilus"),
106 String::from("pass"),
107 String::from("nautilus"),
108 )
109 }
110}
111
112impl Default for PostgresConnectOptions {
113 fn default() -> Self {
114 Self::new(
115 String::from("localhost"),
116 5432,
117 String::from("nautilus"),
118 String::from("pass"),
119 String::from("nautilus"),
120 )
121 }
122}
123
124impl From<PostgresConnectOptions> for PgConnectOptions {
125 fn from(opt: PostgresConnectOptions) -> Self {
126 Self::new()
127 .host(opt.host.as_str())
128 .port(opt.port)
129 .username(opt.username.as_str())
130 .password(opt.password.as_str())
131 .database(opt.database.as_str())
132 .disable_statement_logging()
133 }
134}
135
136#[must_use]
142pub fn get_postgres_connect_options(
143 host: Option<String>,
144 port: Option<u16>,
145 username: Option<String>,
146 password: Option<String>,
147 database: Option<String>,
148) -> PostgresConnectOptions {
149 let defaults = PostgresConnectOptions::default_administrator();
150 let host = host
151 .or_else(|| std::env::var("POSTGRES_HOST").ok())
152 .unwrap_or(defaults.host);
153 let port = port
154 .or_else(|| {
155 std::env::var("POSTGRES_PORT")
156 .map(|port| port.parse::<u16>().unwrap())
157 .ok()
158 })
159 .unwrap_or(defaults.port);
160 let username = username
161 .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
162 .unwrap_or(defaults.username);
163 let database = database
164 .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
165 .unwrap_or(defaults.database);
166 let password = password
167 .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
168 .unwrap_or(defaults.password);
169 PostgresConnectOptions::new(host, port, username, password, database)
170}
171
172pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
178 Ok(PgPool::connect_with(options).await?)
179}
180
181fn get_schema_dir() -> anyhow::Result<String> {
193 std::env::var("SCHEMA_DIR").or_else(|_| {
194 let nautilus_git_repo_name = "nautilus_trader";
195 let binding = std::env::current_dir().unwrap();
196 let current_dir = binding.to_str().unwrap();
197 match current_dir.find(nautilus_git_repo_name){
198 Some(index) => {
199 let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
200 Ok(schema_path)
201 }
202 None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
203 }
204 })
205}
206
207pub async fn init_postgres(
217 pg: &PgPool,
218 database: String,
219 password: String,
220 schema_dir: Option<String>,
221) -> anyhow::Result<()> {
222 log::info!("Initializing Postgres database with target permissions and schema");
223
224 validate_sql_identifier(&database, "database")?;
225
226 match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
228 .execute(pg)
229 .await
230 {
231 Ok(_) => log::info!("Schema public created successfully"),
232 Err(e) => log::error!("Error creating schema public: {e:?}"),
233 }
234
235 let escaped_password = escape_sql_string(&password);
237 match sqlx::query(
238 format!("CREATE ROLE {database} PASSWORD '{escaped_password}' LOGIN;").as_str(),
239 )
240 .execute(pg)
241 .await
242 {
243 Ok(_) => log::info!("Role {database} created successfully"),
244 Err(e) => {
245 if e.to_string().contains("already exists") {
246 log::info!("Role {database} already exists");
247 } else {
248 log::error!("Error creating role {database}: {e:?}");
249 }
250 }
251 }
252
253 let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
255 let sql_files = vec!["types.sql", "functions.sql", "partitions.sql", "tables.sql"];
256 let plpgsql_regex =
257 Regex::new(r"\$\$ LANGUAGE plpgsql(?:[ \t\r\n]+SECURITY[ \t\r\n]+DEFINER)?;")?;
258 for file_name in &sql_files {
259 log::info!("Executing schema file: {file_name:?}");
260 let file_path = format!("{schema_dir}/{file_name}");
261 let sql_content = std::fs::read_to_string(&file_path)?;
262 let sql_statements: Vec<String> = match *file_name {
263 "functions.sql" | "partitions.sql" => {
264 let mut statements = Vec::new();
265 let mut last_end = 0;
266
267 for mat in plpgsql_regex.find_iter(&sql_content) {
268 let statement = sql_content[last_end..mat.end()].to_string();
269 if !statement.trim().is_empty() {
270 statements.push(statement);
271 }
272 last_end = mat.end();
273 }
274 statements
275 }
276 _ => sql_content
277 .split(';')
278 .filter(|s| !s.trim().is_empty())
279 .map(|s| format!("{s};"))
280 .collect(),
281 };
282
283 for sql_statement in sql_statements {
284 sqlx::query(&sql_statement)
285 .execute(pg)
286 .await
287 .map_err(|e| {
288 if e.to_string().contains("already exists") {
289 log::info!("Already exists error on statement, skipping");
290 } else {
291 panic!("Error executing statement {sql_statement} with error: {e:?}")
292 }
293 })
294 .unwrap();
295 }
296 }
297
298 match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
300 .execute(pg)
301 .await
302 {
303 Ok(_) => log::info!("Connect privileges granted to role {database}"),
304 Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
305 }
306
307 match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
309 .execute(pg)
310 .await
311 {
312 Ok(_) => log::info!("All schema privileges granted to role {database}"),
313 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
314 }
315
316 match sqlx::query(
318 format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
319 )
320 .execute(pg)
321 .await
322 {
323 Ok(_) => log::info!("All tables privileges granted to role {database}"),
324 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
325 }
326
327 match sqlx::query(
329 format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
330 )
331 .execute(pg)
332 .await
333 {
334 Ok(_) => log::info!("All sequences privileges granted to role {database}"),
335 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
336 }
337
338 match sqlx::query(
340 format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
341 )
342 .execute(pg)
343 .await
344 {
345 Ok(_) => log::info!("All functions privileges granted to role {database}"),
346 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
347 }
348
349 Ok(())
350}
351
352pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
358 validate_sql_identifier(&database, "database")?;
359
360 match sqlx::query(format!("DROP OWNED BY {database}").as_str())
362 .execute(pg)
363 .await
364 {
365 Ok(_) => log::info!("Dropped owned objects by role {database}"),
366 Err(e) => {
367 let err_msg = e.to_string();
368 if err_msg.contains("2BP01") || err_msg.contains("required by the database system") {
369 log::warn!("Skipping system-required objects for role {database}");
370 } else {
371 log::error!("Error dropping owned by role {database}: {e:?}");
372 }
373 }
374 }
375
376 match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
378 .execute(pg)
379 .await
380 {
381 Ok(_) => log::info!("Revoked connect privileges from role {database}"),
382 Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
383 }
384
385 match sqlx::query(
387 format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
388 )
389 .execute(pg)
390 .await
391 {
392 Ok(_) => log::info!("Revoked all privileges from role {database}"),
393 Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
394 }
395
396 match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
398 .execute(pg)
399 .await
400 {
401 Ok(_) => log::info!("Dropped schema public"),
402 Err(e) => log::error!("Error dropping schema public: {e:?}"),
403 }
404
405 match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
407 .execute(pg)
408 .await
409 {
410 Ok(_) => log::info!("Dropped role {database}"),
411 Err(e) => {
412 let err_msg = e.to_string();
413 if err_msg.contains("55006") || err_msg.contains("current user cannot be dropped") {
414 log::warn!("Cannot drop currently connected role {database}");
415 } else {
416 log::error!("Error dropping role {database}: {e:?}");
417 }
418 }
419 }
420 Ok(())
421}