diff options
author | Valentin Popov <valentin@popov.link> | 2024-01-08 00:21:28 +0300 |
---|---|---|
committer | Valentin Popov <valentin@popov.link> | 2024-01-08 00:21:28 +0300 |
commit | 1b6a04ca5504955c571d1c97504fb45ea0befee4 (patch) | |
tree | 7579f518b23313e8a9748a88ab6173d5e030b227 /vendor/smawk/tests/complexity.rs | |
parent | 5ecd8cf2cba827454317368b68571df0d13d7842 (diff) | |
download | fparkan-1b6a04ca5504955c571d1c97504fb45ea0befee4.tar.xz fparkan-1b6a04ca5504955c571d1c97504fb45ea0befee4.zip |
Initial vendor packages
Signed-off-by: Valentin Popov <valentin@popov.link>
Diffstat (limited to 'vendor/smawk/tests/complexity.rs')
-rw-r--r-- | vendor/smawk/tests/complexity.rs | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/vendor/smawk/tests/complexity.rs b/vendor/smawk/tests/complexity.rs new file mode 100644 index 0000000..c9881ea --- /dev/null +++ b/vendor/smawk/tests/complexity.rs @@ -0,0 +1,83 @@ +#![cfg(feature = "ndarray")] + +use ndarray::{Array1, Array2}; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use smawk::online_column_minima; + +mod random_monge; +use random_monge::random_monge_matrix; + +#[derive(Debug)] +struct LinRegression { + alpha: f64, + beta: f64, + r_squared: f64, +} + +/// Square an expression. Works equally well for floats and matrices. +macro_rules! squared { + ($x:expr) => { + $x * $x + }; +} + +/// Compute the mean of a 1-dimensional array. +macro_rules! mean { + ($a:expr) => { + $a.mean().expect("Mean of empty array") + }; +} + +/// Compute a simple linear regression from the list of values. +/// +/// See <https://en.wikipedia.org/wiki/Simple_linear_regression>. +fn linear_regression(values: &[(usize, i32)]) -> LinRegression { + let xs = values.iter().map(|&(x, _)| x as f64).collect::<Array1<_>>(); + let ys = values.iter().map(|&(_, y)| y as f64).collect::<Array1<_>>(); + + let xs_mean = mean!(&xs); + let ys_mean = mean!(&ys); + let xs_ys_mean = mean!(&xs * &ys); + + let cov_xs_ys = ((&xs - xs_mean) * (&ys - ys_mean)).sum(); + let var_xs = squared!(&xs - xs_mean).sum(); + + let beta = cov_xs_ys / var_xs; + let alpha = ys_mean - beta * xs_mean; + let r_squared = squared!(xs_ys_mean - xs_mean * ys_mean) + / ((mean!(&xs * &xs) - squared!(xs_mean)) * (mean!(&ys * &ys) - squared!(ys_mean))); + + LinRegression { + alpha: alpha, + beta: beta, + r_squared: r_squared, + } +} + +/// Check that the number of matrix accesses in `online_column_minima` +/// grows as O(*n*) for *n* ✕ *n* matrix. +#[test] +fn online_linear_complexity() { + let mut rng = ChaCha20Rng::seed_from_u64(0); + let mut data = vec![]; + + for &size in &[1, 2, 3, 4, 5, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100] { + let matrix: Array2<i32> = random_monge_matrix(size, size, &mut rng); + let count = std::cell::RefCell::new(0); + online_column_minima(0, size, |_, i, j| { + *count.borrow_mut() += 1; + matrix[[i, j]] + }); + data.push((size, count.into_inner())); + } + + let lin_reg = linear_regression(&data); + assert!( + lin_reg.r_squared > 0.95, + "r² = {:.4} is lower than expected for a linear fit\nData points: {:?}\n{:?}", + lin_reg.r_squared, + data, + lin_reg + ); +} |