-
-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ref(sourcebundle): Check UTF-8 validity memory efficiently
The current check to ensure a sourcebundle is valid UTF-8 reads the entire sourcebundle file into memory. This is inefficient for large files. This PR introduces a UTF8Reader which wraps any reader. The UTF8Reader ensures that the stream is valid UTF8 as it is being read, while only requiring a small amount of memory (currently 8 KiB) to be allocated as a buffer.
- Loading branch information
1 parent
ffd1ec5
commit d7e4852
Showing
2 changed files
with
322 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
//! UTF-8 reader used by the sourcebundle module to read files. | ||
use std::io::{BufRead, Error, ErrorKind, Read, Result}; | ||
use thiserror::Error; | ||
|
||
const MAX_UTF8_SEQUENCE_SIZE: usize = 4; | ||
const MAX_BUFFER_SIZE: usize = 8 * 1024; | ||
|
||
#[derive(Debug, Error)] | ||
#[error("Invalid UTF-8 sequence")] | ||
pub struct UTF8ReaderError { | ||
/// Make it impossible to construct this struct outside of this module. | ||
_a: (), | ||
} | ||
|
||
impl UTF8ReaderError { | ||
fn new() -> Self { | ||
Self { _a: () } | ||
} | ||
} | ||
|
||
pub struct Utf8Reader<R> { | ||
inner: R, | ||
|
||
/// buffer_vec always contains a valid UTF-8 sequence. It is possible that | ||
/// buffer_vec[buffer_pos..] does not contain a valid UTF-8 sequence, if | ||
/// buffer_pos is in the middle of a multi-byte UTF-8 character. | ||
buffer_vec: Vec<u8>, | ||
buffer_pos: usize, | ||
} | ||
|
||
impl<R> Utf8Reader<R> { | ||
pub fn new(inner: R) -> Self { | ||
Self { | ||
inner, | ||
buffer_vec: vec![], | ||
buffer_pos: 0, | ||
} | ||
} | ||
|
||
fn buffer(&self) -> &[u8] { | ||
&self.buffer_vec[self.buffer_pos..] | ||
} | ||
} | ||
|
||
impl<R> Utf8Reader<R> | ||
where | ||
R: Read, | ||
{ | ||
/// Reads a UTF-8 sequence from the inner reader into the buffer, returning the number | ||
/// of bytes read. Errors if this is not possible. | ||
/// | ||
/// The function guarantees that the function will return an Ok variant with a positive | ||
/// number of bytes read if it is possible to read a valid UTF-8 sequence from the inner | ||
/// reader. The function also guarantees that all of the bytes read from the inner reader will | ||
/// be stored in in the buffer, at indices up to the value returned by the function (in other | ||
/// words, no bytes are lost). | ||
/// | ||
/// Panics if the buffer is not at least `MAX_UTF8_SEQUENCE_SIZE` bytes long. | ||
fn read_utf8(&mut self, buf: &mut [u8]) -> Result<usize> { | ||
if buf.len() < MAX_UTF8_SEQUENCE_SIZE { | ||
panic!("Buffer needs to be at least {MAX_UTF8_SEQUENCE_SIZE} bytes long"); | ||
} | ||
|
||
// We need to leave at least three bytes in the buffer in case the first read | ||
// ends in the middle of a UTF-8 character. In the worst case, we end at the first | ||
// byte of a 4-byte UTF-8 character, and we need to read 3 more bytes to get a | ||
// valid sequence. | ||
let bytes_to_read = buf.len() - MAX_UTF8_SEQUENCE_SIZE + 1; | ||
let read_buf = &mut buf[..bytes_to_read]; | ||
|
||
let bytes_read = self.inner.read(read_buf)?; | ||
let read_portion = &buf[..bytes_read]; | ||
|
||
let valid_up_to = utf8_up_to(read_portion); | ||
let invalid_portion = &read_portion[valid_up_to..]; | ||
|
||
if invalid_portion.len() >= MAX_UTF8_SEQUENCE_SIZE { | ||
return Err(Error::new(ErrorKind::InvalidData, UTF8ReaderError::new())); | ||
} | ||
|
||
// We use this buffer to read up to 4 bytes from the inner reader | ||
// to try to get a valid UTF-8 sequence. | ||
let mut next_character_buffer = Vec::from(invalid_portion); | ||
|
||
// read_until_utf8 will not read anything if the buffer is empty, | ||
// since an empty buffer is a valid UTF-8 sequence. | ||
if self.read_until_utf8(&mut next_character_buffer)? { | ||
let next_character_bytes = next_character_buffer.len(); | ||
let total_bytes_read = valid_up_to + next_character_bytes; | ||
buf[valid_up_to..total_bytes_read].copy_from_slice(&next_character_buffer); | ||
Ok(total_bytes_read) | ||
} else { | ||
// We have 4 bytes in the buffer, but they do not form a valid | ||
// UTF-8 sequence. | ||
Err(Error::new(ErrorKind::InvalidData, UTF8ReaderError::new())) | ||
} | ||
} | ||
|
||
/// Reads a single byte at a time from the inner reader into the buffer until | ||
/// either the buffer contains a valid UTF-8 sequence, the reader is | ||
/// exhausted, or the buffer contains 4 bytes (the maximum size of a | ||
/// UTF-8 sequence). | ||
/// | ||
/// If the buffer is empty, we will return Ok(true) without reading anything | ||
/// from the reader. An empty byte sequence is a valid UTF-8 sequence (the | ||
/// empty string). | ||
/// | ||
/// Returns whether a valid UTF-8 sequence was found, or an error if the | ||
/// reader errors. | ||
fn read_until_utf8(&mut self, buffer: &mut Vec<u8>) -> Result<bool> { | ||
while std::str::from_utf8(buffer).is_err() { | ||
if buffer.len() >= MAX_UTF8_SEQUENCE_SIZE { | ||
// We already have 4 bytes in the buffer (maximum UTF-8 sequence size) | ||
// so we cannot form a valid UTF-8 sequence by reading more bytes. | ||
return Ok(false); | ||
} | ||
|
||
let mut byte = [0; 1]; | ||
if self.inner.read(&mut byte)? == 0 { | ||
// Stream has been exhausted without finding | ||
// a valid UTF-8 sequence. | ||
return Ok(false); | ||
} | ||
|
||
buffer.push(byte[0]); | ||
} | ||
|
||
Ok(true) | ||
} | ||
} | ||
|
||
impl<R> Read for Utf8Reader<R> | ||
where | ||
R: Read, | ||
{ | ||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { | ||
if self.buffer().is_empty() && buf.len() > MAX_BUFFER_SIZE { | ||
// If the buffer is bigger than the | ||
return self.read_utf8(buf); | ||
} | ||
|
||
self.fill_buf()?; | ||
|
||
let bytes_to_copy = std::cmp::min(buf.len(), self.buffer().len()); | ||
buf[..bytes_to_copy].copy_from_slice(&self.buffer()[..bytes_to_copy]); | ||
self.consume(bytes_to_copy); | ||
|
||
Ok(bytes_to_copy) | ||
} | ||
} | ||
|
||
impl<R> BufRead for Utf8Reader<R> | ||
where | ||
R: Read, | ||
{ | ||
fn fill_buf(&mut self) -> std::io::Result<&[u8]> { | ||
if !self.buffer().is_empty() { | ||
return Ok(self.buffer()); | ||
} | ||
|
||
let mut buf = [0; MAX_BUFFER_SIZE]; | ||
let bytes_read = self.read_utf8(&mut buf)?; | ||
|
||
self.buffer_vec = buf[..bytes_read].into(); | ||
self.buffer_pos = 0; | ||
|
||
Ok(self.buffer()) | ||
} | ||
|
||
fn consume(&mut self, amt: usize) { | ||
self.buffer_pos += amt; | ||
} | ||
} | ||
|
||
/// Returns the index of the first invalid UTF-8 sequence | ||
/// in the given bytes. If the sequence is valid, returns the | ||
/// length of the bytes. | ||
fn utf8_up_to(bytes: &[u8]) -> usize { | ||
match std::str::from_utf8(bytes) { | ||
Ok(_) => bytes.len(), | ||
Err(e) => e.valid_up_to(), | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
use std::io::Cursor; | ||
|
||
#[test] | ||
fn test_read_empty() { | ||
let mut empty_reader = Utf8Reader::new(Cursor::new(b"")); | ||
|
||
let mut buf = vec![]; | ||
empty_reader | ||
.read_to_end(&mut buf) | ||
.expect("read_to_end errored"); | ||
|
||
assert_eq!(buf, b""); | ||
} | ||
|
||
#[test] | ||
fn test_read_ascii_simple() { | ||
let mut reader = Utf8Reader::new(Cursor::new(b"Hello, world!")); | ||
|
||
let mut buf = vec![]; | ||
reader.read_to_end(&mut buf).expect("read_to_end errored"); | ||
|
||
assert_eq!(buf, b"Hello, world!"); | ||
} | ||
|
||
#[test] | ||
fn test_read_utf8_simple() { | ||
const HELLO_WORLD: &str = "δ½ ε₯½δΈηοΌ"; | ||
let mut reader = Utf8Reader::new(Cursor::new(HELLO_WORLD.as_bytes())); | ||
|
||
let mut buf = vec![]; | ||
reader.read_to_end(&mut buf).expect("read_to_end errored"); | ||
|
||
assert_eq!(buf, HELLO_WORLD.as_bytes()); | ||
} | ||
|
||
#[test] | ||
fn multibyte_character_at_end_of_buffer() { | ||
// Having a multibyte character at the end of the buffer will cause us to hit | ||
// read_until_utf8. | ||
let mut read_buffer = vec![b'a'; MAX_BUFFER_SIZE - MAX_UTF8_SEQUENCE_SIZE]; | ||
read_buffer.extend("π".as_bytes()); | ||
|
||
let mut reader = Utf8Reader::new(Cursor::new(&read_buffer)); | ||
|
||
let mut buf = [0; MAX_BUFFER_SIZE]; | ||
let bytes_read = reader.read(&mut buf).expect("read errored"); | ||
|
||
// We expect the buffer to be filled with the read bytes. We first read all but the last | ||
// three bytes. But, we will need to fill these three bytes to get a valid UTF-8 sequence, | ||
// since "π" is a 4-byte UTF-8 sequence. | ||
assert_eq!(bytes_read, buf.len(), "buffer not filled"); | ||
assert_eq!(&buf[..], read_buffer); | ||
} | ||
|
||
#[test] | ||
fn multibyte_character_at_end_of_big_read() { | ||
// Big reads bypass buffering, so basically, we retest multibyte_character_at_end_of_buffer | ||
// for this case. | ||
let mut read_buffer = vec![b'a'; MAX_BUFFER_SIZE + 10]; | ||
read_buffer.extend("π".as_bytes()); | ||
|
||
let mut reader = Utf8Reader::new(Cursor::new(&read_buffer)); | ||
|
||
let mut buf = [0; MAX_BUFFER_SIZE + MAX_UTF8_SEQUENCE_SIZE + 10]; | ||
let bytes_read = reader.read(&mut buf).expect("read errored"); | ||
|
||
assert_eq!(bytes_read, buf.len(), "buffer not filled"); | ||
assert_eq!(&buf[..], read_buffer); | ||
} | ||
|
||
#[test] | ||
fn small_reads_splitting_sequence() { | ||
let mut reader = Utf8Reader::new(Cursor::new("π".as_bytes())); | ||
|
||
let mut buf = [0; MAX_UTF8_SEQUENCE_SIZE]; | ||
|
||
for i in 0..MAX_UTF8_SEQUENCE_SIZE { | ||
// Read at most one byte at a time. | ||
let bytes_read = reader.read(&mut buf[i..i + 1]).expect("read errored"); | ||
assert_eq!(bytes_read, 1, "bytes read"); | ||
} | ||
|
||
assert_eq!(&buf[..], "π".as_bytes()); | ||
} | ||
|
||
#[test] | ||
fn invalid_utf8_sequence() { | ||
let mut reader = Utf8Reader::new(Cursor::new([0b11111111])); | ||
|
||
let mut buf = [0; 1]; | ||
reader.read(&mut buf).expect_err("read should have errored"); | ||
} | ||
|
||
#[test] | ||
fn invalid_utf8_sequence_at_end_of_reader() { | ||
let mut read_buffer = Vec::from(b"Hello, world!"); | ||
|
||
// Cutting off the last byte will invalidate the UTF-8 sequence. | ||
let invalid_sequence = &"π".as_bytes()[..'π'.len_utf8() - 1]; | ||
read_buffer.extend(invalid_sequence); | ||
|
||
let mut reader = Utf8Reader::new(Cursor::new(&read_buffer)); | ||
reader | ||
.read_to_end(&mut vec![]) | ||
.expect_err("read should have errored"); | ||
} | ||
|
||
#[test] | ||
fn invalid_utf8_sequence_at_end_of_reader_and_buffer() { | ||
let mut read_buffer = vec![b'a'; MAX_BUFFER_SIZE - MAX_UTF8_SEQUENCE_SIZE]; | ||
|
||
// Cutting off the last byte will invalidate the UTF-8 sequence. | ||
let invalid_sequence = &"π".as_bytes()[..'π'.len_utf8() - 1]; | ||
read_buffer.extend(invalid_sequence); | ||
|
||
let mut reader = Utf8Reader::new(Cursor::new(&read_buffer)); | ||
|
||
let mut buf = [0; MAX_BUFFER_SIZE]; | ||
reader.read(&mut buf).expect_err("read should have errored"); | ||
} | ||
} |