Created
October 3, 2023 20:09
-
-
Save vadv/617ba6b74fd494307546d84c83607c52 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/.gitignore b/.gitignore | |
index 94d2001..2078a91 100644 | |
--- a/.gitignore | |
+++ b/.gitignore | |
@@ -1,3 +1,4 @@ | |
/target | |
.idea | |
/examples | |
+/vendor | |
diff --git a/Cargo.lock b/Cargo.lock | |
index 9a82980..f105049 100644 | |
--- a/Cargo.lock | |
+++ b/Cargo.lock | |
@@ -605,6 +605,9 @@ name = "ipnet" | |
version = "2.8.0" | |
source = "registry+https://github.com/rust-lang/crates.io-index" | |
checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" | |
+dependencies = [ | |
+ "serde", | |
+] | |
[[package]] | |
name = "itoa" | |
diff --git a/Cargo.toml b/Cargo.toml | |
index 1b71a16..87e2ceb 100644 | |
--- a/Cargo.toml | |
+++ b/Cargo.toml | |
@@ -12,7 +12,7 @@ log = "0.4.20" | |
clap = { version = "4.3.1", features = ["derive", "env"] } | |
serde = { version = "1", features = ["derive"] } | |
serde_derive = "1" | |
-ipnet = "2.8.0" | |
+ipnet = { version = "2.8.0", features = ["serde"] } | |
once_cell = "1" | |
arc-swap = "1" | |
toml = "0.7" | |
diff --git a/pg_doorman.toml b/pg_doorman.toml | |
index 1d3e79a..0b1440a 100644 | |
--- a/pg_doorman.toml | |
+++ b/pg_doorman.toml | |
@@ -27,6 +27,8 @@ admin_password = "doorman_admin_password" | |
prometheus_exporter_port = 9075 | |
+hba = ["10.0.0.0/8", "192.168.0.0/16"] | |
+ | |
[pools] | |
[pools.example_db] | |
diff --git a/src/client.rs b/src/client.rs | |
index 1b130be..7c98876 100644 | |
--- a/src/client.rs | |
+++ b/src/client.rs | |
@@ -12,9 +12,7 @@ use tokio::sync::broadcast::Receiver; | |
use tokio::sync::mpsc::Sender; | |
use crate::admin::{generate_server_parameters_for_admin, handle_admin}; | |
-use crate::config::{ | |
- get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, | |
-}; | |
+use crate::config::{get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, addr_in_hba}; | |
use crate::constants::*; | |
use crate::messages::*; | |
use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; | |
@@ -463,6 +461,16 @@ where | |
return Err(Error::ShuttingDown); | |
} | |
+ if !addr_in_hba(addr.ip()) { | |
+ error_response_terminal( | |
+ &mut write, | |
+ "hba forbidden for this ip address", | |
+ ).await?; | |
+ return Err(Error::HbaForbiddenError(format!( | |
+ "hba forbidden client: {} from address: {:?}", client_identifier, addr | |
+ ))); | |
+ } | |
+ | |
// Generate random backend ID and secret key | |
let process_id: i32 = rand::random(); | |
let secret_key: i32 = rand::random(); | |
diff --git a/src/config.rs b/src/config.rs | |
index c4d6bad..4c36865 100644 | |
--- a/src/config.rs | |
+++ b/src/config.rs | |
@@ -7,6 +7,8 @@ use log::{error, info}; | |
use std::collections::{BTreeMap, HashMap}; | |
use std::path::Path; | |
use std::collections::hash_map::DefaultHasher; | |
+use std::net::IpAddr; | |
+use ipnet::IpNet; | |
use once_cell::sync::Lazy; | |
use tokio::fs::File; | |
use tokio::io::AsyncReadExt; | |
@@ -259,6 +261,8 @@ pub struct General { | |
#[serde(default = "General::default_prepared_statements_cache_size")] | |
pub prepared_statements_cache_size: usize, | |
+ | |
+ pub hba: Vec<IpNet>, | |
} | |
impl General { | |
@@ -375,6 +379,7 @@ impl Default for General { | |
validate_config: true, | |
prepared_statements: false, | |
prepared_statements_cache_size: 500, | |
+ hba: vec![], | |
} | |
} | |
} | |
@@ -859,8 +864,17 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E | |
} | |
} | |
+pub fn addr_in_hba(addr: IpAddr) -> bool { | |
+ let config = get_config(); | |
+ if config.general.hba.is_empty() { | |
+ return true | |
+ } | |
+ return config.general.hba .iter() .find(|net| net.contains(&addr)) .is_some(); | |
+} | |
+ | |
#[cfg(test)] | |
mod test { | |
+ use std::net::Ipv4Addr; | |
use super::*; | |
#[tokio::test] | |
@@ -883,6 +897,9 @@ mod test { | |
assert_eq!(get_config().pools["example_db"].users["1"].username, "example_user_2"); | |
assert_eq!(get_config().pools["example_db"].users["0"].pool_size, 40); | |
assert_eq!(get_config().pools["example_db"].users["0"].pool_mode, Some(PoolMode::Session)); | |
+ assert_eq!(addr_in_hba(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))), true); | |
+ assert_eq!(addr_in_hba(IpAddr::V4(Ipv4Addr::new(172, 0, 0, 1))), false); | |
+ assert_eq!(addr_in_hba(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))), true); | |
} | |
#[tokio::test] | |
diff --git a/src/errors.rs b/src/errors.rs | |
index 49553d8..6a73402 100644 | |
--- a/src/errors.rs | |
+++ b/src/errors.rs | |
@@ -26,6 +26,7 @@ pub enum Error { | |
QueryError(String), | |
ScramClientError(String), | |
ScramServerError(String), | |
+ HbaForbiddenError(String), | |
} | |
#[derive(Clone, PartialEq, Debug)] | |
@@ -51,7 +52,7 @@ impl std::fmt::Display for ClientIdentifier { | |
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
write!( | |
f, | |
- "{{ {}@{}{}/?application_name={} }}", | |
+ "{{ {}@{}/{}?application_name={} }}", | |
self.username, self.addr, self.pool_name, self.application_name | |
) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment