Spaces:
Runtime error
Runtime error
File size: 6,228 Bytes
9552aa0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | use memmem::{Searcher, TwoWaySearcher};
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
// Constant defining the capacity of the response buffer
const RESPONSE_BUFFER_CAPACITY: usize = 16384;
// Struct representing a response, which wraps an async read stream
pub struct CgiResponse<R>
where
R: AsyncRead + Unpin,
{
stream: R,
response_buf: Vec<u8>,
response_head_length: Option<usize>,
}
impl<R> CgiResponse<R>
where
R: AsyncRead + Unpin,
{
// Constructor to create a new CgiResponse instance
pub fn new(stream: R) -> Self {
Self {
stream,
response_buf: Vec::with_capacity(RESPONSE_BUFFER_CAPACITY),
response_head_length: None,
}
}
// Asynchronous method to get the response headers
pub async fn get_head(&mut self) -> Result<&[u8], Error> {
let mut temp_buf = [0u8; RESPONSE_BUFFER_CAPACITY];
let rnrn = TwoWaySearcher::new(b"\r\n\r\n");
let nrnr = TwoWaySearcher::new(b"\n\r\n\r");
let nn = TwoWaySearcher::new(b"\n\n");
let rr = TwoWaySearcher::new(b"\r\r");
let to_parse_length;
loop {
// Read data from the stream into the temporary buffer
let read_bytes = self.stream.read(&mut temp_buf).await?;
// If no bytes are read, return an empty response head
if read_bytes == 0 {
self.response_head_length = Some(0);
return Ok(&[0u8; 0]);
}
// If the response buffer exceeds the capacity, return an empty response head
if self.response_buf.len() + read_bytes > RESPONSE_BUFFER_CAPACITY {
self.response_head_length = Some(0);
return Ok(&[0u8; 0]);
}
// Determine the starting point for searching the "\r\n\r\n" sequence
let begin_rnrn_or_nrnr_search = self.response_buf.len().saturating_sub(3);
let begin_rr_or_nn_search = self.response_buf.len().saturating_sub(1);
self.response_buf.extend_from_slice(&temp_buf[..read_bytes]);
// Search for the "\r\n\r\n" sequence in the response buffer
if let Some(rnrn_index) = rnrn.search_in(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
to_parse_length = begin_rnrn_or_nrnr_search + rnrn_index + 4;
break;
} else if let Some(nrnr_index) =
nrnr.search_in(&self.response_buf[begin_rnrn_or_nrnr_search..])
{
to_parse_length = begin_rnrn_or_nrnr_search + nrnr_index + 4;
break;
} else if let Some(nn_index) = nn.search_in(&self.response_buf[begin_rr_or_nn_search..]) {
to_parse_length = begin_rr_or_nn_search + nn_index + 2;
break;
} else if let Some(rr_index) = rr.search_in(&self.response_buf[begin_rr_or_nn_search..]) {
to_parse_length = begin_rr_or_nn_search + rr_index + 2;
break;
}
}
// Set the length of the response header
self.response_head_length = Some(to_parse_length);
// Return the response header as a byte slice
Ok(&self.response_buf[..to_parse_length])
}
}
// Implementation of AsyncRead for the CgiResponse struct
impl<R> AsyncRead for CgiResponse<R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
// If the response header length is known and the buffer contains more data than the header length
if let Some(response_head_length) = self.response_head_length {
if self.response_buf.len() > response_head_length {
let remaining_data = &self.response_buf[response_head_length..];
let to_read = remaining_data.len().min(buf.remaining());
buf.put_slice(&remaining_data[..to_read]);
self.response_head_length = Some(response_head_length + to_read);
return Poll::Ready(Ok(()));
}
}
// Create a temporary buffer to hold the data to be consumed
let stream = Pin::new(&mut self.stream);
match stream.poll_read(cx, buf) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
other => other,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncReadExt;
use tokio_test::io::Builder;
#[tokio::test]
async fn test_get_head() {
let data = b"Content-Type: text/plain\r\n\r\n";
let mut stream = Builder::new().read(data).build();
let mut response = CgiResponse::new(&mut stream);
let head = response.get_head().await.unwrap();
assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
}
#[tokio::test]
async fn test_get_head_nn() {
let data = b"Content-Type: text/plain\n\n";
let mut stream = Builder::new().read(data).build();
let mut response = CgiResponse::new(&mut stream);
let head = response.get_head().await.unwrap();
assert_eq!(head, b"Content-Type: text/plain\n\n");
}
#[tokio::test]
async fn test_get_head_large_headers() {
let data = b"Content-Type: text/plain\r\n";
let large_header = vec![b'A'; RESPONSE_BUFFER_CAPACITY + 10]
.into_iter()
.collect::<Vec<u8>>();
let mut stream = Builder::new().read(data).read(&large_header).build();
let mut response = CgiResponse::new(&mut stream);
let result = response.get_head().await;
assert_eq!(result.unwrap().len(), 0);
// Consume the remaining data to avoid panicking
let mut remaining_data = vec![0u8; RESPONSE_BUFFER_CAPACITY + 10];
let _ = response.stream.read(&mut remaining_data).await;
}
#[tokio::test]
async fn test_get_head_premature_eof() {
let data = b"Content-Type: text/plain\r\n";
let mut stream = Builder::new().read(data).build();
let mut response = CgiResponse::new(&mut stream);
let result = response.get_head().await;
assert_eq!(result.unwrap().len(), 0);
}
#[tokio::test]
async fn test_poll_read() {
let data = b"Content-Type: text/plain\r\n\r\nHello, world!";
let mut stream = Builder::new().read(data).build();
let mut response = CgiResponse::new(&mut stream);
let head = response.get_head().await.unwrap();
assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
let mut buf = vec![0u8; 13];
let n = response.read(&mut buf).await.unwrap();
assert_eq!(n, 13);
assert_eq!(&buf[..n], b"Hello, world!");
}
}
|