nautilus_infrastructure/sql/
pg.rs1use derive_builder::Builder;
17use regex::Regex;
18use sqlx::{ConnectOptions, PgPool, postgres::PgConnectOptions};
19
20#[derive(Debug, Clone, Builder)]
21#[builder(default)]
22#[cfg_attr(
23 feature = "python",
24 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
25)]
26#[cfg_attr(
27 feature = "python",
28 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.infrastructure")
29)]
30pub struct PostgresConnectOptions {
31 pub host: String,
32 pub port: u16,
33 pub username: String,
34 pub password: String,
35 pub database: String,
36}
37
38impl PostgresConnectOptions {
39 #[must_use]
41 pub const fn new(
42 host: String,
43 port: u16,
44 username: String,
45 password: String,
46 database: String,
47 ) -> Self {
48 Self {
49 host,
50 port,
51 username,
52 password,
53 database,
54 }
55 }
56
57 #[must_use]
58 pub fn connection_string(&self) -> String {
59 format!(
60 "postgres://{username}:{password}@{host}:{port}/{database}",
61 username = self.username,
62 password = self.password,
63 host = self.host,
64 port = self.port,
65 database = self.database
66 )
67 }
68
69 #[must_use]
70 pub fn default_administrator() -> Self {
71 Self::new(
72 String::from("localhost"),
73 5432,
74 String::from("nautilus"),
75 String::from("pass"),
76 String::from("nautilus"),
77 )
78 }
79}
80
81impl Default for PostgresConnectOptions {
82 fn default() -> Self {
83 Self::new(
84 String::from("localhost"),
85 5432,
86 String::from("nautilus"),
87 String::from("pass"),
88 String::from("nautilus"),
89 )
90 }
91}
92
93impl From<PostgresConnectOptions> for PgConnectOptions {
94 fn from(opt: PostgresConnectOptions) -> Self {
95 Self::new()
96 .host(opt.host.as_str())
97 .port(opt.port)
98 .username(opt.username.as_str())
99 .password(opt.password.as_str())
100 .database(opt.database.as_str())
101 .disable_statement_logging()
102 }
103}
104
105#[must_use]
111pub fn get_postgres_connect_options(
112 host: Option<String>,
113 port: Option<u16>,
114 username: Option<String>,
115 password: Option<String>,
116 database: Option<String>,
117) -> PostgresConnectOptions {
118 let defaults = PostgresConnectOptions::default_administrator();
119 let host = host
120 .or_else(|| std::env::var("POSTGRES_HOST").ok())
121 .unwrap_or(defaults.host);
122 let port = port
123 .or_else(|| {
124 std::env::var("POSTGRES_PORT")
125 .map(|port| port.parse::<u16>().unwrap())
126 .ok()
127 })
128 .unwrap_or(defaults.port);
129 let username = username
130 .or_else(|| std::env::var("POSTGRES_USERNAME").ok())
131 .unwrap_or(defaults.username);
132 let database = database
133 .or_else(|| std::env::var("POSTGRES_DATABASE").ok())
134 .unwrap_or(defaults.database);
135 let password = password
136 .or_else(|| std::env::var("POSTGRES_PASSWORD").ok())
137 .unwrap_or(defaults.password);
138 PostgresConnectOptions::new(host, port, username, password, database)
139}
140
141pub async fn connect_pg(options: PgConnectOptions) -> anyhow::Result<PgPool> {
147 Ok(PgPool::connect_with(options).await?)
148}
149
150fn get_schema_dir() -> anyhow::Result<String> {
162 std::env::var("SCHEMA_DIR").or_else(|_| {
163 let nautilus_git_repo_name = "nautilus_trader";
164 let binding = std::env::current_dir().unwrap();
165 let current_dir = binding.to_str().unwrap();
166 match current_dir.find(nautilus_git_repo_name){
167 Some(index) => {
168 let schema_path = current_dir[0..index + nautilus_git_repo_name.len()].to_string() + "/schema/sql";
169 Ok(schema_path)
170 }
171 None => anyhow::bail!("Could not calculate schema dir from current directory path or SCHEMA_DIR env variable")
172 }
173 })
174}
175
176pub async fn init_postgres(
186 pg: &PgPool,
187 database: String,
188 password: String,
189 schema_dir: Option<String>,
190) -> anyhow::Result<()> {
191 log::info!("Initializing Postgres database with target permissions and schema");
192
193 match sqlx::query("CREATE SCHEMA IF NOT EXISTS public;")
195 .execute(pg)
196 .await
197 {
198 Ok(_) => log::info!("Schema public created successfully"),
199 Err(e) => log::error!("Error creating schema public: {e:?}"),
200 }
201
202 match sqlx::query(format!("CREATE ROLE {database} PASSWORD '{password}' LOGIN;").as_str())
204 .execute(pg)
205 .await
206 {
207 Ok(_) => log::info!("Role {database} created successfully"),
208 Err(e) => {
209 if e.to_string().contains("already exists") {
210 log::info!("Role {database} already exists");
211 } else {
212 log::error!("Error creating role {database}: {e:?}");
213 }
214 }
215 }
216
217 let schema_dir = schema_dir.unwrap_or_else(|| get_schema_dir().unwrap());
219 let sql_files = vec!["types.sql", "functions.sql", "partitions.sql", "tables.sql"];
220 let plpgsql_regex =
221 Regex::new(r"\$\$ LANGUAGE plpgsql(?:[ \t\r\n]+SECURITY[ \t\r\n]+DEFINER)?;")?;
222 for file_name in &sql_files {
223 log::info!("Executing schema file: {file_name:?}");
224 let file_path = format!("{}/{}", schema_dir, file_name);
225 let sql_content = std::fs::read_to_string(&file_path)?;
226 let sql_statements: Vec<String> = match *file_name {
227 "functions.sql" | "partitions.sql" => {
228 let mut statements = Vec::new();
229 let mut last_end = 0;
230
231 for mat in plpgsql_regex.find_iter(&sql_content) {
232 let statement = sql_content[last_end..mat.end()].to_string();
233 if !statement.trim().is_empty() {
234 statements.push(statement);
235 }
236 last_end = mat.end();
237 }
238 statements
239 }
240 _ => sql_content
241 .split(';')
242 .filter(|s| !s.trim().is_empty())
243 .map(|s| format!("{s};"))
244 .collect(),
245 };
246
247 for sql_statement in sql_statements {
248 sqlx::query(&sql_statement)
249 .execute(pg)
250 .await
251 .map_err(|e| {
252 if e.to_string().contains("already exists") {
253 log::info!("Already exists error on statement, skipping");
254 } else {
255 panic!("Error executing statement {sql_statement} with error: {e:?}")
256 }
257 })
258 .unwrap();
259 }
260 }
261
262 match sqlx::query(format!("GRANT CONNECT ON DATABASE {database} TO {database};").as_str())
264 .execute(pg)
265 .await
266 {
267 Ok(_) => log::info!("Connect privileges granted to role {database}"),
268 Err(e) => log::error!("Error granting connect privileges to role {database}: {e:?}"),
269 }
270
271 match sqlx::query(format!("GRANT ALL PRIVILEGES ON SCHEMA public TO {database};").as_str())
273 .execute(pg)
274 .await
275 {
276 Ok(_) => log::info!("All schema privileges granted to role {database}"),
277 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
278 }
279
280 match sqlx::query(
282 format!("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {database};").as_str(),
283 )
284 .execute(pg)
285 .await
286 {
287 Ok(_) => log::info!("All tables privileges granted to role {database}"),
288 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
289 }
290
291 match sqlx::query(
293 format!("GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO {database};").as_str(),
294 )
295 .execute(pg)
296 .await
297 {
298 Ok(_) => log::info!("All sequences privileges granted to role {database}"),
299 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
300 }
301
302 match sqlx::query(
304 format!("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA public TO {database};").as_str(),
305 )
306 .execute(pg)
307 .await
308 {
309 Ok(_) => log::info!("All functions privileges granted to role {database}"),
310 Err(e) => log::error!("Error granting all privileges to role {database}: {e:?}"),
311 }
312
313 Ok(())
314}
315
316pub async fn drop_postgres(pg: &PgPool, database: String) -> anyhow::Result<()> {
322 match sqlx::query(format!("DROP OWNED BY {database}").as_str())
324 .execute(pg)
325 .await
326 {
327 Ok(_) => log::info!("Dropped owned objects by role {database}"),
328 Err(e) => {
329 let err_msg = e.to_string();
330 if err_msg.contains("2BP01") || err_msg.contains("required by the database system") {
331 log::warn!("Skipping system-required objects for role {database}");
332 } else {
333 log::error!("Error dropping owned by role {database}: {e:?}");
334 }
335 }
336 }
337
338 match sqlx::query(format!("REVOKE CONNECT ON DATABASE {database} FROM {database};").as_str())
340 .execute(pg)
341 .await
342 {
343 Ok(_) => log::info!("Revoked connect privileges from role {database}"),
344 Err(e) => log::error!("Error revoking connect privileges from role {database}: {e:?}"),
345 }
346
347 match sqlx::query(
349 format!("REVOKE ALL PRIVILEGES ON DATABASE {database} FROM {database};").as_str(),
350 )
351 .execute(pg)
352 .await
353 {
354 Ok(_) => log::info!("Revoked all privileges from role {database}"),
355 Err(e) => log::error!("Error revoking all privileges from role {database}: {e:?}"),
356 }
357
358 match sqlx::query("DROP SCHEMA IF EXISTS public CASCADE")
360 .execute(pg)
361 .await
362 {
363 Ok(_) => log::info!("Dropped schema public"),
364 Err(e) => log::error!("Error dropping schema public: {e:?}"),
365 }
366
367 match sqlx::query(format!("DROP ROLE IF EXISTS {database};").as_str())
369 .execute(pg)
370 .await
371 {
372 Ok(_) => log::info!("Dropped role {database}"),
373 Err(e) => {
374 let err_msg = e.to_string();
375 if err_msg.contains("55006") || err_msg.contains("current user cannot be dropped") {
376 log::warn!("Cannot drop currently connected role {database}");
377 } else {
378 log::error!("Error dropping role {database}: {e:?}");
379 }
380 }
381 }
382 Ok(())
383}