virtiofs_opendal/
virtiofs_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::min;
19use std::collections::VecDeque;
20use std::io::Read;
21use std::io::Write;
22use std::io::{self};
23use std::mem::size_of;
24use std::mem::MaybeUninit;
25use std::ops::Deref;
26use std::ptr::copy_nonoverlapping;
27
28use virtio_queue::DescriptorChain;
29use vm_memory::bitmap::Bitmap;
30use vm_memory::bitmap::BitmapSlice;
31use vm_memory::Address;
32use vm_memory::ByteValued;
33use vm_memory::GuestMemory;
34use vm_memory::GuestMemoryMmap;
35use vm_memory::GuestMemoryRegion;
36use vm_memory::VolatileMemory;
37use vm_memory::VolatileSlice;
38
39use crate::buffer::ReadWriteAtVolatile;
40use crate::error::*;
41
42/// Used to consume and use data areas in shared memory between host and VMs.
43struct DescriptorChainConsumer<'a, B> {
44    buffers: VecDeque<VolatileSlice<'a, B>>,
45    bytes_consumed: usize,
46}
47
48impl<'a, B: BitmapSlice> DescriptorChainConsumer<'a, B> {
49    #[cfg(test)]
50    fn available_bytes(&self) -> usize {
51        self.buffers.iter().fold(0, |count, vs| count + vs.len())
52    }
53
54    fn bytes_consumed(&self) -> usize {
55        self.bytes_consumed
56    }
57
58    fn consume<F>(&mut self, count: usize, f: F) -> Result<usize>
59    where
60        F: FnOnce(&[&VolatileSlice<B>]) -> Result<usize>,
61    {
62        let mut len = 0;
63        let mut bufs = Vec::with_capacity(self.buffers.len());
64        for vs in &self.buffers {
65            if len >= count {
66                break;
67            }
68            bufs.push(vs);
69            let remain = count - len;
70            if remain < vs.len() {
71                len += remain;
72            } else {
73                len += vs.len();
74            }
75        }
76        if bufs.is_empty() {
77            return Ok(0);
78        }
79        let bytes_consumed = f(&bufs)?;
80        let total_bytes_consumed =
81            self.bytes_consumed
82                .checked_add(bytes_consumed)
83                .ok_or(new_vhost_user_fs_error(
84                    "the combined length of all the buffers in DescriptorChain would overflow",
85                    None,
86                ))?;
87        let mut remain = bytes_consumed;
88        while let Some(vs) = self.buffers.pop_front() {
89            if remain < vs.len() {
90                self.buffers.push_front(vs.offset(remain).unwrap());
91                break;
92            }
93            remain -= vs.len();
94        }
95        self.bytes_consumed = total_bytes_consumed;
96        Ok(bytes_consumed)
97    }
98
99    fn split_at(&mut self, offset: usize) -> Result<DescriptorChainConsumer<'a, B>> {
100        let mut remain = offset;
101        let pos = self.buffers.iter().position(|vs| {
102            if remain < vs.len() {
103                true
104            } else {
105                remain -= vs.len();
106                false
107            }
108        });
109        if let Some(at) = pos {
110            let mut other = self.buffers.split_off(at);
111            if remain > 0 {
112                let front = other.pop_front().expect("empty VecDeque after split");
113                self.buffers.push_back(
114                    front
115                        .subslice(0, remain)
116                        .map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?,
117                );
118                other.push_front(
119                    front
120                        .offset(remain)
121                        .map_err(|_| new_vhost_user_fs_error("volatile memory error", None))?,
122                );
123            }
124            Ok(DescriptorChainConsumer {
125                buffers: other,
126                bytes_consumed: 0,
127            })
128        } else if remain == 0 {
129            Ok(DescriptorChainConsumer {
130                buffers: VecDeque::new(),
131                bytes_consumed: 0,
132            })
133        } else {
134            Err(new_vhost_user_fs_error(
135                "DescriptorChain split is out of bounds",
136                None,
137            ))
138        }
139    }
140}
141
142/// Provides a high-level interface for reading data in shared memory sequences.
143pub struct Reader<'a, B = ()> {
144    buffer: DescriptorChainConsumer<'a, B>,
145}
146
147impl<'a, B: Bitmap + BitmapSlice + 'static> Reader<'a, B> {
148    pub fn new<M>(
149        mem: &'a GuestMemoryMmap<B>,
150        desc_chain: DescriptorChain<M>,
151    ) -> Result<Reader<'a, B>>
152    where
153        M: Deref,
154        M::Target: GuestMemory + Sized,
155    {
156        let mut len: usize = 0;
157        let buffers = desc_chain
158            .readable()
159            .map(|desc| {
160                len = len
161                    .checked_add(desc.len() as usize)
162                    .ok_or(new_vhost_user_fs_error(
163                        "the combined length of all the buffers in DescriptorChain would overflow",
164                        None,
165                    ))?;
166                let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error(
167                    "no memory region for this address range",
168                    None,
169                ))?;
170                let offset = desc
171                    .addr()
172                    .checked_sub(region.start_addr().raw_value())
173                    .unwrap();
174                region
175                    .deref()
176                    .get_slice(offset.raw_value() as usize, desc.len() as usize)
177                    .map_err(|err| {
178                        new_vhost_user_fs_error("volatile memory error", Some(err.into()))
179                    })
180            })
181            .collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?;
182        Ok(Reader {
183            buffer: DescriptorChainConsumer {
184                buffers,
185                bytes_consumed: 0,
186            },
187        })
188    }
189
190    pub fn read_obj<T: ByteValued>(&mut self) -> io::Result<T> {
191        let mut obj = MaybeUninit::<T>::uninit();
192        let buf =
193            unsafe { std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>()) };
194        self.read_exact(buf)?;
195        Ok(unsafe { obj.assume_init() })
196    }
197
198    pub fn read_to_at<F: ReadWriteAtVolatile<B>>(
199        &mut self,
200        dst: F,
201        count: usize,
202    ) -> io::Result<usize> {
203        self.buffer
204            .consume(count, |bufs| dst.write_vectored_at_volatile(bufs))
205            .map_err(|err| err.into())
206    }
207
208    #[cfg(test)]
209    pub fn available_bytes(&self) -> usize {
210        self.buffer.available_bytes()
211    }
212
213    #[cfg(test)]
214    pub fn bytes_read(&self) -> usize {
215        self.buffer.bytes_consumed()
216    }
217}
218
219impl<B: BitmapSlice> io::Read for Reader<'_, B> {
220    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
221        self.buffer
222            .consume(buf.len(), |bufs| {
223                let mut rem = buf;
224                let mut total = 0;
225                for vs in bufs {
226                    let copy_len = min(rem.len(), vs.len());
227                    unsafe {
228                        copy_nonoverlapping(vs.ptr_guard().as_ptr(), rem.as_mut_ptr(), copy_len);
229                    }
230                    rem = &mut rem[copy_len..];
231                    total += copy_len;
232                }
233                Ok(total)
234            })
235            .map_err(|err| err.into())
236    }
237}
238
239/// Provides a high-level interface for writing data in shared memory sequences.
240pub struct Writer<'a, B = ()> {
241    buffer: DescriptorChainConsumer<'a, B>,
242}
243
244impl<'a, B: Bitmap + BitmapSlice + 'static> Writer<'a, B> {
245    pub fn new<M>(
246        mem: &'a GuestMemoryMmap<B>,
247        desc_chain: DescriptorChain<M>,
248    ) -> Result<Writer<'a, B>>
249    where
250        M: Deref,
251        M::Target: GuestMemory + Sized,
252    {
253        let mut len: usize = 0;
254        let buffers = desc_chain
255            .writable()
256            .map(|desc| {
257                len = len
258                    .checked_add(desc.len() as usize)
259                    .ok_or(new_vhost_user_fs_error(
260                        "the combined length of all the buffers in DescriptorChain would overflow",
261                        None,
262                    ))?;
263                let region = mem.find_region(desc.addr()).ok_or(new_vhost_user_fs_error(
264                    "no memory region for this address range",
265                    None,
266                ))?;
267                let offset = desc
268                    .addr()
269                    .checked_sub(region.start_addr().raw_value())
270                    .unwrap();
271                region
272                    .deref()
273                    .get_slice(offset.raw_value() as usize, desc.len() as usize)
274                    .map_err(|err| {
275                        new_vhost_user_fs_error("volatile memory error", Some(err.into()))
276                    })
277            })
278            .collect::<Result<VecDeque<VolatileSlice<'a, B>>>>()?;
279        Ok(Writer {
280            buffer: DescriptorChainConsumer {
281                buffers,
282                bytes_consumed: 0,
283            },
284        })
285    }
286
287    pub fn split_at(&mut self, offset: usize) -> Result<Writer<'a, B>> {
288        self.buffer.split_at(offset).map(|buffer| Writer { buffer })
289    }
290
291    pub fn write_from_at<F: ReadWriteAtVolatile<B>>(
292        &mut self,
293        src: F,
294        count: usize,
295    ) -> io::Result<usize> {
296        self.buffer
297            .consume(count, |bufs| src.read_vectored_at_volatile(bufs))
298            .map_err(|err| err.into())
299    }
300
301    #[cfg(test)]
302    pub fn available_bytes(&self) -> usize {
303        self.buffer.available_bytes()
304    }
305
306    pub fn bytes_written(&self) -> usize {
307        self.buffer.bytes_consumed()
308    }
309}
310
311impl<B: BitmapSlice> Write for Writer<'_, B> {
312    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
313        self.buffer
314            .consume(buf.len(), |bufs| {
315                let mut rem = buf;
316                let mut total = 0;
317                for vs in bufs {
318                    let copy_len = min(rem.len(), vs.len());
319                    unsafe {
320                        copy_nonoverlapping(rem.as_ptr(), vs.ptr_guard_mut().as_ptr(), copy_len);
321                    }
322                    vs.bitmap().mark_dirty(0, copy_len);
323                    rem = &rem[copy_len..];
324                    total += copy_len;
325                }
326                Ok(total)
327            })
328            .map_err(|err| err.into())
329    }
330
331    fn flush(&mut self) -> io::Result<()> {
332        Ok(())
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use virtio_queue::Queue;
340    use virtio_queue::QueueOwnedT;
341    use virtio_queue::QueueT;
342    use vm_memory::Bytes;
343    use vm_memory::GuestAddress;
344    use vm_memory::Le16;
345    use vm_memory::Le32;
346    use vm_memory::Le64;
347
348    const VIRTQ_DESC_F_NEXT: u16 = 0x1;
349    const VIRTQ_DESC_F_WRITE: u16 = 0x2;
350
351    enum DescriptorType {
352        Readable,
353        Writable,
354    }
355
356    // Helper structure for testing, used to define the layout of the descriptor chain.
357    #[derive(Copy, Clone, Debug, Default)]
358    #[repr(C)]
359    struct VirtqDesc {
360        addr: Le64,
361        len: Le32,
362        flags: Le16,
363        next: Le16,
364    }
365
366    // Helper structure for testing, used to define the layout of the available ring.
367    #[derive(Copy, Clone, Debug, Default)]
368    #[repr(C)]
369    struct VirtqAvail {
370        flags: Le16,
371        idx: Le16,
372        ring: Le16,
373    }
374
375    unsafe impl ByteValued for VirtqAvail {}
376    unsafe impl ByteValued for VirtqDesc {}
377
378    // Helper function for testing, used to create a descriptor chain with the specified descriptors.
379    fn create_descriptor_chain(
380        memory: &GuestMemoryMmap,
381        descriptor_array_addr: GuestAddress,
382        mut buffers_start_addr: GuestAddress,
383        descriptors: Vec<(DescriptorType, u32)>,
384    ) -> DescriptorChain<&GuestMemoryMmap> {
385        let descriptors_len = descriptors.len();
386        for (index, (type_, size)) in descriptors.into_iter().enumerate() {
387            let mut flags = 0;
388            if let DescriptorType::Writable = type_ {
389                flags |= VIRTQ_DESC_F_WRITE;
390            }
391            if index + 1 < descriptors_len {
392                flags |= VIRTQ_DESC_F_NEXT;
393            }
394
395            let desc = VirtqDesc {
396                addr: buffers_start_addr.raw_value().into(),
397                len: size.into(),
398                flags: flags.into(),
399                next: (index as u16 + 1).into(),
400            };
401
402            buffers_start_addr = buffers_start_addr.checked_add(size as u64).unwrap();
403
404            memory
405                .write_obj(
406                    desc,
407                    descriptor_array_addr
408                        .checked_add((index * std::mem::size_of::<VirtqDesc>()) as u64)
409                        .unwrap(),
410                )
411                .unwrap();
412        }
413
414        let avail_ring = descriptor_array_addr
415            .checked_add((descriptors_len * std::mem::size_of::<VirtqDesc>()) as u64)
416            .unwrap();
417        let avail = VirtqAvail {
418            flags: 0.into(),
419            idx: 1.into(),
420            ring: 0.into(),
421        };
422        memory.write_obj(avail, avail_ring).unwrap();
423
424        let mut queue = Queue::new(4).unwrap();
425        queue
426            .try_set_desc_table_address(descriptor_array_addr)
427            .unwrap();
428        queue.try_set_avail_ring_address(avail_ring).unwrap();
429        queue.set_ready(true);
430        queue.iter(memory).unwrap().next().unwrap()
431    }
432
433    #[test]
434    fn simple_chain_reader_test() {
435        let memory_start_addr = GuestAddress(0x0);
436        let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap();
437
438        let chain = create_descriptor_chain(
439            &memory,
440            GuestAddress(0x0),
441            GuestAddress(0x100),
442            vec![
443                (DescriptorType::Readable, 8),
444                (DescriptorType::Readable, 16),
445                (DescriptorType::Readable, 18),
446                (DescriptorType::Readable, 64),
447            ],
448        );
449
450        let mut reader = Reader::new(&memory, chain).unwrap();
451        assert_eq!(reader.available_bytes(), 106);
452        assert_eq!(reader.bytes_read(), 0);
453
454        let mut buffer = [0; 64];
455        reader.read_exact(&mut buffer).unwrap();
456        assert_eq!(reader.available_bytes(), 42);
457        assert_eq!(reader.bytes_read(), 64);
458        assert_eq!(reader.read(&mut buffer).unwrap(), 42);
459        assert_eq!(reader.available_bytes(), 0);
460        assert_eq!(reader.bytes_read(), 106);
461    }
462
463    #[test]
464    fn simple_chain_writer_test() {
465        let memory_start_addr = GuestAddress(0x0);
466        let memory = GuestMemoryMmap::from_ranges(&[(memory_start_addr, 0x1000)]).unwrap();
467
468        let chain = create_descriptor_chain(
469            &memory,
470            GuestAddress(0x0),
471            GuestAddress(0x100),
472            vec![
473                (DescriptorType::Writable, 8),
474                (DescriptorType::Writable, 16),
475                (DescriptorType::Writable, 18),
476                (DescriptorType::Writable, 64),
477            ],
478        );
479
480        let mut writer = Writer::new(&memory, chain).unwrap();
481        assert_eq!(writer.available_bytes(), 106);
482        assert_eq!(writer.bytes_written(), 0);
483
484        let buffer = [0; 64];
485        writer.write_all(&buffer).unwrap();
486        assert_eq!(writer.available_bytes(), 42);
487        assert_eq!(writer.bytes_written(), 64);
488        assert_eq!(writer.write(&buffer).unwrap(), 42);
489        assert_eq!(writer.available_bytes(), 0);
490        assert_eq!(writer.bytes_written(), 106);
491    }
492}