Skip to content

Store StreamWrapper::inner as a raw pointer #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 94 additions & 79 deletions src/ffi/c.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::cmp;
use std::convert::TryFrom;
use std::fmt;
use std::marker;
use std::ops::{Deref, DerefMut};
use std::os::raw::{c_int, c_uint, c_void};
use std::ptr;

Expand All @@ -21,7 +20,10 @@ impl ErrorMessage {
}

pub struct StreamWrapper {
pub inner: Box<mz_stream>,
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure, and it must never be copied
// by Rust.
pub inner: *mut mz_stream,
}

impl fmt::Debug for StreamWrapper {
Expand All @@ -32,8 +34,12 @@ impl fmt::Debug for StreamWrapper {

impl Default for StreamWrapper {
fn default() -> StreamWrapper {
// SAFETY: The field `state` will be initialized across the FFI to
// point to the opaque type `mz_internal_state`, which will contain a copy
// of `inner`. This cyclic structure breaks the uniqueness invariant of
// &mut mz_stream, so we must use a raw pointer instead of Box<mz_stream>.
StreamWrapper {
inner: Box::new(mz_stream {
inner: Box::into_raw(Box::new(mz_stream {
next_in: ptr::null_mut(),
avail_in: 0,
total_in: 0,
Expand All @@ -54,11 +60,21 @@ impl Default for StreamWrapper {
zalloc: Some(zalloc),
#[cfg(not(all(feature = "any_zlib", not(feature = "cloudflare-zlib-sys"))))]
zfree: Some(zfree),
}),
})),
}
}
}

impl Drop for StreamWrapper {
fn drop(&mut self) {
// SAFETY: At this point, every other allocation for struct has been freed by
// `inflateEnd` or `deflateEnd`, and no copies of `inner` are retained by `C`,
// so it is safe to drop the struct as long as the user respects the invariant that
// `inner` must never be copied by Rust.
drop(unsafe { Box::from_raw(self.inner) });
}
}

const ALIGN: usize = std::mem::align_of::<usize>();

fn align_up(size: usize, align: usize) -> usize {
Expand Down Expand Up @@ -110,20 +126,6 @@ extern "C" fn zfree(_ptr: *mut c_void, address: *mut c_void) {
}
}

impl Deref for StreamWrapper {
type Target = mz_stream;

fn deref(&self) -> &Self::Target {
&*self.inner
}
}

impl DerefMut for StreamWrapper {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.inner
}
}

unsafe impl<D: Direction> Send for Stream<D> {}
unsafe impl<D: Direction> Sync for Stream<D> {}

Expand All @@ -148,7 +150,10 @@ pub struct Stream<D: Direction> {

impl<D: Direction> Stream<D> {
pub fn msg(&self) -> ErrorMessage {
let msg = self.stream_wrapper.msg;
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self`.
let msg = unsafe { (*self.stream_wrapper.inner).msg };
ErrorMessage(if msg.is_null() {
None
} else {
Expand All @@ -161,7 +166,7 @@ impl<D: Direction> Stream<D> {
impl<D: Direction> Drop for Stream<D> {
fn drop(&mut self) {
unsafe {
let _ = D::destroy(&mut *self.stream_wrapper);
let _ = D::destroy(self.stream_wrapper.inner);
}
}
}
Expand All @@ -185,9 +190,9 @@ pub struct Inflate {
impl InflateBackend for Inflate {
fn make(zlib_header: bool, window_bits: u8) -> Self {
unsafe {
let mut state = StreamWrapper::default();
let state = StreamWrapper::default();
let ret = mz_inflateInit2(
&mut *state,
state.inner,
if zlib_header {
window_bits as c_int
} else {
Expand All @@ -212,33 +217,38 @@ impl InflateBackend for Inflate {
output: &mut [u8],
flush: FlushDecompress,
) -> Result<Status, DecompressError> {
let raw = &mut *self.inner.stream_wrapper;
raw.msg = ptr::null_mut();
raw.next_in = input.as_ptr() as *mut u8;
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
raw.next_out = output.as_mut_ptr();
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = unsafe { mz_inflate(raw, flush as c_int) };

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
raw.next_in = ptr::null_mut();
raw.avail_in = 0;
raw.next_out = ptr::null_mut();
raw.avail_out = 0;

match rc {
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_NEED_DICT => mem::decompress_need_dict(raw.adler as u32),
c => panic!("unknown return code: {}", c),
let raw = self.inner.stream_wrapper.inner;
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self`.
unsafe {
(*raw).msg = ptr::null_mut();
(*raw).next_in = input.as_ptr() as *mut u8;
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
(*raw).next_out = output.as_mut_ptr();
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = mz_inflate(raw, flush as c_int);

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
(*raw).next_in = ptr::null_mut();
(*raw).avail_in = 0;
(*raw).next_out = ptr::null_mut();
(*raw).avail_out = 0;

match rc {
MZ_DATA_ERROR | MZ_STREAM_ERROR => mem::decompress_failed(self.inner.msg()),
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_NEED_DICT => mem::decompress_need_dict((*raw).adler as u32),
c => panic!("unknown return code: {}", c),
}
}
}

Expand All @@ -249,7 +259,7 @@ impl InflateBackend for Inflate {
-MZ_DEFAULT_WINDOW_BITS
};
unsafe {
inflateReset2(&mut *self.inner.stream_wrapper, bits);
inflateReset2(self.inner.stream_wrapper.inner, bits);
}
self.inner.total_out = 0;
self.inner.total_in = 0;
Expand All @@ -276,9 +286,9 @@ pub struct Deflate {
impl DeflateBackend for Deflate {
fn make(level: Compression, zlib_header: bool, window_bits: u8) -> Self {
unsafe {
let mut state = StreamWrapper::default();
let state = StreamWrapper::default();
let ret = mz_deflateInit2(
&mut *state,
state.inner,
level.0 as c_int,
MZ_DEFLATED,
if zlib_header {
Expand Down Expand Up @@ -306,39 +316,44 @@ impl DeflateBackend for Deflate {
output: &mut [u8],
flush: FlushCompress,
) -> Result<Status, CompressError> {
let raw = &mut *self.inner.stream_wrapper;
raw.msg = ptr::null_mut();
raw.next_in = input.as_ptr() as *mut _;
raw.avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
raw.next_out = output.as_mut_ptr();
raw.avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = unsafe { mz_deflate(raw, flush as c_int) };

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.
self.inner.total_in += (raw.next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += (raw.next_out as usize - output.as_ptr() as usize) as u64;

// reset these pointers so we don't accidentally read them later
raw.next_in = ptr::null_mut();
raw.avail_in = 0;
raw.next_out = ptr::null_mut();
raw.avail_out = 0;

match rc {
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
c => panic!("unknown return code: {}", c),
let raw = self.inner.stream_wrapper.inner;
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self`.
unsafe {
(*raw).msg = ptr::null_mut();
(*raw).next_in = input.as_ptr() as *mut _;
(*raw).avail_in = cmp::min(input.len(), c_uint::MAX as usize) as c_uint;
(*raw).next_out = output.as_mut_ptr();
(*raw).avail_out = cmp::min(output.len(), c_uint::MAX as usize) as c_uint;

let rc = mz_deflate(raw, flush as c_int);

// Unfortunately the total counters provided by zlib might be only
// 32 bits wide and overflow while processing large amounts of data.

self.inner.total_in += ((*raw).next_in as usize - input.as_ptr() as usize) as u64;
self.inner.total_out += ((*raw).next_out as usize - output.as_ptr() as usize) as u64;
// reset these pointers so we don't accidentally read them later
(*raw).next_in = ptr::null_mut();
(*raw).avail_in = 0;
(*raw).next_out = ptr::null_mut();
(*raw).avail_out = 0;

match rc {
MZ_OK => Ok(Status::Ok),
MZ_BUF_ERROR => Ok(Status::BufError),
MZ_STREAM_END => Ok(Status::StreamEnd),
MZ_STREAM_ERROR => mem::compress_failed(self.inner.msg()),
c => panic!("unknown return code: {}", c),
}
}
}

fn reset(&mut self) {
self.inner.total_in = 0;
self.inner.total_out = 0;
let rc = unsafe { mz_deflateReset(&mut *self.inner.stream_wrapper) };
let rc = unsafe { mz_deflateReset(self.inner.stream_wrapper.inner) };
assert_eq!(rc, MZ_OK);
}
}
Expand Down
30 changes: 20 additions & 10 deletions src/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,19 @@ impl Compress {
/// Returns the Adler-32 checksum of the dictionary.
#[cfg(feature = "any_zlib")]
pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result<u32, CompressError> {
let stream = &mut *self.inner.inner.stream_wrapper;
stream.msg = std::ptr::null_mut();
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self.inner.inner.stream_wrapper`.
let stream = self.inner.inner.stream_wrapper.inner;
let rc = unsafe {
(*stream).msg = std::ptr::null_mut();
assert!(dictionary.len() < ffi::uInt::MAX as usize);
ffi::deflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt)
};

match rc {
ffi::MZ_STREAM_ERROR => compress_failed(self.inner.inner.msg()),
ffi::MZ_OK => Ok(stream.adler as u32),
ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32),
c => panic!("unknown return code: {}", c),
}
}
Expand All @@ -299,9 +302,13 @@ impl Compress {
#[cfg(feature = "any_zlib")]
pub fn set_level(&mut self, level: Compression) -> Result<(), CompressError> {
use std::os::raw::c_int;
let stream = &mut *self.inner.inner.stream_wrapper;
stream.msg = std::ptr::null_mut();

// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self.inner.inner.stream_wrapper`.
let stream = self.inner.inner.stream_wrapper.inner;
unsafe {
(*stream).msg = std::ptr::null_mut();
}
let rc = unsafe { ffi::deflateParams(stream, level.0 as c_int, ffi::MZ_DEFAULT_STRATEGY) };

match rc {
Expand Down Expand Up @@ -476,17 +483,20 @@ impl Decompress {
/// Specifies the decompression dictionary to use.
#[cfg(feature = "any_zlib")]
pub fn set_dictionary(&mut self, dictionary: &[u8]) -> Result<u32, DecompressError> {
let stream = &mut *self.inner.inner.stream_wrapper;
stream.msg = std::ptr::null_mut();
// SAFETY: The field `inner` must always be accessed as a raw pointer,
// since it points to a cyclic structure. No copies of `inner` can be
// retained for longer than the lifetime of `self.inner.inner.stream_wrapper`.
let stream = self.inner.inner.stream_wrapper.inner;
let rc = unsafe {
(*stream).msg = std::ptr::null_mut();
assert!(dictionary.len() < ffi::uInt::MAX as usize);
ffi::inflateSetDictionary(stream, dictionary.as_ptr(), dictionary.len() as ffi::uInt)
};

match rc {
ffi::MZ_STREAM_ERROR => decompress_failed(self.inner.inner.msg()),
ffi::MZ_DATA_ERROR => decompress_need_dict(stream.adler as u32),
ffi::MZ_OK => Ok(stream.adler as u32),
ffi::MZ_DATA_ERROR => decompress_need_dict(unsafe { (*stream).adler } as u32),
ffi::MZ_OK => Ok(unsafe { (*stream).adler } as u32),
c => panic!("unknown return code: {}", c),
}
}
Expand Down