Explorar el Código

incorporate file sizes into message lengths

Justin Tracey hace 1 año
padre
commit
903b430b76
Se han modificado 3 ficheros con 56 adiciones y 3 borrados
  1. 3 0
      Cargo.toml
  2. 22 3
      src/bin/message-lens.rs
  3. 31 0
      src/lib.rs

+ 3 - 0
Cargo.toml

@@ -6,10 +6,13 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 
 [dependencies]
 [dependencies]
+anyhow = "1.0.75"
 counter = "0.5.7"
 counter = "0.5.7"
 glob = "0.3.1"
 glob = "0.3.1"
 itertools = "0.11.0"
 itertools = "0.11.0"
 pyo3 = "0.19.0"
 pyo3 = "0.19.0"
+rand = "0.8.5"
+rand_distr = "0.4.3"
 rayon = "1.7.0"
 rayon = "1.7.0"
 serde = { version = "1.0.164", features = ["derive"] }
 serde = { version = "1.0.164", features = ["derive"] }
 serde_json = "1.0.96"
 serde_json = "1.0.96"

+ 22 - 3
src/bin/message-lens.rs

@@ -1,3 +1,4 @@
+use rand_distr::Distribution;
 use rayon::prelude::*;
 use rayon::prelude::*;
 use sam_extractor::*;
 use sam_extractor::*;
 use std::collections::HashMap;
 use std::collections::HashMap;
@@ -15,10 +16,21 @@ fn main() {
     let this_program = args.next().unwrap();
     let this_program = args.next().unwrap();
 
 
     if args.len() < 2 {
     if args.len() < 2 {
-        panic!("Usage: {} stats_directory chat.json...", this_program);
+        panic!(
+            "Usage: {} [-s file_sizes.dat] stats_directory chat.json...",
+            this_program
+        );
     }
     }
 
 
-    let dists_dir = args.next().unwrap();
+    let first_arg = args.next().unwrap();
+    let (file_sizes, dists_dir) = if first_arg != "-s" {
+        (None, first_arg)
+    } else {
+        (
+            Some(parse_weights_file(args.next().unwrap()).unwrap()),
+            args.next().unwrap(),
+        )
+    };
 
 
     let conversations = args
     let conversations = args
         .flat_map(|a| glob::glob(a.as_str()).unwrap())
         .flat_map(|a| glob::glob(a.as_str()).unwrap())
@@ -33,9 +45,16 @@ fn main() {
         .collect::<Vec<_>>();
         .collect::<Vec<_>>();
 
 
     let mut users: HashMap<UserId, Vec<usize>> = HashMap::new();
     let mut users: HashMap<UserId, Vec<usize>> = HashMap::new();
+    let mut rng = rand::thread_rng();
     for conversation in conversations {
     for conversation in conversations {
         for message in conversation.messages {
         for message in conversation.messages {
-            let message_len = bytes_to_blocks(message.char_count + message.emoji_count as i32 * 4);
+            let file_size = if let Some((ref dist, ref sizes)) = file_sizes {
+                sizes[dist.sample(&mut rng)]
+            } else {
+                0
+            };
+            let message_len =
+                bytes_to_blocks(message.char_count + message.emoji_count as i32 * 4 + file_size);
             if let Some(lens) = users.get_mut(&message.user) {
             if let Some(lens) = users.get_mut(&message.user) {
                 lens.push(message_len);
                 lens.push(message_len);
             } else {
             } else {

+ 31 - 0
src/lib.rs

@@ -4,6 +4,7 @@ use serde_repr::Deserialize_repr;
 use std::cmp::min;
 use std::cmp::min;
 use std::collections::{BTreeMap, HashMap};
 use std::collections::{BTreeMap, HashMap};
 use std::io::Write;
 use std::io::Write;
+use std::str::FromStr;
 use time::{Duration, OffsetDateTime as Time};
 use time::{Duration, OffsetDateTime as Time};
 
 
 pub type UserId = i32;
 pub type UserId = i32;
@@ -223,3 +224,33 @@ where
 
 
     std::fs::write(file_path, data.as_bytes())
     std::fs::write(file_path, data.as_bytes())
 }
 }
+
+use rand_distr::WeightedAliasIndex;
+
+pub fn parse_weights_file<T>(path: String) -> anyhow::Result<(WeightedAliasIndex<u32>, Vec<T>)>
+where
+    T: FromStr,
+    <T as FromStr>::Err: std::error::Error,
+{
+    let weights_file = std::fs::read_to_string(path)?;
+    let mut weights_lines = weights_file.lines();
+    let weights = weights_lines
+        .next()
+        .unwrap()
+        .split(',')
+        .map(u32::from_str)
+        .collect::<Result<Vec<_>, _>>()?;
+    let vals = weights_lines
+        .next()
+        .expect("Weights file only has one line")
+        .split(',')
+        .map(T::from_str)
+        .collect::<Result<Vec<_>, _>>()
+        .unwrap();
+    assert!(
+        weights.len() == vals.len(),
+        "Weights file doesn't have the same number of weights and values."
+    );
+    let dist = WeightedAliasIndex::<u32>::new(weights)?;
+    Ok((dist, vals))
+}