diff --git a/Cargo.lock b/Cargo.lock index 183f3c2..ff8f41e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -352,6 +352,15 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "ipnetwork" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e" +dependencies = [ + "serde", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.0" @@ -431,6 +440,7 @@ version = "0.1.0" dependencies = [ "env_logger", "futures-util", + "ipnetwork", "log", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 1e88c09..8b9e077 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ forwarder = ["tokio-tungstenite", "tokio", "futures-util"] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.5" +ipnetwork = "0.20.0" # forwarder deps tokio-tungstenite = { version = "0.23.1", optional = true, features = ["native-tls"] } diff --git a/src/filters/blacklist.rs b/src/filters/blacklist.rs index e72fb2e..2eb99f1 100644 --- a/src/filters/blacklist.rs +++ b/src/filters/blacklist.rs @@ -1,15 +1,44 @@ use crate::{Action, InputMessage, NoteFilter, OutputMessage}; +use ipnetwork::IpNetwork; use serde::Deserialize; +use std::net::IpAddr; +use std::str::FromStr; #[derive(Deserialize, Default)] pub struct Blacklist { pub pubkeys: Option>, pub ips: Option>, + pub cidrs: Option>, +} + +impl Blacklist { + fn is_ip_blocked(&self, ip: &str) -> bool { + if let Some(cidrs) = &self.cidrs { + for cidr in cidrs { + if let Ok(network) = IpNetwork::from_str(cidr) { + if let Ok(addr) = IpAddr::from_str(ip) { + if network.contains(addr) { + return true; + } + } + } + } + } + + if let Some(ips) = &self.ips { + if ips.contains(&ip.to_string()) { + return true; + } + } + + false + } } impl NoteFilter for Blacklist { fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage { let reject_message = "blocked: pubkey/ip is blacklisted".to_string(); + if let Some(pubkeys) = &self.pubkeys { if pubkeys.contains(&msg.event.pubkey) { return OutputMessage::new( @@ -20,14 +49,8 @@ impl NoteFilter for Blacklist { } } - if let Some(ips) = &self.ips { - if ips.contains(&msg.source_info) { - return OutputMessage::new( - msg.event.id.clone(), - Action::Reject, - Some(reject_message), - ); - } + if self.is_ip_blocked(&msg.source_info) { + return OutputMessage::new(msg.event.id.clone(), Action::Reject, Some(reject_message)); } OutputMessage::new(msg.event.id.clone(), Action::Accept, None) diff --git a/src/main.rs b/src/main.rs index 16b1d2d..b6922bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -177,12 +177,14 @@ fn noteguard() { #[cfg(test)] mod tests { use super::*; - use noteguard::filters::{Kinds, ProtectedEvents, RateLimit, Whitelist}; use noteguard::{Action, Note}; - use serde_json::json; // Helper function to create a mock InputMessage - fn create_mock_input_message(event_id: &str, message_type: &str) -> InputMessage { + fn create_mock_input_message( + event_id: &str, + message_type: &str, + source_info: &str, + ) -> InputMessage { InputMessage { message_type: message_type.to_string(), event: Note { @@ -196,20 +198,7 @@ mod tests { }, received_at: 0, source_type: "mock_source".to_string(), - source_info: "mock_source_info".to_string(), - } - } - - // Helper function to create a mock OutputMessage - fn create_mock_output_message( - event_id: &str, - action: Action, - msg: Option<&str>, - ) -> OutputMessage { - OutputMessage { - id: event_id.to_string(), - action, - msg: msg.map(|s| s.to_string()), + source_info: source_info.to_string(), } } @@ -263,7 +252,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_1", "new"); + let input_message = create_mock_input_message("test_event_1", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Accept); @@ -287,7 +276,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_3", "new"); + let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Reject); @@ -311,7 +300,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_2", "new"); + let input_message = create_mock_input_message("test_event_2", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Reject); @@ -334,7 +323,7 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_3", "new"); + let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Reject); @@ -361,12 +350,49 @@ mod tests { .load_config(&config) .expect("Failed to load config"); - let input_message = create_mock_input_message("test_event_4", "new"); + let input_message = create_mock_input_message("test_event_4", "new", "mock_source_info"); let output_message = noteguard.run(input_message); assert_eq!(output_message.action, Action::Accept); } + #[test] + fn test_blacklist_cidr() { + let mut noteguard = Noteguard::new(); + + let config: Config = toml::from_str( + r#" + pipeline = ["blacklist"] + [filters.blacklist] + cidrs = ["127.0.0.1/24"] + "#, + ) + .expect("Failed to parse config"); + + noteguard + .load_config(&config) + .expect("Failed to load config"); + + let test_cases = [ + ("127.0.0.1", true), + ("127.0.0.2", true), + ("128.0.0.1", false), + ("127.1.0.1", false), + ("127.0.1.1", false), + ]; + + for (ip, should_reject) in test_cases.iter() { + let input_message = create_mock_input_message("event_id", "new", ip); + let output_message = noteguard.run(input_message); + + if *should_reject { + assert_eq!(output_message.action, Action::Reject); + } else { + assert_eq!(output_message.action, Action::Accept); + } + } + } + #[test] fn test_deserialize_input_message() { let input_json = r#"