diff options
Diffstat (limited to 'vendor/rayon-core/src/broadcast')
-rw-r--r-- | vendor/rayon-core/src/broadcast/mod.rs | 150 | ||||
-rw-r--r-- | vendor/rayon-core/src/broadcast/test.rs | 263 |
2 files changed, 413 insertions, 0 deletions
diff --git a/vendor/rayon-core/src/broadcast/mod.rs b/vendor/rayon-core/src/broadcast/mod.rs new file mode 100644 index 0000000..96611e4 --- /dev/null +++ b/vendor/rayon-core/src/broadcast/mod.rs @@ -0,0 +1,150 @@ +use crate::job::{ArcJob, StackJob}; +use crate::latch::{CountLatch, LatchRef}; +use crate::registry::{Registry, WorkerThread}; +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + +mod test; + +/// Executes `op` within every thread in the current threadpool. If this is +/// called from a non-Rayon thread, it will execute in the global threadpool. +/// Any attempts to use `join`, `scope`, or parallel iterators will then operate +/// within that threadpool. When the call has completed on each thread, returns +/// a vector containing all of their return values. +/// +/// For more information, see the [`ThreadPool::broadcast()`][m] method. +/// +/// [m]: struct.ThreadPool.html#method.broadcast +pub fn broadcast<OP, R>(op: OP) -> Vec<R> +where + OP: Fn(BroadcastContext<'_>) -> R + Sync, + R: Send, +{ + // We assert that current registry has not terminated. + unsafe { broadcast_in(op, &Registry::current()) } +} + +/// Spawns an asynchronous task on every thread in this thread-pool. This task +/// will run in the implicit, global scope, which means that it may outlast the +/// current stack frame -- therefore, it cannot capture any references onto the +/// stack (you will likely need a `move` closure). +/// +/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method. +/// +/// [m]: struct.ThreadPool.html#method.spawn_broadcast +pub fn spawn_broadcast<OP>(op: OP) +where + OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, +{ + // We assert that current registry has not terminated. + unsafe { spawn_broadcast_in(op, &Registry::current()) } +} + +/// Provides context to a closure called by `broadcast`. +pub struct BroadcastContext<'a> { + worker: &'a WorkerThread, + + /// Make sure to prevent auto-traits like `Send` and `Sync`. + _marker: PhantomData<&'a mut dyn Fn()>, +} + +impl<'a> BroadcastContext<'a> { + pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { + let worker_thread = WorkerThread::current(); + assert!(!worker_thread.is_null()); + f(BroadcastContext { + worker: unsafe { &*worker_thread }, + _marker: PhantomData, + }) + } + + /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). + #[inline] + pub fn index(&self) -> usize { + self.worker.index() + } + + /// The number of threads receiving the broadcast in the thread pool. + /// + /// # Future compatibility note + /// + /// Future versions of Rayon might vary the number of threads over time, but + /// this method will always return the number of threads which are actually + /// receiving your particular `broadcast` call. + #[inline] + pub fn num_threads(&self) -> usize { + self.worker.registry().num_threads() + } +} + +impl<'a> fmt::Debug for BroadcastContext<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("BroadcastContext") + .field("index", &self.index()) + .field("num_threads", &self.num_threads()) + .field("pool_id", &self.worker.registry().id()) + .finish() + } +} + +/// Execute `op` on every thread in the pool. It will be executed on each +/// thread when they have nothing else to do locally, before they try to +/// steal work from other threads. This function will not return until all +/// threads have completed the `op`. +/// +/// Unsafe because `registry` must not yet have terminated. +pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R> +where + OP: Fn(BroadcastContext<'_>) -> R + Sync, + R: Send, +{ + let f = move |injected: bool| { + debug_assert!(injected); + BroadcastContext::with(&op) + }; + + let n_threads = registry.num_threads(); + let current_thread = WorkerThread::current().as_ref(); + let latch = CountLatch::with_count(n_threads, current_thread); + let jobs: Vec<_> = (0..n_threads) + .map(|_| StackJob::new(&f, LatchRef::new(&latch))) + .collect(); + let job_refs = jobs.iter().map(|job| job.as_job_ref()); + + registry.inject_broadcast(job_refs); + + // Wait for all jobs to complete, then collect the results, maybe propagating a panic. + latch.wait(current_thread); + jobs.into_iter().map(|job| job.into_result()).collect() +} + +/// Execute `op` on every thread in the pool. It will be executed on each +/// thread when they have nothing else to do locally, before they try to +/// steal work from other threads. This function returns immediately after +/// injecting the jobs. +/// +/// Unsafe because `registry` must not yet have terminated. +pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>) +where + OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, +{ + let job = ArcJob::new({ + let registry = Arc::clone(registry); + move || { + registry.catch_unwind(|| BroadcastContext::with(&op)); + registry.terminate(); // (*) permit registry to terminate now + } + }); + + let n_threads = registry.num_threads(); + let job_refs = (0..n_threads).map(|_| { + // Ensure that registry cannot terminate until this job has executed + // on each thread. This ref is decremented at the (*) above. + registry.increment_terminate_count(); + + ArcJob::as_static_job_ref(&job) + }); + + registry.inject_broadcast(job_refs); +} diff --git a/vendor/rayon-core/src/broadcast/test.rs b/vendor/rayon-core/src/broadcast/test.rs new file mode 100644 index 0000000..00ab4ad --- /dev/null +++ b/vendor/rayon-core/src/broadcast/test.rs @@ -0,0 +1,263 @@ +#![cfg(test)] + +use crate::ThreadPoolBuilder; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::channel; +use std::sync::Arc; +use std::{thread, time}; + +#[test] +fn broadcast_global() { + let v = crate::broadcast(|ctx| ctx.index()); + assert!(v.into_iter().eq(0..crate::current_num_threads())); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn spawn_broadcast_global() { + let (tx, rx) = channel(); + crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); + + let mut v: Vec<_> = rx.into_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..crate::current_num_threads())); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn broadcast_pool() { + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let v = pool.broadcast(|ctx| ctx.index()); + assert!(v.into_iter().eq(0..7)); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn spawn_broadcast_pool() { + let (tx, rx) = channel(); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool.spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); + + let mut v: Vec<_> = rx.into_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..7)); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn broadcast_self() { + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let v = pool.install(|| crate::broadcast(|ctx| ctx.index())); + assert!(v.into_iter().eq(0..7)); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn spawn_broadcast_self() { + let (tx, rx) = channel(); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool.spawn(|| crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap())); + + let mut v: Vec<_> = rx.into_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..7)); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn broadcast_mutual() { + let count = AtomicUsize::new(0); + let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.install(|| { + pool2.broadcast(|_| { + pool1.broadcast(|_| { + count.fetch_add(1, Ordering::Relaxed); + }) + }) + }); + assert_eq!(count.into_inner(), 3 * 7); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn spawn_broadcast_mutual() { + let (tx, rx) = channel(); + let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap()); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.spawn({ + let pool1 = Arc::clone(&pool1); + move || { + pool2.spawn_broadcast(move |_| { + let tx = tx.clone(); + pool1.spawn_broadcast(move |_| tx.send(()).unwrap()) + }) + } + }); + assert_eq!(rx.into_iter().count(), 3 * 7); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn broadcast_mutual_sleepy() { + let count = AtomicUsize::new(0); + let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.install(|| { + thread::sleep(time::Duration::from_secs(1)); + pool2.broadcast(|_| { + thread::sleep(time::Duration::from_secs(1)); + pool1.broadcast(|_| { + thread::sleep(time::Duration::from_millis(100)); + count.fetch_add(1, Ordering::Relaxed); + }) + }) + }); + assert_eq!(count.into_inner(), 3 * 7); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn spawn_broadcast_mutual_sleepy() { + let (tx, rx) = channel(); + let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap()); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.spawn({ + let pool1 = Arc::clone(&pool1); + move || { + thread::sleep(time::Duration::from_secs(1)); + pool2.spawn_broadcast(move |_| { + let tx = tx.clone(); + thread::sleep(time::Duration::from_secs(1)); + pool1.spawn_broadcast(move |_| { + thread::sleep(time::Duration::from_millis(100)); + tx.send(()).unwrap(); + }) + }) + } + }); + assert_eq!(rx.into_iter().count(), 3 * 7); +} + +#[test] +#[cfg_attr(not(panic = "unwind"), ignore)] +fn broadcast_panic_one() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.broadcast(|ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() == 3 { + panic!("Hello, world!"); + } + }) + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} + +#[test] +#[cfg_attr(not(panic = "unwind"), ignore)] +fn spawn_broadcast_panic_one() { + let (tx, rx) = channel(); + let (panic_tx, panic_rx) = channel(); + let pool = ThreadPoolBuilder::new() + .num_threads(7) + .panic_handler(move |e| panic_tx.send(e).unwrap()) + .build() + .unwrap(); + pool.spawn_broadcast(move |ctx| { + tx.send(()).unwrap(); + if ctx.index() == 3 { + panic!("Hello, world!"); + } + }); + drop(pool); // including panic_tx + assert_eq!(rx.into_iter().count(), 7); + assert_eq!(panic_rx.into_iter().count(), 1); +} + +#[test] +#[cfg_attr(not(panic = "unwind"), ignore)] +fn broadcast_panic_many() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.broadcast(|ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() % 2 == 0 { + panic!("Hello, world!"); + } + }) + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} + +#[test] +#[cfg_attr(not(panic = "unwind"), ignore)] +fn spawn_broadcast_panic_many() { + let (tx, rx) = channel(); + let (panic_tx, panic_rx) = channel(); + let pool = ThreadPoolBuilder::new() + .num_threads(7) + .panic_handler(move |e| panic_tx.send(e).unwrap()) + .build() + .unwrap(); + pool.spawn_broadcast(move |ctx| { + tx.send(()).unwrap(); + if ctx.index() % 2 == 0 { + panic!("Hello, world!"); + } + }); + drop(pool); // including panic_tx + assert_eq!(rx.into_iter().count(), 7); + assert_eq!(panic_rx.into_iter().count(), 4); +} + +#[test] +#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)] +fn broadcast_sleep_race() { + let test_duration = time::Duration::from_secs(1); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let start = time::Instant::now(); + while start.elapsed() < test_duration { + pool.broadcast(|ctx| { + // A slight spread of sleep duration increases the chance that one + // of the threads will race in the pool's idle sleep afterward. + thread::sleep(time::Duration::from_micros(ctx.index() as u64)); + }); + } +} + +#[test] +fn broadcast_after_spawn_broadcast() { + let (tx, rx) = channel(); + + // Queue a non-blocking spawn_broadcast. + crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); + + // This blocking broadcast runs after all prior broadcasts. + crate::broadcast(|_| {}); + + // The spawn_broadcast **must** have run by now on all threads. + let mut v: Vec<_> = rx.try_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..crate::current_num_threads())); +} + +#[test] +fn broadcast_after_spawn() { + let (tx, rx) = channel(); + + // Queue a regular spawn on a thread-local deque. + crate::registry::in_worker(move |_, _| { + crate::spawn(move || tx.send(22).unwrap()); + }); + + // Broadcast runs after the local deque is empty. + crate::broadcast(|_| {}); + + // The spawn **must** have run by now. + assert_eq!(22, rx.try_recv().unwrap()); +} |