Merge pull request #17 from dsaxton/dsaxton/blacklist-cidr

Allow blacklisting CIDR ranges
This commit is contained in:
William Casarin 2025-05-01 16:34:55 -07:00 committed by GitHub
commit 4d4d9d62d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 125 additions and 33 deletions

10
Cargo.lock generated
View file

@ -352,6 +352,15 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "ipnetwork"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "is_terminal_polyfill" name = "is_terminal_polyfill"
version = "1.70.0" version = "1.70.0"
@ -431,6 +440,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"env_logger", "env_logger",
"futures-util", "futures-util",
"ipnetwork",
"log", "log",
"serde", "serde",
"serde_json", "serde_json",

View file

@ -10,6 +10,7 @@ forwarder = ["tokio-tungstenite", "tokio", "futures-util"]
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
toml = "0.5" toml = "0.5"
ipnetwork = "0.20.0"
# forwarder deps # forwarder deps
tokio-tungstenite = { version = "0.23.1", optional = true, features = ["native-tls"] } tokio-tungstenite = { version = "0.23.1", optional = true, features = ["native-tls"] }

View file

@ -1,4 +1,3 @@
# noteguard # noteguard
A high performance note filter plugin system for [strfry] A high performance note filter plugin system for [strfry]
@ -107,6 +106,18 @@ The whitelist filter only allows notes to pass if it matches a particular pubkey
Either criteria can match Either criteria can match
### Blacklist
* name: `blacklist`
The blacklist filter blocks notes that match any pubkey, ip, or CIDR range:
- `pubkeys` *optional*: a list of hex public keys to block
- `ips` *optional*: a list of IP addresses to block
- `cidrs` *optional*: a list of CIDR ranges to block
### Kinds ### Kinds
* name: `kinds` * name: `kinds`
@ -155,7 +166,6 @@ be queued if the connection goes down (up to the `queue_size` buffer limit)
- `queue_size` *optional* - size of the note queue, this is used to buffer notes if the connection goes down. Default is 1000. - `queue_size` *optional* - size of the note queue, this is used to buffer notes if the connection goes down. Default is 1000.
## Testing ## Testing
You can test your filters like so: You can test your filters like so:

View file

