opendal_core/services/memcached/
binary.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use tokio::io::AsyncReadExt;
19use tokio::io::AsyncWriteExt;
20use tokio::io::BufReader;
21use tokio::io::{self};
22use tokio::net::TcpStream;
23
24use crate::raw::*;
25use crate::*;
26
27pub(super) mod constants {
28    pub const OK_STATUS: u16 = 0x0;
29    pub const KEY_NOT_FOUND: u16 = 0x1;
30}
31
32pub enum Opcode {
33    Get = 0x00,
34    Set = 0x01,
35    Delete = 0x04,
36    Version = 0x0b,
37    StartAuth = 0x21,
38}
39
40pub enum Magic {
41    Request = 0x80,
42}
43
44#[derive(Debug)]
45pub struct StoreExtras {
46    pub flags: u32,
47    pub expiration: u32,
48}
49
50#[derive(Debug, Default)]
51pub struct PacketHeader {
52    pub magic: u8,
53    pub opcode: u8,
54    pub key_length: u16,
55    pub extras_length: u8,
56    pub data_type: u8,
57    pub vbucket_id_or_status: u16,
58    pub total_body_length: u32,
59    pub opaque: u32,
60    pub cas: u64,
61}
62
63impl PacketHeader {
64    pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> {
65        writer.write_u8(self.magic).await?;
66        writer.write_u8(self.opcode).await?;
67        writer.write_u16(self.key_length).await?;
68        writer.write_u8(self.extras_length).await?;
69        writer.write_u8(self.data_type).await?;
70        writer.write_u16(self.vbucket_id_or_status).await?;
71        writer.write_u32(self.total_body_length).await?;
72        writer.write_u32(self.opaque).await?;
73        writer.write_u64(self.cas).await?;
74        Ok(())
75    }
76
77    pub async fn read(reader: &mut TcpStream) -> Result<PacketHeader, io::Error> {
78        let header = PacketHeader {
79            magic: reader.read_u8().await?,
80            opcode: reader.read_u8().await?,
81            key_length: reader.read_u16().await?,
82            extras_length: reader.read_u8().await?,
83            data_type: reader.read_u8().await?,
84            vbucket_id_or_status: reader.read_u16().await?,
85            total_body_length: reader.read_u32().await?,
86            opaque: reader.read_u32().await?,
87            cas: reader.read_u64().await?,
88        };
89        Ok(header)
90    }
91}
92
93pub struct Response {
94    header: PacketHeader,
95    _key: Vec<u8>,
96    _extras: Vec<u8>,
97    value: Vec<u8>,
98}
99
100#[derive(Debug)]
101pub struct Connection {
102    io: BufReader<TcpStream>,
103}
104
105impl Connection {
106    pub fn new(io: TcpStream) -> Self {
107        Self {
108            io: BufReader::new(io),
109        }
110    }
111
112    pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> {
113        let writer = self.io.get_mut();
114        let key = "PLAIN";
115        let request_header = PacketHeader {
116            magic: Magic::Request as u8,
117            opcode: Opcode::StartAuth as u8,
118            key_length: key.len() as u16,
119            total_body_length: (key.len() + username.len() + password.len() + 2) as u32,
120            ..Default::default()
121        };
122        request_header
123            .write(writer)
124            .await
125            .map_err(new_std_io_error)?;
126        writer
127            .write_all(key.as_bytes())
128            .await
129            .map_err(new_std_io_error)?;
130        writer
131            .write_all(format!("\x00{username}\x00{password}").as_bytes())
132            .await
133            .map_err(new_std_io_error)?;
134        writer.flush().await.map_err(new_std_io_error)?;
135        parse_response(writer).await?;
136        Ok(())
137    }
138
139    pub async fn version(&mut self) -> Result<String> {
140        let writer = self.io.get_mut();
141        let request_header = PacketHeader {
142            magic: Magic::Request as u8,
143            opcode: Opcode::Version as u8,
144            ..Default::default()
145        };
146        request_header
147            .write(writer)
148            .await
149            .map_err(new_std_io_error)?;
150        writer.flush().await.map_err(new_std_io_error)?;
151        let response = parse_response(writer).await?;
152        let version = String::from_utf8(response.value);
153        match version {
154            Ok(version) => Ok(version),
155            Err(e) => {
156                Err(Error::new(ErrorKind::Unexpected, "unexpected data received").set_source(e))
157            }
158        }
159    }
160
161    pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
162        let writer = self.io.get_mut();
163        let request_header = PacketHeader {
164            magic: Magic::Request as u8,
165            opcode: Opcode::Get as u8,
166            key_length: key.len() as u16,
167            total_body_length: key.len() as u32,
168            ..Default::default()
169        };
170        request_header
171            .write(writer)
172            .await
173            .map_err(new_std_io_error)?;
174        writer
175            .write_all(key.as_bytes())
176            .await
177            .map_err(new_std_io_error)?;
178        writer.flush().await.map_err(new_std_io_error)?;
179        match parse_response(writer).await {
180            Ok(response) => {
181                if response.header.vbucket_id_or_status == 0x1 {
182                    return Ok(None);
183                }
184                Ok(Some(response.value))
185            }
186            Err(e) => Err(e),
187        }
188    }
189
190    pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> {
191        let writer = self.io.get_mut();
192        let request_header = PacketHeader {
193            magic: Magic::Request as u8,
194            opcode: Opcode::Set as u8,
195            key_length: key.len() as u16,
196            extras_length: 8,
197            total_body_length: (8 + key.len() + val.len()) as u32,
198            ..Default::default()
199        };
200        let extras = StoreExtras {
201            flags: 0,
202            expiration,
203        };
204        request_header
205            .write(writer)
206            .await
207            .map_err(new_std_io_error)?;
208        writer
209            .write_u32(extras.flags)
210            .await
211            .map_err(new_std_io_error)?;
212        writer
213            .write_u32(extras.expiration)
214            .await
215            .map_err(new_std_io_error)?;
216        writer
217            .write_all(key.as_bytes())
218            .await
219            .map_err(new_std_io_error)?;
220        writer.write_all(val).await.map_err(new_std_io_error)?;
221        writer.flush().await.map_err(new_std_io_error)?;
222
223        parse_response(writer).await?;
224        Ok(())
225    }
226
227    pub async fn delete(&mut self, key: &str) -> Result<()> {
228        let writer = self.io.get_mut();
229        let request_header = PacketHeader {
230            magic: Magic::Request as u8,
231            opcode: Opcode::Delete as u8,
232            key_length: key.len() as u16,
233            total_body_length: key.len() as u32,
234            ..Default::default()
235        };
236        request_header
237            .write(writer)
238            .await
239            .map_err(new_std_io_error)?;
240        writer
241            .write_all(key.as_bytes())
242            .await
243            .map_err(new_std_io_error)?;
244        writer.flush().await.map_err(new_std_io_error)?;
245        parse_response(writer).await?;
246        Ok(())
247    }
248}
249
250pub async fn parse_response(reader: &mut TcpStream) -> Result<Response> {
251    let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?;
252
253    if header.vbucket_id_or_status != constants::OK_STATUS
254        && header.vbucket_id_or_status != constants::KEY_NOT_FOUND
255    {
256        return Err(
257            Error::new(ErrorKind::Unexpected, "unexpected status received")
258                .with_context("message", format!("{}", header.vbucket_id_or_status)),
259        );
260    }
261
262    let mut extras = vec![0x0; header.extras_length as usize];
263    reader
264        .read_exact(extras.as_mut_slice())
265        .await
266        .map_err(new_std_io_error)?;
267
268    let mut key = vec![0x0; header.key_length as usize];
269    reader
270        .read_exact(key.as_mut_slice())
271        .await
272        .map_err(new_std_io_error)?;
273
274    let mut value = vec![
275        0x0;
276        (header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length))
277            as usize
278    ];
279    reader
280        .read_exact(value.as_mut_slice())
281        .await
282        .map_err(new_std_io_error)?;
283
284    Ok(Response {
285        header,
286        _key: key,
287        _extras: extras,
288        value,
289    })
290}