File size: 6,228 Bytes
9552aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
use memmem::{Searcher, TwoWaySearcher};
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};

// Constant defining the capacity of the response buffer
const RESPONSE_BUFFER_CAPACITY: usize = 16384;

// Struct representing a response, which wraps an async read stream
pub struct CgiResponse<R>
where
  R: AsyncRead + Unpin,
{
  stream: R,
  response_buf: Vec<u8>,
  response_head_length: Option<usize>,
}

impl<R> CgiResponse<R>
where
  R: AsyncRead + Unpin,
{
  // Constructor to create a new CgiResponse instance
  pub fn new(stream: R) -> Self {
    Self {
      stream,
      response_buf: Vec::with_capacity(RESPONSE_BUFFER_CAPACITY),
      response_head_length: None,
    }
  }

  // Asynchronous method to get the response headers
  pub async fn get_head(&mut self) -> Result<&[u8], Error> {
    let mut temp_buf = [0u8; RESPONSE_BUFFER_CAPACITY];
    let rnrn = TwoWaySearcher::new(b"\r\n\r\n");
    let nrnr = TwoWaySearcher::new(b"\n\r\n\r");
    let nn = TwoWaySearcher::new(b"\n\n");
    let rr = TwoWaySearcher::new(b"\r\r");
    let to_parse_length;

    loop {
      // Read data from the stream into the temporary buffer
      let read_bytes = self.stream.read(&mut temp_buf).await?;

      // If no bytes are read, return an empty response head
      if read_bytes == 0 {
        self.response_head_length = Some(0);
        return Ok(&[0u8; 0]);
      }

      // If the response buffer exceeds the capacity, return an empty response head
      if self.response_buf.len() + read_bytes > RESPONSE_BUFFER_CAPACITY {
        self.response_head_length = Some(0);
        return Ok(&[0u8; 0]);
      }

      // Determine the starting point for searching the "\r\n\r\n" sequence
      let begin_rnrn_or_nrnr_search = self.response_buf.len().saturating_sub(3);
      let begin_rr_or_nn_search = self.response_buf.len().saturating_sub(1);
      self.response_buf.extend_from_slice(&temp_buf[..read_bytes]);

      // Search for the "\r\n\r\n" sequence in the response buffer
      if let Some(rnrn_index) = rnrn.search_in(&self.response_buf[begin_rnrn_or_nrnr_search..]) {
        to_parse_length = begin_rnrn_or_nrnr_search + rnrn_index + 4;
        break;
      } else if let Some(nrnr_index) =
        nrnr.search_in(&self.response_buf[begin_rnrn_or_nrnr_search..])
      {
        to_parse_length = begin_rnrn_or_nrnr_search + nrnr_index + 4;
        break;
      } else if let Some(nn_index) = nn.search_in(&self.response_buf[begin_rr_or_nn_search..]) {
        to_parse_length = begin_rr_or_nn_search + nn_index + 2;
        break;
      } else if let Some(rr_index) = rr.search_in(&self.response_buf[begin_rr_or_nn_search..]) {
        to_parse_length = begin_rr_or_nn_search + rr_index + 2;
        break;
      }
    }

    // Set the length of the response header
    self.response_head_length = Some(to_parse_length);

    // Return the response header as a byte slice
    Ok(&self.response_buf[..to_parse_length])
  }
}

// Implementation of AsyncRead for the CgiResponse struct
impl<R> AsyncRead for CgiResponse<R>
where
  R: AsyncRead + Unpin,
{
  fn poll_read(
    mut self: Pin<&mut Self>,
    cx: &mut Context<'_>,
    buf: &mut ReadBuf<'_>,
  ) -> Poll<std::io::Result<()>> {
    // If the response header length is known and the buffer contains more data than the header length
    if let Some(response_head_length) = self.response_head_length {
      if self.response_buf.len() > response_head_length {
        let remaining_data = &self.response_buf[response_head_length..];
        let to_read = remaining_data.len().min(buf.remaining());
        buf.put_slice(&remaining_data[..to_read]);
        self.response_head_length = Some(response_head_length + to_read);
        return Poll::Ready(Ok(()));
      }
    }

    // Create a temporary buffer to hold the data to be consumed
    let stream = Pin::new(&mut self.stream);
    match stream.poll_read(cx, buf) {
      Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
      other => other,
    }
  }
}

#[cfg(test)]
mod tests {
  use super::*;
  use tokio::io::AsyncReadExt;
  use tokio_test::io::Builder;

  #[tokio::test]
  async fn test_get_head() {
    let data = b"Content-Type: text/plain\r\n\r\n";
    let mut stream = Builder::new().read(data).build();
    let mut response = CgiResponse::new(&mut stream);

    let head = response.get_head().await.unwrap();
    assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");
  }

  #[tokio::test]
  async fn test_get_head_nn() {
    let data = b"Content-Type: text/plain\n\n";
    let mut stream = Builder::new().read(data).build();
    let mut response = CgiResponse::new(&mut stream);

    let head = response.get_head().await.unwrap();
    assert_eq!(head, b"Content-Type: text/plain\n\n");
  }

  #[tokio::test]
  async fn test_get_head_large_headers() {
    let data = b"Content-Type: text/plain\r\n";
    let large_header = vec![b'A'; RESPONSE_BUFFER_CAPACITY + 10]
      .into_iter()
      .collect::<Vec<u8>>();
    let mut stream = Builder::new().read(data).read(&large_header).build();
    let mut response = CgiResponse::new(&mut stream);

    let result = response.get_head().await;
    assert_eq!(result.unwrap().len(), 0);

    // Consume the remaining data to avoid panicking
    let mut remaining_data = vec![0u8; RESPONSE_BUFFER_CAPACITY + 10];
    let _ = response.stream.read(&mut remaining_data).await;
  }

  #[tokio::test]
  async fn test_get_head_premature_eof() {
    let data = b"Content-Type: text/plain\r\n";
    let mut stream = Builder::new().read(data).build();
    let mut response = CgiResponse::new(&mut stream);

    let result = response.get_head().await;
    assert_eq!(result.unwrap().len(), 0);
  }

  #[tokio::test]
  async fn test_poll_read() {
    let data = b"Content-Type: text/plain\r\n\r\nHello, world!";
    let mut stream = Builder::new().read(data).build();
    let mut response = CgiResponse::new(&mut stream);

    let head = response.get_head().await.unwrap();
    assert_eq!(head, b"Content-Type: text/plain\r\n\r\n");

    let mut buf = vec![0u8; 13];
    let n = response.read(&mut buf).await.unwrap();
    assert_eq!(n, 13);
    assert_eq!(&buf[..n], b"Hello, world!");
  }
}