diff --git a/Cargo.lock b/Cargo.lock index 183f3c2..ff8f41e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -352,6 +352,15 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.0" @@ -431,6 +440,7 @@ version = "0.1.0" dependencies = [ "env_logger", "futures-util", + "ipnetwork", "log", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 1e88c09..8b9e077 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ forwarder = ["tokio-tungstenite", "tokio", "futures-util"] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.5" +ipnetwork = "0.20.0" # forwarder deps tokio-tungstenite = { version = "0.23.1", optional = true, features = ["native-tls"] } diff --git a/README.md b/README.md index 4abe825..eb6288d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ - # noteguard A high performance note filter plugin system for [strfry] @@ -107,6 +106,18 @@ The whitelist filter only allows notes to pass if it matches a particular pubkey Either criteria can match +### Blacklist + +* name: `blacklist` + +The blacklist filter blocks notes that match any pubkey, ip, or CIDR range: + +- `pubkeys` *optional*: a list of hex public keys to block + +- `ips` *optional*: a list of IP addresses to block + +- `cidrs` *optional*: a list of CIDR ranges to block + ### Kinds * name: `kinds` @@ -155,7 +166,6 @@ be queued if the connection goes down (up to the `queue_size` buffer limit) - `queue_size` *optional* - size of the note queue, this is used to buffer notes if the connection goes down. Default is 1000. - ## Testing You can test your filters like so: diff --git a/src/filters/blacklist.rs b/src/filters/blacklist.rs index e72fb2e..13e2f3a 100644 --- a/src/filters/blacklist.rs +++ b/src/filters/blacklist.rs @@ -1,15 +1,66 @@ use crate::{Action, InputMessage, NoteFilter, OutputMessage}; +use ipnetwork::IpNetwork; use serde::Deserialize; +use std::net::IpAddr; +use std::str::FromStr; #[derive(Deserialize, Default)] -pub struct Blacklist { +pub struct BlacklistConfig { pub pubkeys: Option>, pub ips: Option>, + pub cidrs: Option>, +} + +#[derive(Default)] +pub struct Blacklist { + pubkeys: Option>, + ips: Option>, + cidrs: Option>, +} + +impl<'de> Deserialize<'de> for Blacklist { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let config = BlacklistConfig::deserialize(deserializer)?; + Ok(Blacklist { + pubkeys: config.pubkeys, + ips: config.ips, + cidrs: config.cidrs.map(|cidrs| { + cidrs + .into_iter() + .filter_map(|s| IpNetwork::from_str(&s).ok()) + .collect() + }), + }) + } +} + +impl Blacklist { + fn is_ip_blocked(&self, ip: &str) -> bool { + if let Some(ips) = &self.ips { + if ips.contains(&ip.to_string()) { + return true; + } + } + + if let Ok(addr) = IpAddr::from_str(ip) { + if let Some(cidrs) = &self.cidrs { + if cidrs.iter().any(|network| network.contains(addr)) { + return true; + } + } + } + + false + } } impl NoteFilter for Blacklist { fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage { let reject_message = "blocked: pubkey/ip is blacklisted".to_string(); + if let Some(pubkeys) = &self.pubkeys { if pubkeys.contains(&msg.event.pubkey) { return OutputMessage::new( @@ -20,14 +71,8 @@ impl NoteFilter for Blacklist { } } - if let Some(ips) = &self.ips { - if ips.contains(&msg.source_info) { - return OutputMessage::new( - msg.event.id.clone(), - Action::Reject, - Some(reject_message), - ); - } + if self.is_ip_blocked(&msg.source_info) { + return OutputMessage::new(msg.event.id.clone(), Action::Reject, Some(reject_message)); } OutputMessage::new(msg.event.id.clone(), Action::Accept, None) diff --git a/src/main.rs b/src/main.rs index 16b1d2d..b6922bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -177,12 +177,14 @@ fn noteguard() { #[cfg(test)] mod tests { use super::*; - use noteguard::filters::{Kinds, ProtectedEvents, RateLimit, Whitelist}; use noteguard::{Action, Note}; - use serde_json::json; // Helper function to create a mock InputMessage - fn create_mock_input_message(event_id: &str, message_type: &str) -> InputMessage { + fn create_mock_input_message( + event_id: &str, + message_type: &str, + source_info: &str, + ) -> InputMessage { InputMessage { message_type: message_type.to_string(), event: Note { @@ -196,20 +198,7 @@ mod tests { }, received_at: 0, source_type: "mock_source".to_string(), - source_info: "mock_source_info".to_string(), - } - } - - // Helper function to create a mock OutputMessage - fn create_mock_output_message( - event_id: &str, - action: Action, - msg: Option<&str>, - ) -> OutputMessage { - OutputMessage { - id: event_id.to_string(), - action, - msg: msg.map(|s| s.to_string()), + source_info: source_info.to_string(), } } @@ -263,7 +252,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_1", "new"); + let input_message = create_mock_input_message("test_event_1", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Accept); @@ -287,7 +276,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_3", "new"); + let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Reject); @@ -311,7 +300,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_2", "new"); + let input_message = create_mock_input_message("test_event_2", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Reject); @@ -334,7 +323,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_3", "new"); + let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Reject); @@ -361,12 +350,49 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_4", "new"); + let input_message = create_mock_input_message("test_event_4", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Accept); } + #[test] + fn test_blacklist_cidr() { + let mut noteguard = Noteguard::new(); + + let config: Config = toml::from_str( + r#" + pipeline = ["blacklist"] + [filters.blacklist] + cidrs = ["127.0.0.1/24"] + "#, + ) + .expect("Failed to parse config"); + + noteguard + .load_config(&config) + .expect("Failed to load config"); + + let test_cases = [ + ("127.0.0.1", true), + ("127.0.0.2", true), + ("128.0.0.1", false), + ("127.1.0.1", false), + ("127.0.1.1", false), + ]; + + for (ip, should_reject) in test_cases.iter() { + let input_message = create_mock_input_message("event_id", "new", ip); + let output_message = noteguard.run(input_message); + + if *should_reject { + assert_eq!(output_message.action, Action::Reject); + } else { + assert_eq!(output_message.action, Action::Accept); + } + } + } + #[test] fn test_deserialize_input_message() { let input_json = r#"