1use std::{
4 cell::RefCell,
5 collections::VecDeque,
6 pin::Pin,
7 rc::{Rc, Weak},
8 task::{Context, Poll, Waker},
9};
10
11use bytes::Bytes;
12use futures_core::Stream;
13
14use crate::error::PayloadError;
15
16pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19#[derive(Debug, PartialEq, Eq)]
20pub enum PayloadStatus {
21 Read,
22 Pause,
23 Dropped,
24}
25
26#[derive(Debug)]
33pub struct Payload {
34 inner: Rc<RefCell<Inner>>,
35}
36
37impl Payload {
38 pub fn create(eof: bool) -> (PayloadSender, Payload) {
44 let shared = Rc::new(RefCell::new(Inner::new(eof)));
45
46 (
47 PayloadSender::new(Rc::downgrade(&shared)),
48 Payload { inner: shared },
49 )
50 }
51
52 pub(crate) fn empty() -> Payload {
54 Payload {
55 inner: Rc::new(RefCell::new(Inner::new(true))),
56 }
57 }
58
59 #[cfg(test)]
61 pub fn len(&self) -> usize {
62 self.inner.borrow().len()
63 }
64
65 #[cfg(test)]
67 pub fn is_empty(&self) -> bool {
68 self.inner.borrow().len() == 0
69 }
70
71 #[inline]
73 pub fn unread_data(&mut self, data: Bytes) {
74 self.inner.borrow_mut().unread_data(data);
75 }
76}
77
78impl Stream for Payload {
79 type Item = Result<Bytes, PayloadError>;
80
81 fn poll_next(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
85 Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx)
86 }
87}
88
89pub struct PayloadSender {
91 inner: Weak<RefCell<Inner>>,
92}
93
94impl PayloadSender {
95 fn new(inner: Weak<RefCell<Inner>>) -> Self {
96 Self { inner }
97 }
98
99 #[inline]
100 pub fn set_error(&mut self, err: PayloadError) {
101 if let Some(shared) = self.inner.upgrade() {
102 shared.borrow_mut().set_error(err)
103 }
104 }
105
106 #[inline]
107 pub fn feed_eof(&mut self) {
108 if let Some(shared) = self.inner.upgrade() {
109 shared.borrow_mut().feed_eof()
110 }
111 }
112
113 #[inline]
114 pub fn feed_data(&mut self, data: Bytes) {
115 if let Some(shared) = self.inner.upgrade() {
116 shared.borrow_mut().feed_data(data)
117 }
118 }
119
120 #[allow(clippy::needless_pass_by_ref_mut)]
121 #[inline]
122 pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
123 if let Some(shared) = self.inner.upgrade() {
126 if shared.borrow().need_read {
127 PayloadStatus::Read
128 } else {
129 shared.borrow_mut().register_io(cx);
130 PayloadStatus::Pause
131 }
132 } else {
133 PayloadStatus::Dropped
134 }
135 }
136}
137
138#[derive(Debug)]
139struct Inner {
140 len: usize,
141 eof: bool,
142 err: Option<PayloadError>,
143 need_read: bool,
144 items: VecDeque<Bytes>,
145 task: Option<Waker>,
146 io_task: Option<Waker>,
147}
148
149impl Inner {
150 fn new(eof: bool) -> Self {
151 Inner {
152 eof,
153 len: 0,
154 err: None,
155 items: VecDeque::new(),
156 need_read: true,
157 task: None,
158 io_task: None,
159 }
160 }
161
162 fn wake(&mut self) {
164 if let Some(waker) = self.task.take() {
165 waker.wake();
166 }
167 }
168
169 fn wake_io(&mut self) {
171 if let Some(waker) = self.io_task.take() {
172 waker.wake();
173 }
174 }
175
176 fn register(&mut self, cx: &Context<'_>) {
179 if self
180 .task
181 .as_ref()
182 .map_or(true, |w| !cx.waker().will_wake(w))
183 {
184 self.task = Some(cx.waker().clone());
185 }
186 }
187
188 fn register_io(&mut self, cx: &Context<'_>) {
191 if self
192 .io_task
193 .as_ref()
194 .map_or(true, |w| !cx.waker().will_wake(w))
195 {
196 self.io_task = Some(cx.waker().clone());
197 }
198 }
199
200 #[inline]
201 fn set_error(&mut self, err: PayloadError) {
202 self.err = Some(err);
203 self.wake();
204 }
205
206 #[inline]
207 fn feed_eof(&mut self) {
208 self.eof = true;
209 self.wake();
210 }
211
212 #[inline]
213 fn feed_data(&mut self, data: Bytes) {
214 self.len += data.len();
215 self.items.push_back(data);
216 self.need_read = self.len < MAX_BUFFER_SIZE;
217 self.wake();
218 }
219
220 #[cfg(test)]
221 fn len(&self) -> usize {
222 self.len
223 }
224
225 fn poll_next(
226 mut self: Pin<&mut Self>,
227 cx: &Context<'_>,
228 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
229 if let Some(data) = self.items.pop_front() {
230 self.len -= data.len();
231 self.need_read = self.len < MAX_BUFFER_SIZE;
232
233 if self.need_read && !self.eof {
234 self.register(cx);
235 }
236 self.wake_io();
237 Poll::Ready(Some(Ok(data)))
238 } else if let Some(err) = self.err.take() {
239 Poll::Ready(Some(Err(err)))
240 } else if self.eof {
241 Poll::Ready(None)
242 } else {
243 self.need_read = true;
244 self.register(cx);
245 self.wake_io();
246 Poll::Pending
247 }
248 }
249
250 fn unread_data(&mut self, data: Bytes) {
251 self.len += data.len();
252 self.items.push_front(data);
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::{task::Poll, time::Duration};
259
260 use actix_rt::time::timeout;
261 use actix_utils::future::poll_fn;
262 use futures_util::{FutureExt, StreamExt};
263 use static_assertions::{assert_impl_all, assert_not_impl_any};
264 use tokio::sync::oneshot;
265
266 use super::*;
267
268 assert_impl_all!(Payload: Unpin);
269 assert_not_impl_any!(Payload: Send, Sync);
270
271 assert_impl_all!(Inner: Unpin, Send, Sync);
272
273 const WAKE_TIMEOUT: Duration = Duration::from_secs(2);
274
275 fn prepare_waking_test(
276 mut payload: Payload,
277 expected: Option<Result<(), ()>>,
278 ) -> (oneshot::Receiver<()>, actix_rt::task::JoinHandle<()>) {
279 let (tx, rx) = oneshot::channel();
280
281 let handle = actix_rt::spawn(async move {
282 poll_fn(|cx| {
284 assert!(payload.poll_next_unpin(cx).is_pending());
285 Poll::Ready(())
286 })
287 .await;
288 tx.send(()).unwrap();
289
290 let mut pend_once = false;
292 poll_fn(|_| {
293 if pend_once {
294 Poll::Ready(())
295 } else {
296 pend_once = true;
299 Poll::Pending
300 }
301 })
302 .await;
303
304 let got = payload.next().now_or_never().unwrap();
305 match expected {
306 Some(Ok(_)) => assert!(got.unwrap().is_ok()),
307 Some(Err(_)) => assert!(got.unwrap().is_err()),
308 None => assert!(got.is_none()),
309 }
310 });
311 (rx, handle)
312 }
313
314 #[actix_rt::test]
315 async fn wake_on_error() {
316 let (mut sender, payload) = Payload::create(false);
317 let (rx, handle) = prepare_waking_test(payload, Some(Err(())));
318
319 rx.await.unwrap();
320 sender.set_error(PayloadError::Incomplete(None));
321 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
322 }
323
324 #[actix_rt::test]
325 async fn wake_on_eof() {
326 let (mut sender, payload) = Payload::create(false);
327 let (rx, handle) = prepare_waking_test(payload, None);
328
329 rx.await.unwrap();
330 sender.feed_eof();
331 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
332 }
333
334 #[actix_rt::test]
335 async fn test_unread_data() {
336 let (_, mut payload) = Payload::create(false);
337
338 payload.unread_data(Bytes::from("data"));
339 assert!(!payload.is_empty());
340 assert_eq!(payload.len(), 4);
341
342 assert_eq!(
343 Bytes::from("data"),
344 poll_fn(|cx| Pin::new(&mut payload).poll_next(cx))
345 .await
346 .unwrap()
347 .unwrap()
348 );
349 }
350}