From 08185977cc5049228c43f905f37e6d4b07c59e5b Mon Sep 17 00:00:00 2001 From: Fukusuke Takahashi <41001169+fukusuket@users.noreply.github.com> Date: Sat, 13 Aug 2022 19:56:30 +0900 Subject: [PATCH] fix race condition in insert_message. #639 (#660) --- Cargo.lock | 37 +++++++++++++++++++++++++ Cargo.toml | 3 +++ src/detections/message.rs | 57 +++++++++++++++++++++++++++++++++------ 3 files changed, 89 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e6bcafe3..ec99fbd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -770,6 +770,7 @@ dependencies = [ "pbr", "prettytable-rs", "quick-xml", + "rand", "regex", "serde", "serde_derive", @@ -1356,6 +1357,12 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + [[package]] name = "prettytable-rs" version = "0.8.0" @@ -1451,6 +1458,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +dependencies = [ + "getrandom 0.2.7", +] + [[package]] name = "rayon" version = "1.5.3" diff --git a/Cargo.toml b/Cargo.toml index 388881a4..b2ed103a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,9 @@ num-format = "*" [build-dependencies] static_vcruntime = "2.*" +[dev-dependencies] +rand = "0.8.*" + [target.'cfg(windows)'.dependencies] is_elevated = "0.1.*" diff --git a/src/detections/message.rs b/src/detections/message.rs index 63940876..55b33fe2 100644 --- a/src/detections/message.rs +++ b/src/detections/message.rs @@ -112,13 +112,9 @@ pub fn create_output_filter_config(path: &str) -> HashMap { /// メッセージの設定を行う関数。aggcondition対応のためrecordではなく出力をする対象時間がDatetime形式での入力としている pub fn insert_message(detect_info: DetectInfo, event_time: DateTime) { - if let Some(mut v) = MESSAGES.get_mut(&event_time) { - let (_, info) = v.pair_mut(); - info.push(detect_info); - } else { - let m = vec![detect_info; 1]; - MESSAGES.insert(event_time, m); - } + let mut v = MESSAGES.entry(event_time).or_default(); + let (_, info) = v.pair_mut(); + info.push(detect_info); } /// メッセージを設定 @@ -358,10 +354,14 @@ impl AlertMessage { #[cfg(test)] mod tests { - use crate::detections::message::AlertMessage; + use crate::detections::message::{get, insert_message, AlertMessage, DetectInfo}; use crate::detections::message::{parse_message, MESSAGES}; + use chrono::Utc; use hashbrown::HashMap; + use rand::Rng; use serde_json::Value; + use std::thread; + use std::time::Duration; use super::{create_output_filter_config, get_default_details}; @@ -621,4 +621,45 @@ mod tests { assert!(actual.get(k).unwrap_or(&String::default()) == v); } } + + #[ignore] + #[test] + fn test_insert_message_race_condition() { + MESSAGES.clear(); + + // Setup test detect_info before starting threads. + let mut sample_detects = vec![]; + let mut rng = rand::thread_rng(); + let sample_event_time = Utc::now(); + for i in 1..2001 { + let detect_info = DetectInfo { + rulepath: "".to_string(), + level: "".to_string(), + computername: "".to_string(), + eventid: i.to_string(), + detail: "".to_string(), + record_information: None, + ext_field: Default::default(), + }; + sample_detects.push((sample_event_time, detect_info, rng.gen_range(0..10))); + } + + // Starting threads and randomly insert_message in parallel. + let mut handles = vec![]; + for (event_time, detect_info, random_num) in sample_detects { + let handle = thread::spawn(move || { + thread::sleep(Duration::from_micros(random_num)); + insert_message(detect_info, event_time); + }); + handles.push(handle); + } + + // Wait for all threads execution completion. + for handle in handles { + handle.join().unwrap(); + } + + // Expect all sample_detects to be included, but the len() size will be different each time I run it + assert_eq!(get(sample_event_time).len(), 2000) + } }