diff options
Diffstat (limited to 'vendor/smawk/tests/complexity.rs')
-rw-r--r-- | vendor/smawk/tests/complexity.rs | 83 |
1 files changed, 0 insertions, 83 deletions
diff --git a/vendor/smawk/tests/complexity.rs b/vendor/smawk/tests/complexity.rs deleted file mode 100644 index c9881ea..0000000 --- a/vendor/smawk/tests/complexity.rs +++ /dev/null @@ -1,83 +0,0 @@ -#![cfg(feature = "ndarray")] - -use ndarray::{Array1, Array2}; -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; -use smawk::online_column_minima; - -mod random_monge; -use random_monge::random_monge_matrix; - -#[derive(Debug)] -struct LinRegression { - alpha: f64, - beta: f64, - r_squared: f64, -} - -/// Square an expression. Works equally well for floats and matrices. -macro_rules! squared { - ($x:expr) => { - $x * $x - }; -} - -/// Compute the mean of a 1-dimensional array. -macro_rules! mean { - ($a:expr) => { - $a.mean().expect("Mean of empty array") - }; -} - -/// Compute a simple linear regression from the list of values. -/// -/// See <https://en.wikipedia.org/wiki/Simple_linear_regression>. -fn linear_regression(values: &[(usize, i32)]) -> LinRegression { - let xs = values.iter().map(|&(x, _)| x as f64).collect::<Array1<_>>(); - let ys = values.iter().map(|&(_, y)| y as f64).collect::<Array1<_>>(); - - let xs_mean = mean!(&xs); - let ys_mean = mean!(&ys); - let xs_ys_mean = mean!(&xs * &ys); - - let cov_xs_ys = ((&xs - xs_mean) * (&ys - ys_mean)).sum(); - let var_xs = squared!(&xs - xs_mean).sum(); - - let beta = cov_xs_ys / var_xs; - let alpha = ys_mean - beta * xs_mean; - let r_squared = squared!(xs_ys_mean - xs_mean * ys_mean) - / ((mean!(&xs * &xs) - squared!(xs_mean)) * (mean!(&ys * &ys) - squared!(ys_mean))); - - LinRegression { - alpha: alpha, - beta: beta, - r_squared: r_squared, - } -} - -/// Check that the number of matrix accesses in `online_column_minima` -/// grows as O(*n*) for *n* ✕ *n* matrix. -#[test] -fn online_linear_complexity() { - let mut rng = ChaCha20Rng::seed_from_u64(0); - let mut data = vec![]; - - for &size in &[1, 2, 3, 4, 5, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100] { - let matrix: Array2<i32> = random_monge_matrix(size, size, &mut rng); - let count = std::cell::RefCell::new(0); - online_column_minima(0, size, |_, i, j| { - *count.borrow_mut() += 1; - matrix[[i, j]] - }); - data.push((size, count.into_inner())); - } - - let lin_reg = linear_regression(&data); - assert!( - lin_reg.r_squared > 0.95, - "r² = {:.4} is lower than expected for a linear fit\nData points: {:?}\n{:?}", - lin_reg.r_squared, - data, - lin_reg - ); -} |