Spaces:
Runtime error
Runtime error
| use std::error::Error; | |
| use std::io::{Read, Write}; | |
| use std::os::fd::OwnedFd; | |
| use std::sync::atomic::{AtomicBool, Ordering}; | |
| use std::sync::Arc; | |
| use interprocess::os::unix::unnamed_pipe::UnnamedPipeExt; | |
| use interprocess::unnamed_pipe::tokio::{Recver as TokioRecver, Sender as TokioSender}; | |
| use interprocess::unnamed_pipe::{Recver, Sender}; | |
| use nix::sys::signal::{SigSet, SigmaskHow}; | |
| use nix::unistd::{ForkResult, Pid}; | |
| use tokio::io::{AsyncReadExt, AsyncWriteExt}; | |
| use tokio::sync::{Mutex, RwLock}; | |
| use tokio_util::bytes::BufMut; | |
| pub struct PreforkedProcessPool { | |
| inner: Vec<( | |
| Arc<RwLock<Option<Result<Arc<Mutex<(TokioSender, TokioRecver)>>, std::io::Error>>>>, | |
| Pid, | |
| Arc<Mutex<Option<(OwnedFd, OwnedFd)>>>, | |
| )>, | |
| async_ipc_initialized: Arc<RwLock<AtomicBool>>, | |
| } | |
| impl PreforkedProcessPool { | |
| // This function is `unsafe`, due to forking function in `nix` crate also being `unsafe`. | |
| pub unsafe fn new( | |
| num_processes: usize, | |
| pool_fn: impl Fn(Sender, Recver), | |
| ) -> Result<Self, Box<dyn Error + Send + Sync>> { | |
| let mut processes = Vec::new(); | |
| for _ in 0..num_processes { | |
| // Create unnamed pipes | |
| let (tx_parent, rx_child) = interprocess::unnamed_pipe::pipe()?; | |
| let (tx_child, rx_parent) = interprocess::unnamed_pipe::pipe()?; | |
| // Set parent pipes to be non-blocking, because they'll be used in an asynchronous context | |
| tx_parent.set_nonblocking(true).unwrap_or_default(); | |
| rx_parent.set_nonblocking(true).unwrap_or_default(); | |
| // Obtain the file descriptors of the pipes | |
| let tx_parent_fd: OwnedFd = tx_parent.into(); | |
| let rx_parent_fd: OwnedFd = rx_parent.into(); | |
| let tx_child_fd: OwnedFd = tx_child.into(); | |
| let rx_child_fd: OwnedFd = rx_child.into(); | |
| match nix::unistd::fork() { | |
| Ok(ForkResult::Parent { child }) => { | |
| processes.push(( | |
| Arc::new(RwLock::new(None)), | |
| child, | |
| Arc::new(Mutex::new(Some((tx_parent_fd, rx_parent_fd)))), | |
| )); | |
| } | |
| Ok(ForkResult::Child) => { | |
| // Block all the signals | |
| nix::sys::signal::sigprocmask(SigmaskHow::SIG_SETMASK, Some(&SigSet::all()), None) | |
| .unwrap_or_default(); | |
| { | |
| // Stop child process after the parent process is stopped on Linux systems | |
| nix::sys::prctl::set_pdeathsig(nix::sys::signal::SIGKILL).unwrap_or_default(); | |
| } | |
| pool_fn(tx_child_fd.into(), rx_child_fd.into()); | |
| // Exit the process in the process pool | |
| std::process::exit(0); | |
| } | |
| Err(errno) => { | |
| Err(errno)?; | |
| } | |
| } | |
| } | |
| Ok(Self { | |
| inner: processes, | |
| async_ipc_initialized: Arc::new(RwLock::new(AtomicBool::new(false))), | |
| }) | |
| } | |
| pub async fn init_async_ipc(&self) { | |
| if !self | |
| .async_ipc_initialized | |
| .read() | |
| .await | |
| .load(Ordering::Relaxed) | |
| { | |
| for inner_process in &self.inner { | |
| let fds_option = inner_process.2.lock().await.take(); | |
| if let Some((tx_fd, rx_fd)) = fds_option { | |
| let ipc_io_result = match tx_fd.try_into() { | |
| Ok(tx) => match rx_fd.try_into() { | |
| Ok(rx) => Ok(Arc::new(Mutex::new((tx, rx)))), | |
| Err(err) => Err(err), | |
| }, | |
| Err(err) => Err(err), | |
| }; | |
| inner_process.0.write().await.replace(ipc_io_result); | |
| } | |
| } | |
| self | |
| .async_ipc_initialized | |
| .write() | |
| .await | |
| .store(true, Ordering::Relaxed); | |
| } | |
| } | |
| pub async fn obtain_process( | |
| &self, | |
| ) -> Result<Arc<Mutex<(TokioSender, TokioRecver)>>, Box<dyn Error + Send + Sync>> { | |
| if self.inner.is_empty() { | |
| Err(anyhow::anyhow!( | |
| "The process pool doesn't have any processes" | |
| ))? | |
| } else if self.inner.len() == 1 { | |
| let process_option = self.inner[0].0.read().await; | |
| let process = match process_option.as_ref() { | |
| Some(arc_mutex_result) => arc_mutex_result | |
| .as_ref() | |
| .map_err(|e| std::io::Error::new(e.kind(), e.to_string()))?, | |
| None => Err(anyhow::anyhow!("Asynchronous IPC not initialized yet"))?, | |
| }; | |
| Ok(process.clone()) | |
| } else { | |
| let first_random_choice = rand::random_range(0..self.inner.len()); | |
| let second_random_choice_reduced = rand::random_range(0..self.inner.len() - 1); | |
| let second_random_choice = if second_random_choice_reduced < first_random_choice { | |
| second_random_choice_reduced | |
| } else { | |
| second_random_choice_reduced + 1 | |
| }; | |
| let first_random_process_option = self.inner[first_random_choice].0.read().await; | |
| let second_random_process_option = self.inner[second_random_choice].0.read().await; | |
| let first_random_process = match first_random_process_option.as_ref() { | |
| Some(arc_mutex_result) => arc_mutex_result | |
| .as_ref() | |
| .map_err(|e| std::io::Error::new(e.kind(), e.to_string()))?, | |
| None => Err(anyhow::anyhow!("Asynchronous IPC not initialized yet"))?, | |
| }; | |
| let second_random_process = match second_random_process_option.as_ref() { | |
| Some(arc_mutex_result) => arc_mutex_result | |
| .as_ref() | |
| .map_err(|e| std::io::Error::new(e.kind(), e.to_string()))?, | |
| None => Err(anyhow::anyhow!("Asynchronous IPC not initialized yet"))?, | |
| }; | |
| let first_random_process_reference = Arc::strong_count(first_random_process); | |
| let second_random_process_reference = Arc::strong_count(second_random_process); | |
| if first_random_process_reference < second_random_process_reference { | |
| Ok(first_random_process.clone()) | |
| } else { | |
| Ok(second_random_process.clone()) | |
| } | |
| } | |
| } | |
| pub async fn obtain_process_with_init_async_ipc( | |
| &self, | |
| ) -> Result<Arc<Mutex<(TokioSender, TokioRecver)>>, Box<dyn Error + Send + Sync>> { | |
| self.init_async_ipc().await; | |
| self.obtain_process().await | |
| } | |
| } | |
| impl Drop for PreforkedProcessPool { | |
| fn drop(&mut self) { | |
| for inner_process in &self.inner { | |
| // Kill processes in the process pool when dropping the process pool | |
| nix::sys::signal::kill(inner_process.1, nix::sys::signal::SIGCHLD).unwrap_or_default(); | |
| } | |
| } | |
| } | |
| pub fn read_ipc_message(rx: &mut Recver) -> Result<Vec<u8>, std::io::Error> { | |
| let mut message_size_buffer = [0u8; 4]; | |
| rx.read_exact(&mut message_size_buffer)?; | |
| let message_size = u32::from_be_bytes(message_size_buffer); | |
| let mut buffer = vec![0u8; message_size as usize]; | |
| rx.read_exact(&mut buffer)?; | |
| Ok(buffer) | |
| } | |
| pub async fn read_ipc_message_async(rx: &mut TokioRecver) -> Result<Vec<u8>, std::io::Error> { | |
| let mut message_size_buffer = [0u8; 4]; | |
| rx.read_exact(&mut message_size_buffer).await?; | |
| let message_size = u32::from_be_bytes(message_size_buffer); | |
| let mut buffer = vec![0u8; message_size as usize]; | |
| rx.read_exact(&mut buffer).await?; | |
| Ok(buffer) | |
| } | |
| pub fn write_ipc_message(tx: &mut Sender, message: &[u8]) -> Result<(), std::io::Error> { | |
| let mut packet = Vec::new(); | |
| packet.put_slice(&(message.len() as u32).to_be_bytes()); | |
| packet.put_slice(message); | |
| tx.write_all(&packet)?; | |
| Ok(()) | |
| } | |
| pub async fn write_ipc_message_async( | |
| tx: &mut TokioSender, | |
| message: &[u8], | |
| ) -> Result<(), std::io::Error> { | |
| let mut packet = Vec::new(); | |
| packet.put_slice(&(message.len() as u32).to_be_bytes()); | |
| packet.put_slice(message); | |
| tx.write_all(&packet).await?; | |
| Ok(()) | |
| } | |
| mod tests { | |
| use super::*; | |
| use std::time::Duration; | |
| use tokio::time::timeout; | |
| fn dummy_pool_fn(mut tx: Sender, mut rx: Recver) { | |
| // Simulate child doing some work and echoing a message | |
| while let Ok(message) = read_ipc_message(&mut rx) { | |
| let _ = write_ipc_message(&mut tx, &message); | |
| } | |
| } | |
| async fn test_process_pool_creation() { | |
| let pool = unsafe { PreforkedProcessPool::new(2, dummy_pool_fn) }.unwrap(); | |
| assert_eq!(pool.inner.len(), 2); | |
| } | |
| async fn test_obtain_process_and_communication() { | |
| let pool = unsafe { PreforkedProcessPool::new(1, dummy_pool_fn) }.unwrap(); | |
| let proc = pool.obtain_process_with_init_async_ipc().await.unwrap(); | |
| let mut proc = proc.lock().await; | |
| let (tx, rx) = &mut *proc; | |
| // Write and read a message | |
| write_ipc_message_async(tx, b"hello").await.unwrap(); | |
| let message = timeout(Duration::from_secs(2), read_ipc_message_async(rx)) | |
| .await | |
| .expect("Timed out reading") | |
| .unwrap(); | |
| assert_eq!(&message, b"hello"); | |
| } | |
| async fn test_obtain_process_balancing() { | |
| let pool = unsafe { PreforkedProcessPool::new(3, dummy_pool_fn) }.unwrap(); | |
| let _p1 = pool.obtain_process_with_init_async_ipc().await.unwrap(); | |
| let _p2 = pool.obtain_process_with_init_async_ipc().await.unwrap(); | |
| let _p3 = pool.obtain_process_with_init_async_ipc().await.unwrap(); | |
| // This ensures reference counts differ | |
| let chosen = pool.obtain_process().await; | |
| assert!(chosen.is_ok()); | |
| } | |
| async fn test_obtain_process_empty_pool() { | |
| let empty_pool = PreforkedProcessPool { | |
| inner: Vec::new(), | |
| async_ipc_initialized: Arc::new(RwLock::new(AtomicBool::new(false))), | |
| }; | |
| let result = empty_pool.obtain_process_with_init_async_ipc().await; | |
| assert!(result.is_err()); | |
| } | |
| } | |