1use std::cmp::min;
19use std::collections::VecDeque;
20use std::io::Read;
21use std::io::Write;
22use std::io::{self};
23use std::mem::size_of;
24use std::mem::MaybeUninit;
25use std::ops::Deref;
26use std::ptr::copy_nonoverlapping;
27
28use virtio_queue::DescriptorChain;
29use vm_memory::bitmap::Bitmap;
30use vm_memory::bitmap::BitmapSlice;
31use vm_memory::Address;
32use vm_memory::ByteValued;
33use vm_memory::GuestMemory;
34use vm_memory::GuestMemoryMmap;
35use vm_memory::GuestMemoryRegion;
36use vm_memory::VolatileMemory;
37use vm_memory::VolatileSlice;
38
39use crate::buffer::ReadWriteAtVolatile;
40use crate::error::*;
41
42struct DescriptorChainConsumer<'a, B> {
44 buffers: VecDeque<VolatileSlice<'a, B>>,
45 bytes_consumed: usize,
46}
47
48impl<'a, B: BitmapSlice> DescriptorChainConsumer<'a, B> {
49 #[cfg(test)]
50 fn available_bytes(&self) -> usize {
51 self.buffers.iter().fold(0, |count, vs| count + vs.len())
52 }
53
54 fn bytes_consumed(&self) -> usize {
55 self.bytes_consumed
56 }
57
58 fn consume<F>(&mut self, count: usize, f: F) -> Result<usize>
59 where
60 F: FnOnce(&[&VolatileSlice<B>]) -> Result<usize>,
61 {
62 let mut len = 0;
63 let mut bufs = Vec::with_capacity(self.buffers.len());
64 for vs in &self.buffers {
65 if len >= count {
66 break;
67 }
68 bufs.push(vs);
69 let remain = count - len;
70 if remain < vs.len() {
71 len += remain;
72 } else {
73 len += vs.len();
74 }
75 }
76 if bufs.is_empty() {
77 return Ok(0);
78 }
79 let bytes_consumed = f(&bufs)?;
80 let total_bytes_consumed =
81 self.bytes_consumed
82 .checked_add(bytes_consumed)
83 .ok_or(new_vhost_user_fs_error(
84 "the combined length of all the buffers in DescriptorChain would overflow",
85 None,
86 ))?;
87 let mut remain = bytes_consumed;
88 while let Some(vs) = self.buffers.pop_front() {
89 if remain < vs.len() {
90 self.buffers.push_front(vs.offset(remain).unwrap());
91 break;
92 }
93 remain -= vs.len();
94 }
95 self.bytes_consumed = total_bytes_consumed;
96 Ok(bytes_consumed)
97 }
98
99 fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a, B>> {
100 let mut remain = offset;
101 let pos = self.buffers.iter().position(|vs| {
102 if remain < vs.len() {
103 true
104 } else {
105 remain -= vs.len();
106 false
107 }
108 });
109 if let Some(at) = pos {
110 let mut other = self.buffers.split_off(at);
111 if remain > 0 {
112 let front = other.pop_front().expect("empty VecDeque after split");
113 self.buffers.push_back(
114 front
115 .subslice(0, remain)
116 .map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?,
117 );
118 other.push_front(
119 front
120 .offset(remain)
121 .map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?,
122 );
123 }
124 Ok(DescriptorChainConsumer {
125 buffers: other,
126 bytes_consumed: 0,
127 })
128 } else if remain == 0 {
129 Ok(DescriptorChainConsumer {
130 buffers: VecDeque::new(),
131 bytes_consumed: 0,
132 })
133 } else {
134 Err(new_vhost_user_fs_error(
135 "DescriptorChain split is out of bounds",
136 None,
137 ))
138 }
139 }
140}
141
142pub struct Reader<'a, B = ()> {
144 buffer: DescriptorChainConsumer<'a, B>,
145}
146
147impl<'a, B: Bitmap + BitmapSlice + 'static> Reader<'a, B> {
148 pub fn new<M>(
149 mem: &'a GuestMemoryMmap<B>,
150 desc_chain: DescriptorChain<M>,
151 ) -> Result<Reader<'a, B>>
152 where
153 M: Deref,
154 M::Target: GuestMemory + Sized,
155 {
156 let mut len: usize = 0;
157 let buffers = desc_chain
158 .readable()
159 .map(|desc| {
160 len = len
161 .checked_add(desc.len() as usize)
162 .ok_or(new_vhost_user_fs_error(
163 "the combined length of all the buffers in DescriptorChain would overflow",
164 None,
165 ))?;
166 let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error(
167 "no memory region for this address range",
168 None,
169 ))?;
170 let offset = desc
171 .addr()
172 .checked_sub(region.start_addr().raw_value())
173 .unwrap();
174 region
175 .deref()
176 .get_slice(offset.raw_value() as usize, desc.len() as usize)
177 .map_err(|err| {
178 new_vhost_user_fs_error("volatile memory error", Some(err.into()))
179 })
180 })
181 .collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?;
182 Ok(Reader {
183 buffer: DescriptorChainConsumer {
184 buffers,
185 bytes_consumed: 0,
186 },
187 })
188 }
189
190 pub fn read_obj<T: ByteValued>(&mut self) -> io::Result<T> {
191 let mut obj = MaybeUninit::<T>::uninit();
192 let buf =
193 unsafe { std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>()) };
194 self.read_exact(buf)?;
195 Ok(unsafe { obj.assume_init() })
196 }
197
198 pub fn read_to_at<F: ReadWriteAtVolatile<B>>(
199 &mut self,
200 dst: F,
201 count: usize,
202 ) -> io::Result<usize> {
203 self.buffer
204 .consume(count, |bufs| dst.write_vectored_at_volatile(bufs))
205 .map_err(|err| err.into())
206 }
207
208 #[cfg(test)]
209 pub fn available_bytes(&self) -> usize {
210 self.buffer.available_bytes()
211 }
212
213 #[cfg(test)]
214 pub fn bytes_read(&self) -> usize {
215 self.buffer.bytes_consumed()
216 }
217}
218
219impl<B: BitmapSlice> io::Read for Reader<'_, B> {
220 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
221 self.buffer
222 .consume(buf.len(), |bufs| {
223 let mut rem = buf;
224 let mut total = 0;
225 for vs in bufs {
226 let copy_len = min(rem.len(), vs.len());
227 unsafe {
228 copy_nonoverlapping(vs.ptr_guard().as_ptr(), rem.as_mut_ptr(), copy_len);
229 }
230 rem = &mut rem[copy_len..];
231 total += copy_len;
232 }
233 Ok(total)
234 })
235 .map_err(|err| err.into())
236 }
237}
238
239pub struct Writer<'a, B = ()> {
241 buffer: DescriptorChainConsumer<'a, B>,
242}
243
244impl<'a, B: Bitmap + BitmapSlice + 'static> Writer<'a, B> {
245 pub fn new<M>(
246 mem: &'a GuestMemoryMmap<B>,
247 desc_chain: DescriptorChain<M>,
248 ) -> Result<Writer<'a, B>>
249 where
250 M: Deref,
251 M::Target: GuestMemory + Sized,
252 {
253 let mut len: usize = 0;
254 let buffers = desc_chain
255 .writable()
256 .map(|desc| {
257 len = len
258 .checked_add(desc.len() as usize)
259 .ok_or(new_vhost_user_fs_error(
260 "the combined length of all the buffers in DescriptorChain would overflow",
261 None,
262 ))?;
263 let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error(
264 "no memory region for this address range",
265 None,
266 ))?;
267 let offset = desc
268 .addr()
269 .checked_sub(region.start_addr().raw_value())
270 .unwrap();
271 region
272 .deref()
273 .get_slice(offset.raw_value() as usize, desc.len() as usize)
274 .map_err(|err| {
275 new_vhost_user_fs_error("volatile memory error", Some(err.into()))
276 })
277 })
278 .collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?;
279 Ok(Writer {
280 buffer: DescriptorChainConsumer {
281 buffers,
282 bytes_consumed: 0,
283 },
284 })
285 }
286
287 pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a, B>> {
288 self.buffer.split_at(offset).map(|buffer| Writer { buffer })
289 }
290
291 pub fn write_from_at<F: ReadWriteAtVolatile<B>>(
292 &mut self,
293 src: F,
294 count: usize,
295 ) -> io::Result<usize> {
296 self.buffer
297 .consume(count, |bufs| src.read_vectored_at_volatile(bufs))
298 .map_err(|err| err.into())
299 }
300
301 #[cfg(test)]
302 pub fn available_bytes(&self) -> usize {
303 self.buffer.available_bytes()
304 }
305
306 pub fn bytes_written(&self) -> usize {
307 self.buffer.bytes_consumed()
308 }
309}
310
311impl<B: BitmapSlice> Write for Writer<'_, B> {
312 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
313 self.buffer
314 .consume(buf.len(), |bufs| {
315 let mut rem = buf;
316 let mut total = 0;
317 for vs in bufs {
318 let copy_len = min(rem.len(), vs.len());
319 unsafe {
320 copy_nonoverlapping(rem.as_ptr(), vs.ptr_guard_mut().as_ptr(), copy_len);
321 }
322 vs.bitmap().mark_dirty(0, copy_len);
323 rem = &rem[copy_len..];
324 total += copy_len;
325 }
326 Ok(total)
327 })
328 .map_err(|err| err.into())
329 }
330
331 fn flush(&mut self) -> io::Result<()> {
332 Ok(())
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use virtio_queue::Queue;
340 use virtio_queue::QueueOwnedT;
341 use virtio_queue::QueueT;
342 use vm_memory::Bytes;
343 use vm_memory::GuestAddress;
344 use vm_memory::Le16;
345 use vm_memory::Le32;
346 use vm_memory::Le64;
347
348 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
349 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
350
351 enum DescriptorType {
352 Readable,
353 Writable,
354 }
355
356 #[derive(Copy, Clone, Debug, Default)]
358 #[repr(C)]
359 struct VirtqDesc {
360 addr: Le64,
361 len: Le32,
362 flags: Le16,
363 next: Le16,
364 }
365
366 #[derive(Copy, Clone, Debug, Default)]
368 #[repr(C)]
369 struct VirtqAvail {
370 flags: Le16,
371 idx: Le16,
372 ring: Le16,
373 }
374
375 unsafe impl ByteValued for VirtqAvail {}
376 unsafe impl ByteValued for VirtqDesc {}
377
378 fn create_descriptor_chain(
380 memory: &GuestMemoryMmap,
381 descriptor_array_addr: GuestAddress,
382 mut buffers_start_addr: GuestAddress,
383 descriptors: Vec<(DescriptorType, u32)>,
384 ) -> DescriptorChain<&GuestMemoryMmap> {
385 let descriptors_len = descriptors.len();
386 for (index, (type_, size)) in descriptors.into_iter().enumerate() {
387 let mut flags = 0;
388 if let DescriptorType::Writable = type_ {
389 flags |= VIRTQ_DESC_F_WRITE;
390 }
391 if index + 1 < descriptors_len {
392 flags |= VIRTQ_DESC_F_NEXT;
393 }
394
395 let desc = VirtqDesc {
396 addr: buffers_start_addr.raw_value().into(),
397 len: size.into(),
398 flags: flags.into(),
399 next: (index as u16 + 1).into(),
400 };
401
402 buffers_start_addr = buffers_start_addr.checked_add(size as u64).unwrap();
403
404 memory
405 .write_obj(
406 desc,
407 descriptor_array_addr
408 .checked_add((index * std::mem::size_of::<VirtqDesc>()) as u64)
409 .unwrap(),
410 )
411 .unwrap();
412 }
413
414 let avail_ring = descriptor_array_addr
415 .checked_add((descriptors_len * std::mem::size_of::<VirtqDesc>()) as u64)
416 .unwrap();
417 let avail = VirtqAvail {
418 flags: 0.into(),
419 idx: 1.into(),
420 ring: 0.into(),
421 };
422 memory.write_obj(avail, avail_ring).unwrap();
423
424 let mut queue = Queue::new(4).unwrap();
425 queue
426 .try_set_desc_table_address(descriptor_array_addr)
427 .unwrap();
428 queue.try_set_avail_ring_address(avail_ring).unwrap();
429 queue.set_ready(true);
430 queue.iter(memory).unwrap().next().unwrap()
431 }
432
433 #[test]
434 fn simple_chain_reader_test() {
435 let memory_start_addr = GuestAddress(0x0);
436 let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap();
437
438 let chain = create_descriptor_chain(
439 &memory,
440 GuestAddress(0x0),
441 GuestAddress(0x100),
442 vec![
443 (DescriptorType::Readable, 8),
444 (DescriptorType::Readable, 16),
445 (DescriptorType::Readable, 18),
446 (DescriptorType::Readable, 64),
447 ],
448 );
449
450 let mut reader = Reader::new(&memory, chain).unwrap();
451 assert_eq!(reader.available_bytes(), 106);
452 assert_eq!(reader.bytes_read(), 0);
453
454 let mut buffer = [0; 64];
455 reader.read_exact(&mut buffer).unwrap();
456 assert_eq!(reader.available_bytes(), 42);
457 assert_eq!(reader.bytes_read(), 64);
458 assert_eq!(reader.read(&mut buffer).unwrap(), 42);
459 assert_eq!(reader.available_bytes(), 0);
460 assert_eq!(reader.bytes_read(), 106);
461 }
462
463 #[test]
464 fn simple_chain_writer_test() {
465 let memory_start_addr = GuestAddress(0x0);
466 let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap();
467
468 let chain = create_descriptor_chain(
469 &memory,
470 GuestAddress(0x0),
471 GuestAddress(0x100),
472 vec![
473 (DescriptorType::Writable, 8),
474 (DescriptorType::Writable, 16),
475 (DescriptorType::Writable, 18),
476 (DescriptorType::Writable, 64),
477 ],
478 );
479
480 let mut writer = Writer::new(&memory, chain).unwrap();
481 assert_eq!(writer.available_bytes(), 106);
482 assert_eq!(writer.bytes_written(), 0);
483
484 let buffer = [0; 64];
485 writer.write_all(&buffer).unwrap();
486 assert_eq!(writer.available_bytes(), 42);
487 assert_eq!(writer.bytes_written(), 64);
488 assert_eq!(writer.write(&buffer).unwrap(), 42);
489 assert_eq!(writer.available_bytes(), 0);
490 assert_eq!(writer.bytes_written(), 106);
491 }
492}