Allow blacklisting CIDR ranges

This commit is contained in:
Daniel Saxton 2025-01-04 19:20:37 -06:00
parent e6566e65f9
commit c22f1f0a24
4 changed files with 90 additions and 30 deletions

10
Cargo.lock generated
View file

@ -352,6 +352,15 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "ipnetwork"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf466541e9d546596ee94f9f69590f89473455f88372423e0008fc1a7daf100e"
dependencies = [
"serde",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.0"
@ -431,6 +440,7 @@ version = "0.1.0"
dependencies = [
"env_logger",
"futures-util",
"ipnetwork",
"log",
"serde",
"serde_json",

View file

@ -10,6 +10,7 @@ forwarder = ["tokio-tungstenite", "tokio", "futures-util"]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
toml = "0.5"
ipnetwork = "0.20.0"
# forwarder deps
tokio-tungstenite = { version = "0.23.1", optional = true, features = ["native-tls"] }

View file

@ -1,15 +1,44 @@
use crate::{Action, InputMessage, NoteFilter, OutputMessage};
use ipnetwork::IpNetwork;
use serde::Deserialize;
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Deserialize, Default)]
pub struct Blacklist {
pub pubkeys: Option<Vec<String>>,
pub ips: Option<Vec<String>>,
pub cidrs: Option<Vec<String>>,
}
impl Blacklist {
fn is_ip_blocked(&self, ip: &str) -> bool {
if let Some(cidrs) = &self.cidrs {
for cidr in cidrs {
if let Ok(network) = IpNetwork::from_str(cidr) {
if let Ok(addr) = IpAddr::from_str(ip) {
if network.contains(addr) {
return true;
}
}
}
}
}
if let Some(ips) = &self.ips {
if ips.contains(&ip.to_string()) {
return true;
}
}
false
}
}
impl NoteFilter for Blacklist {
fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage {
let reject_message = "blocked: pubkey/ip is blacklisted".to_string();
if let Some(pubkeys) = &self.pubkeys {
if pubkeys.contains(&msg.event.pubkey) {
return OutputMessage::new(
@ -20,14 +49,8 @@ impl NoteFilter for Blacklist {
}
}
if let Some(ips) = &self.ips {
if ips.contains(&msg.source_info) {
return OutputMessage::new(
msg.event.id.clone(),
Action::Reject,
Some(reject_message),
);
}
if self.is_ip_blocked(&msg.source_info) {
return OutputMessage::new(msg.event.id.clone(), Action::Reject, Some(reject_message));
}
OutputMessage::new(msg.event.id.clone(), Action::Accept, None)

View file

@ -177,12 +177,14 @@ fn noteguard() {
#[cfg(test)]
mod tests {
use super::*;
use noteguard::filters::{Kinds, ProtectedEvents, RateLimit, Whitelist};
use noteguard::{Action, Note};
use serde_json::json;
// Helper function to create a mock InputMessage
fn create_mock_input_message(event_id: &str, message_type: &str) -> InputMessage {
fn create_mock_input_message(
event_id: &str,
message_type: &str,
source_info: &str,
) -> InputMessage {
InputMessage {
message_type: message_type.to_string(),
event: Note {
@ -196,20 +198,7 @@ mod tests {
},
received_at: 0,
source_type: "mock_source".to_string(),
source_info: "mock_source_info".to_string(),
}
}
// Helper function to create a mock OutputMessage
fn create_mock_output_message(
event_id: &str,
action: Action,
msg: Option<&str>,
) -> OutputMessage {
OutputMessage {
id: event_id.to_string(),
action,
msg: msg.map(|s| s.to_string()),
source_info: source_info.to_string(),
}
}
@ -263,7 +252,7 @@ mod tests {
.load_config(&config)
.expect("Failed to load config");
let input_message = create_mock_input_message("test_event_1", "new");
let input_message = create_mock_input_message("test_event_1", "new", "mock_source_info");
let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Accept);
@ -287,7 +276,7 @@ mod tests {
.load_config(&config)
.expect("Failed to load config");
let input_message = create_mock_input_message("test_event_3", "new");
let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info");
let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Reject);
@ -311,7 +300,7 @@ mod tests {
.load_config(&config)
.expect("Failed to load config");
let input_message = create_mock_input_message("test_event_2", "new");
let input_message = create_mock_input_message("test_event_2", "new", "mock_source_info");
let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Reject);
@ -334,7 +323,7 @@ mod tests {
.load_config(&config)
.expect("Failed to load config");
let input_message = create_mock_input_message("test_event_3", "new");
let input_message = create_mock_input_message("test_event_3", "new", "mock_source_info");
let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Reject);
@ -361,12 +350,49 @@ mod tests {
.load_config(&config)
.expect("Failed to load config");
let input_message = create_mock_input_message("test_event_4", "new");
let input_message = create_mock_input_message("test_event_4", "new", "mock_source_info");
let output_message = noteguard.run(input_message);
assert_eq!(output_message.action, Action::Accept);
}
#[test]
fn test_blacklist_cidr() {
let mut noteguard = Noteguard::new();
let config: Config = toml::from_str(
r#"
pipeline = ["blacklist"]
[filters.blacklist]
cidrs = ["127.0.0.1/24"]
"#,
)
.expect("Failed to parse config");
noteguard
.load_config(&config)
.expect("Failed to load config");
let test_cases = [
("127.0.0.1", true),
("127.0.0.2", true),
("128.0.0.1", false),
("127.1.0.1", false),
("127.0.1.1", false),
];
for (ip, should_reject) in test_cases.iter() {
let input_message = create_mock_input_message("event_id", "new", ip);
let output_message = noteguard.run(input_message);
if *should_reject {
assert_eq!(output_message.action, Action::Reject);
} else {
assert_eq!(output_message.action, Action::Accept);
}
}
}
#[test]
fn test_deserialize_input_message() {
let input_json = r#"