actix_http/h1/
payload.rs

1//! Payload stream
2
3use std::{
4    cell::RefCell,
5    collections::VecDeque,
6    pin::Pin,
7    rc::{Rc, Weak},
8    task::{Context, Poll, Waker},
9};
10
11use bytes::Bytes;
12use futures_core::Stream;
13
14use crate::error::PayloadError;
15
16/// max buffer size 32k
17pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19#[derive(Debug, PartialEq, Eq)]
20pub enum PayloadStatus {
21    Read,
22    Pause,
23    Dropped,
24}
25
26/// Buffered stream of bytes chunks
27///
28/// Payload stores chunks in a vector. First chunk can be received with `poll_next`. Payload does
29/// not notify current task when new data is available.
30///
31/// Payload can be used as `Response` body stream.
32#[derive(Debug)]
33pub struct Payload {
34    inner: Rc<RefCell<Inner>>,
35}
36
37impl Payload {
38    /// Creates a payload stream.
39    ///
40    /// This method construct two objects responsible for bytes stream generation:
41    /// - `PayloadSender` - *Sender* side of the stream
42    /// - `Payload` - *Receiver* side of the stream
43    pub fn create(eof: bool) -> (PayloadSender, Payload) {
44        let shared = Rc::new(RefCell::new(Inner::new(eof)));
45
46        (
47            PayloadSender::new(Rc::downgrade(&shared)),
48            Payload { inner: shared },
49        )
50    }
51
52    /// Creates an empty payload.
53    pub(crate) fn empty() -> Payload {
54        Payload {
55            inner: Rc::new(RefCell::new(Inner::new(true))),
56        }
57    }
58
59    /// Length of the data in this payload
60    #[cfg(test)]
61    pub fn len(&self) -> usize {
62        self.inner.borrow().len()
63    }
64
65    /// Is payload empty
66    #[cfg(test)]
67    pub fn is_empty(&self) -> bool {
68        self.inner.borrow().len() == 0
69    }
70
71    /// Put unused data back to payload
72    #[inline]
73    pub fn unread_data(&mut self, data: Bytes) {
74        self.inner.borrow_mut().unread_data(data);
75    }
76}
77
78impl Stream for Payload {
79    type Item = Result<Bytes, PayloadError>;
80
81    fn poll_next(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<Result<Bytes, PayloadError>>> {
85        Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx)
86    }
87}
88
89/// Sender part of the payload stream
90pub struct PayloadSender {
91    inner: Weak<RefCell<Inner>>,
92}
93
94impl PayloadSender {
95    fn new(inner: Weak<RefCell<Inner>>) -> Self {
96        Self { inner }
97    }
98
99    #[inline]
100    pub fn set_error(&mut self, err: PayloadError) {
101        if let Some(shared) = self.inner.upgrade() {
102            shared.borrow_mut().set_error(err)
103        }
104    }
105
106    #[inline]
107    pub fn feed_eof(&mut self) {
108        if let Some(shared) = self.inner.upgrade() {
109            shared.borrow_mut().feed_eof()
110        }
111    }
112
113    #[inline]
114    pub fn feed_data(&mut self, data: Bytes) {
115        if let Some(shared) = self.inner.upgrade() {
116            shared.borrow_mut().feed_data(data)
117        }
118    }
119
120    #[allow(clippy::needless_pass_by_ref_mut)]
121    #[inline]
122    pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
123        // we check need_read only if Payload (other side) is alive,
124        // otherwise always return true (consume payload)
125        if let Some(shared) = self.inner.upgrade() {
126            if shared.borrow().need_read {
127                PayloadStatus::Read
128            } else {
129                shared.borrow_mut().register_io(cx);
130                PayloadStatus::Pause
131            }
132        } else {
133            PayloadStatus::Dropped
134        }
135    }
136}
137
138#[derive(Debug)]
139struct Inner {
140    len: usize,
141    eof: bool,
142    err: Option<PayloadError>,
143    need_read: bool,
144    items: VecDeque<Bytes>,
145    task: Option<Waker>,
146    io_task: Option<Waker>,
147}
148
149impl Inner {
150    fn new(eof: bool) -> Self {
151        Inner {
152            eof,
153            len: 0,
154            err: None,
155            items: VecDeque::new(),
156            need_read: true,
157            task: None,
158            io_task: None,
159        }
160    }
161
162    /// Wake up future waiting for payload data to be available.
163    fn wake(&mut self) {
164        if let Some(waker) = self.task.take() {
165            waker.wake();
166        }
167    }
168
169    /// Wake up future feeding data to Payload.
170    fn wake_io(&mut self) {
171        if let Some(waker) = self.io_task.take() {
172            waker.wake();
173        }
174    }
175
176    /// Register future waiting data from payload.
177    /// Waker would be used in `Inner::wake`
178    fn register(&mut self, cx: &Context<'_>) {
179        if self
180            .task
181            .as_ref()
182            .map_or(true, |w| !cx.waker().will_wake(w))
183        {
184            self.task = Some(cx.waker().clone());
185        }
186    }
187
188    // Register future feeding data to payload.
189    /// Waker would be used in `Inner::wake_io`
190    fn register_io(&mut self, cx: &Context<'_>) {
191        if self
192            .io_task
193            .as_ref()
194            .map_or(true, |w| !cx.waker().will_wake(w))
195        {
196            self.io_task = Some(cx.waker().clone());
197        }
198    }
199
200    #[inline]
201    fn set_error(&mut self, err: PayloadError) {
202        self.err = Some(err);
203        self.wake();
204    }
205
206    #[inline]
207    fn feed_eof(&mut self) {
208        self.eof = true;
209        self.wake();
210    }
211
212    #[inline]
213    fn feed_data(&mut self, data: Bytes) {
214        self.len += data.len();
215        self.items.push_back(data);
216        self.need_read = self.len < MAX_BUFFER_SIZE;
217        self.wake();
218    }
219
220    #[cfg(test)]
221    fn len(&self) -> usize {
222        self.len
223    }
224
225    fn poll_next(
226        mut self: Pin<&mut Self>,
227        cx: &Context<'_>,
228    ) -> Poll<Option<Result<Bytes, PayloadError>>> {
229        if let Some(data) = self.items.pop_front() {
230            self.len -= data.len();
231            self.need_read = self.len < MAX_BUFFER_SIZE;
232
233            if self.need_read && !self.eof {
234                self.register(cx);
235            }
236            self.wake_io();
237            Poll::Ready(Some(Ok(data)))
238        } else if let Some(err) = self.err.take() {
239            Poll::Ready(Some(Err(err)))
240        } else if self.eof {
241            Poll::Ready(None)
242        } else {
243            self.need_read = true;
244            self.register(cx);
245            self.wake_io();
246            Poll::Pending
247        }
248    }
249
250    fn unread_data(&mut self, data: Bytes) {
251        self.len += data.len();
252        self.items.push_front(data);
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::{task::Poll, time::Duration};
259
260    use actix_rt::time::timeout;
261    use actix_utils::future::poll_fn;
262    use futures_util::{FutureExt, StreamExt};
263    use static_assertions::{assert_impl_all, assert_not_impl_any};
264    use tokio::sync::oneshot;
265
266    use super::*;
267
268    assert_impl_all!(Payload: Unpin);
269    assert_not_impl_any!(Payload: Send, Sync);
270
271    assert_impl_all!(Inner: Unpin, Send, Sync);
272
273    const WAKE_TIMEOUT: Duration = Duration::from_secs(2);
274
275    fn prepare_waking_test(
276        mut payload: Payload,
277        expected: Option<Result<(), ()>>,
278    ) -> (oneshot::Receiver<()>, actix_rt::task::JoinHandle<()>) {
279        let (tx, rx) = oneshot::channel();
280
281        let handle = actix_rt::spawn(async move {
282            // Make sure to poll once to set the waker
283            poll_fn(|cx| {
284                assert!(payload.poll_next_unpin(cx).is_pending());
285                Poll::Ready(())
286            })
287            .await;
288            tx.send(()).unwrap();
289
290            // actix-rt is single-threaded, so this won't race with `rx.await`
291            let mut pend_once = false;
292            poll_fn(|_| {
293                if pend_once {
294                    Poll::Ready(())
295                } else {
296                    // Return pending without storing wakers, we already did on the previous
297                    // `poll_fn`, now this task will only continue if the `sender` wakes us
298                    pend_once = true;
299                    Poll::Pending
300                }
301            })
302            .await;
303
304            let got = payload.next().now_or_never().unwrap();
305            match expected {
306                Some(Ok(_)) => assert!(got.unwrap().is_ok()),
307                Some(Err(_)) => assert!(got.unwrap().is_err()),
308                None => assert!(got.is_none()),
309            }
310        });
311        (rx, handle)
312    }
313
314    #[actix_rt::test]
315    async fn wake_on_error() {
316        let (mut sender, payload) = Payload::create(false);
317        let (rx, handle) = prepare_waking_test(payload, Some(Err(())));
318
319        rx.await.unwrap();
320        sender.set_error(PayloadError::Incomplete(None));
321        timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
322    }
323
324    #[actix_rt::test]
325    async fn wake_on_eof() {
326        let (mut sender, payload) = Payload::create(false);
327        let (rx, handle) = prepare_waking_test(payload, None);
328
329        rx.await.unwrap();
330        sender.feed_eof();
331        timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
332    }
333
334    #[actix_rt::test]
335    async fn test_unread_data() {
336        let (_, mut payload) = Payload::create(false);
337
338        payload.unread_data(Bytes::from("data"));
339        assert!(!payload.is_empty());
340        assert_eq!(payload.len(), 4);
341
342        assert_eq!(
343            Bytes::from("data"),
344            poll_fn(|cx| Pin::new(&mut payload).poll_next(cx))
345                .await
346                .unwrap()
347                .unwrap()
348        );
349    }
350}