aboutsummaryrefslogtreecommitdiff
path: root/vendor/rayon-core/src/broadcast/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/rayon-core/src/broadcast/mod.rs')
-rw-r--r--vendor/rayon-core/src/broadcast/mod.rs150
1 files changed, 150 insertions, 0 deletions
diff --git a/vendor/rayon-core/src/broadcast/mod.rs b/vendor/rayon-core/src/broadcast/mod.rs
new file mode 100644
index 0000000..96611e4
--- /dev/null
+++ b/vendor/rayon-core/src/broadcast/mod.rs
@@ -0,0 +1,150 @@
+use crate::job::{ArcJob, StackJob};
+use crate::latch::{CountLatch, LatchRef};
+use crate::registry::{Registry, WorkerThread};
+use std::fmt;
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+mod test;
+
+/// Executes `op` within every thread in the current threadpool. If this is
+/// called from a non-Rayon thread, it will execute in the global threadpool.
+/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
+/// within that threadpool. When the call has completed on each thread, returns
+/// a vector containing all of their return values.
+///
+/// For more information, see the [`ThreadPool::broadcast()`][m] method.
+///
+/// [m]: struct.ThreadPool.html#method.broadcast
+pub fn broadcast<OP, R>(op: OP) -> Vec<R>
+where
+ OP: Fn(BroadcastContext<'_>) -> R + Sync,
+ R: Send,
+{
+ // We assert that current registry has not terminated.
+ unsafe { broadcast_in(op, &Registry::current()) }
+}
+
+/// Spawns an asynchronous task on every thread in this thread-pool. This task
+/// will run in the implicit, global scope, which means that it may outlast the
+/// current stack frame -- therefore, it cannot capture any references onto the
+/// stack (you will likely need a `move` closure).
+///
+/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
+///
+/// [m]: struct.ThreadPool.html#method.spawn_broadcast
+pub fn spawn_broadcast<OP>(op: OP)
+where
+ OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
+{
+ // We assert that current registry has not terminated.
+ unsafe { spawn_broadcast_in(op, &Registry::current()) }
+}
+
+/// Provides context to a closure called by `broadcast`.
+pub struct BroadcastContext<'a> {
+ worker: &'a WorkerThread,
+
+ /// Make sure to prevent auto-traits like `Send` and `Sync`.
+ _marker: PhantomData<&'a mut dyn Fn()>,
+}
+
+impl<'a> BroadcastContext<'a> {
+ pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
+ let worker_thread = WorkerThread::current();
+ assert!(!worker_thread.is_null());
+ f(BroadcastContext {
+ worker: unsafe { &*worker_thread },
+ _marker: PhantomData,
+ })
+ }
+
+ /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
+ #[inline]
+ pub fn index(&self) -> usize {
+ self.worker.index()
+ }
+
+ /// The number of threads receiving the broadcast in the thread pool.
+ ///
+ /// # Future compatibility note
+ ///
+ /// Future versions of Rayon might vary the number of threads over time, but
+ /// this method will always return the number of threads which are actually
+ /// receiving your particular `broadcast` call.
+ #[inline]
+ pub fn num_threads(&self) -> usize {
+ self.worker.registry().num_threads()
+ }
+}
+
+impl<'a> fmt::Debug for BroadcastContext<'a> {
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ fmt.debug_struct("BroadcastContext")
+ .field("index", &self.index())
+ .field("num_threads", &self.num_threads())
+ .field("pool_id", &self.worker.registry().id())
+ .finish()
+ }
+}
+
+/// Execute `op` on every thread in the pool. It will be executed on each
+/// thread when they have nothing else to do locally, before they try to
+/// steal work from other threads. This function will not return until all
+/// threads have completed the `op`.
+///
+/// Unsafe because `registry` must not yet have terminated.
+pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
+where
+ OP: Fn(BroadcastContext<'_>) -> R + Sync,
+ R: Send,
+{
+ let f = move |injected: bool| {
+ debug_assert!(injected);
+ BroadcastContext::with(&op)
+ };
+
+ let n_threads = registry.num_threads();
+ let current_thread = WorkerThread::current().as_ref();
+ let latch = CountLatch::with_count(n_threads, current_thread);
+ let jobs: Vec<_> = (0..n_threads)
+ .map(|_| StackJob::new(&f, LatchRef::new(&latch)))
+ .collect();
+ let job_refs = jobs.iter().map(|job| job.as_job_ref());
+
+ registry.inject_broadcast(job_refs);
+
+ // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
+ latch.wait(current_thread);
+ jobs.into_iter().map(|job| job.into_result()).collect()
+}
+
+/// Execute `op` on every thread in the pool. It will be executed on each
+/// thread when they have nothing else to do locally, before they try to
+/// steal work from other threads. This function returns immediately after
+/// injecting the jobs.
+///
+/// Unsafe because `registry` must not yet have terminated.
+pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
+where
+ OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
+{
+ let job = ArcJob::new({
+ let registry = Arc::clone(registry);
+ move || {
+ registry.catch_unwind(|| BroadcastContext::with(&op));
+ registry.terminate(); // (*) permit registry to terminate now
+ }
+ });
+
+ let n_threads = registry.num_threads();
+ let job_refs = (0..n_threads).map(|_| {
+ // Ensure that registry cannot terminate until this job has executed
+ // on each thread. This ref is decremented at the (*) above.
+ registry.increment_terminate_count();
+
+ ArcJob::as_static_job_ref(&job)
+ });
+
+ registry.inject_broadcast(job_refs);
+}