@ -1,15 +1,66 @@
use crate::{Action, InputMessage, NoteFilter, OutputMessage}; use crate::{Action, InputMessage, NoteFilter, OutputMessage};
use ipnetwork::IpNetwork;
use serde::Deserialize; use serde::Deserialize;
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Deserialize, Default)] #[derive(Deserialize, Default)]
pub struct Blacklist { pub struct BlacklistConfig {
pub pubkeys: Option<Vec<String>>, pub pubkeys: Option<Vec<String>>,
pub ips: Option<Vec<String>>, pub ips: Option<Vec<String>>,
pub cidrs: Option<Vec<String>>,
}
#[derive(Default)]
pub struct Blacklist {
pubkeys: Option<Vec<String>>,
ips: Option<Vec<String>>,
cidrs: Option<Vec<IpNetwork>>,
}
impl<'de> Deserialize<'de> for Blacklist {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let config = BlacklistConfig::deserialize(deserializer)?;
Ok(Blacklist {
pubkeys: config.pubkeys,
ips: config.ips,
cidrs: config.cidrs.map(|cidrs| {
cidrs
.into_iter()
.filter_map(|s| IpNetwork::from_str(&s).ok())
.collect()
}),
})
}
}
impl Blacklist {
fn is_ip_blocked(&self, ip: &str) -> bool {
if let Some(ips) = &self.ips {
if ips.contains(&ip.to_string()) {
return true;
}
}
if let Ok(addr) = IpAddr::from_str(ip) {
if let Some(cidrs) = &self.cidrs {
if cidrs.iter().any(|network| network.contains(addr)) {
return true;
}
}
}
false
}
} }
impl NoteFilter for Blacklist { impl NoteFilter for Blacklist {
fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage { fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage {
let reject_message = "blocked: pubkey/ip is blacklisted".to_string(); let reject_message = "blocked: pubkey/ip is blacklisted".to_string();
if let Some(pubkeys) = &self.pubkeys { if let Some(pubkeys) = &self.pubkeys {
if pubkeys.contains(&msg.event.pubkey) { if pubkeys.contains(&msg.event.pubkey) {
return OutputMessage::new( return OutputMessage::new(
@ -20,14 +71,8 @@ impl NoteFilter for Blacklist {
} }
} }
if let Some(ips) = &self.ips { if self.is_ip_blocked(&msg.source_info) {
if ips.contains(&msg.source_info) { return OutputMessage::new(msg.event.id.clone(), Action::Reject, Some(reject_message));
return OutputMessage::new(
msg.event.id.clone(),
Action::Reject,
Some(reject_message),
);
}
} }
OutputMessage::new(msg.event.id.clone(), Action::Accept, None) OutputMessage::new(msg.event.id.clone(), Action::Accept, None)

View file

@ -177,12 +177,14 @@ fn noteguard() {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use noteguard::filters::{Kinds, ProtectedEvents, RateLimit, Whitelist};
use noteguard::{Action, Note}; use noteguard::{Action, Note};
use serde_json::json;
// Helper function to create a mock InputMessage // Helper function to create a mock InputMessage
fn create_mock_input_message(event_id: &str, message_type: &str) -> InputMessage { fn create_mock_input_message(
event_id: &str,
message_type: &str,
source_info: &str,
) -> InputMessage {
InputMessage { InputMessage {
message_type: message_type.to_string(), message_type: message_type.to_string(),
event: Note { event: Note {
@ -196,20 +198,7 @@ mod tests {
}, },
received_at: 0, received_at: 0,
source_type: "mock_source".to_string(), source_type: "mock_source".to_string(),
source_info: "mock_source_info".to_string(), source_info: source_info.to_string(),
}
}
// Helper function to create a mock OutputMessage
fn create_mock_output_message(
event_id: &str,
action: Action,
msg: Option<&str>,
) -> OutputMessage {
OutputMessage {
id: event_id.to_string(),
action,
msg: msg.map(|s| s.to_string()),
} }
} }
@ -263,7 +252,7 @@ mod tests {
.load_config(&config) .load_config(&config)
.expect("Failed to load config"); .expect("Failed to load config");
let input_message = create_mock_input_message("test_event_1", "new"); let input_message = create_mock_input_message("test_event_1", "new", "mock_source_info");
let output_message = noteguard.run(input_message); let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Accept); assert_eq!(output_message.action, Action::Accept);
@ -287,7 +276,7 @@ mod tests {
.load_config(&config) .load_config(&config)
.expect("Failed to load config"); .expect("Failed to load config");
let input_message = create_mock_input_message("test_event_3", "new"); let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info");
let output_message = noteguard.run(input_message); let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Reject); assert_eq!(output_message.action, Action::Reject);
@ -311,7 +300,7 @@ mod tests {
.load_config(&config) .load_config(&config)
.expect("Failed to load config"); .expect("Failed to load config");
let input_message = create_mock_input_message("test_event_2", "new"); let input_message = create_mock_input_message("test_event_2", "new", "mock_source_info");
let output_message = noteguard.run(input_message); let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Reject); assert_eq!(output_message.action, Action::Reject);
@ -334,7 +323,7 @@ mod tests {
.load_config(&config) .load_config(&config)
.expect("Failed to load config"); .expect("Failed to load config");
let input_message = create_mock_input_message("test_event_3", "new"); let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info");
let output_message = noteguard.run(input_message); let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Reject); assert_eq!(output_message.action, Action::Reject);
@ -361,12 +350,49 @@ mod tests {
.load_config(&config) .load_config(&config)
.expect("Failed to load config"); .expect("Failed to load config");
let input_message = create_mock_input_message("test_event_4", "new"); let input_message = create_mock_input_message("test_event_4", "new", "mock_source_info");
let output_message = noteguard.run(input_message); let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Accept); assert_eq!(output_message.action, Action::Accept);
} }
#[test]
fn test_blacklist_cidr() {
let mut noteguard = Noteguard::new();
let config: Config = toml::from_str(
r#"
pipeline = ["blacklist"]
[filters.blacklist]
cidrs = ["127.0.0.1/24"]
"#,
)
.expect("Failed to parse config");
noteguard
.load_config(&config)
.expect("Failed to load config");
let test_cases = [
("127.0.0.1", true),
("127.0.0.2", true),
("128.0.0.1", false),
("127.1.0.1", false),
("127.0.1.1", false),
];
for (ip, should_reject) in test_cases.iter() {
let input_message = create_mock_input_message("event_id", "new", ip);
let output_message = noteguard.run(input_message);
if *should_reject {
assert_eq!(output_message.action, Action::Reject);
} else {
assert_eq!(output_message.action, Action::Accept);
}
}
}
#[test] #[test]
fn test_deserialize_input_message() { fn test_deserialize_input_message() {
let input_json = r#" let input_json = r#"