Spaces:
Runtime error
Runtime error
| use std::pin::Pin; | |
| use pyo3::prelude::*; | |
| use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt}; | |
| pub struct WsgiInputStream { | |
| body_reader: Pin<Box<dyn AsyncBufRead + Send + Sync>>, | |
| } | |
| impl WsgiInputStream { | |
| pub fn new(body_reader: impl AsyncBufRead + Send + Sync + 'static) -> Self { | |
| Self { | |
| body_reader: Box::pin(body_reader), | |
| } | |
| } | |
| } | |
| impl WsgiInputStream { | |
| fn read(&mut self, size: usize) -> PyResult<Vec<u8>> { | |
| let mut buffer = vec![0u8; size]; | |
| let read_bytes = futures_lite::future::block_on(self.body_reader.read(&mut buffer))?; | |
| Ok(buffer[0..read_bytes].to_vec()) | |
| } | |
| fn readline(&mut self, size: Option<isize>) -> PyResult<Vec<u8>> { | |
| let mut buffer = Vec::new(); | |
| let size = if size.is_none_or(|s| s < 0) { | |
| None | |
| } else { | |
| size.map(|s| s as usize) | |
| }; | |
| loop { | |
| let reader_buffer = futures_lite::future::block_on(self.body_reader.fill_buf())?.to_vec(); | |
| if reader_buffer.is_empty() { | |
| break; | |
| } | |
| if let Some(eol_position) = reader_buffer.iter().position(|&char| char == b'\n') { | |
| buffer.extend_from_slice( | |
| &reader_buffer[0..size.map_or(eol_position + 1, |size| { | |
| std::cmp::min(size, eol_position + 1) | |
| })], | |
| ); | |
| self.body_reader.consume(eol_position + 1); | |
| break; | |
| } else { | |
| buffer.extend_from_slice(&reader_buffer[0..size.unwrap_or(reader_buffer.len())]); | |
| self.body_reader.consume(reader_buffer.len()); | |
| } | |
| } | |
| Ok(buffer) | |
| } | |
| fn readlines(&mut self, hint: Option<isize>) -> PyResult<Vec<Vec<u8>>> { | |
| let mut total_bytes = 0; | |
| let mut lines = Vec::new(); | |
| let hint = if hint.is_none_or(|s| s < 0) { | |
| None | |
| } else { | |
| hint.map(|s| s as usize) | |
| }; | |
| loop { | |
| let mut line = Vec::new(); | |
| let bytes_read = | |
| futures_lite::future::block_on(self.body_reader.read_until(b'\n', &mut line))?; | |
| if bytes_read == 0 { | |
| break; | |
| } | |
| total_bytes += line.len(); | |
| lines.push(line); | |
| if hint.is_some_and(|hint| hint > total_bytes) { | |
| break; | |
| } | |
| } | |
| Ok(lines) | |
| } | |
| fn __iter__(this: PyRef<'_, Self>) -> PyRef<'_, Self> { | |
| this | |
| } | |
| fn __next__(&mut self) -> PyResult<Option<Vec<u8>>> { | |
| let line = self.readline(None)?; | |
| if line.is_empty() { | |
| // If a "readline()" function in WSGI input stream Python class returns 0 bytes (not even "\n"), it means EOF. | |
| Ok(None) | |
| } else { | |
| Ok(Some(line)) | |
| } | |
| } | |
| } | |
| mod tests { | |
| use super::*; | |
| use std::io::Cursor; | |
| use tokio::io::BufReader; | |
| fn create_stream(data: &str) -> WsgiInputStream { | |
| let cursor = Cursor::new(data.as_bytes().to_vec()); | |
| let reader = BufReader::new(cursor); | |
| WsgiInputStream::new(reader) | |
| } | |
| fn test_read() { | |
| let mut stream = create_stream("Hello, world!"); | |
| let result = stream.read(5).unwrap(); | |
| assert_eq!(result, b"Hello"); | |
| } | |
| fn test_read_full() { | |
| let mut stream = create_stream("Hello"); | |
| let result = stream.read(10).unwrap(); // try to read more than available | |
| assert_eq!(result, b"Hello"); | |
| } | |
| fn test_readline_no_limit() { | |
| let mut stream = create_stream("line1\nline2\n"); | |
| let result = stream.readline(None).unwrap(); | |
| assert_eq!(result, b"line1\n"); | |
| let result = stream.readline(None).unwrap(); | |
| assert_eq!(result, b"line2\n"); | |
| } | |
| fn test_readline_with_limit() { | |
| let mut stream = create_stream("line1\nline2\n"); | |
| let result = stream.readline(Some(3)).unwrap(); | |
| assert_eq!(result, b"lin"); // Only 3 bytes | |
| } | |
| fn test_readlines_no_hint() { | |
| let mut stream = create_stream("line1\nline2\nline3\n"); | |
| let result = stream.readlines(None).unwrap(); | |
| assert_eq!(result, vec![b"line1\n", b"line2\n", b"line3\n"]); | |
| } | |
| fn test_readlines_with_hint() { | |
| let mut stream = create_stream("line1\nline2\nline3\n"); | |
| let result = stream.readlines(Some(10)).unwrap(); // Should stop when bytes exceed 10 | |
| let total: usize = result.iter().map(|l| l.len()).sum(); | |
| assert!(total > 0 && total <= 10); | |
| } | |
| fn test_iterator_behavior() { | |
| let mut stream = create_stream("line1\nline2\n"); | |
| let mut results = Vec::new(); | |
| while let Some(line) = stream.__next__().unwrap() { | |
| results.push(line); | |
| } | |
| assert_eq!(results, vec![b"line1\n", b"line2\n"]); | |
| } | |
| fn test_iterator_eof() { | |
| let mut stream = create_stream(""); | |
| let result = stream.__next__().unwrap(); | |
| assert_eq!(result, None); | |
| } | |
| } | |