Skip to content

feat: I/O safety pipe, pipe2 & write #2100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/pty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl io::Read for PtyMaster {

impl io::Write for PtyMaster {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
unistd::write(self.0.as_raw_fd(), buf).map_err(io::Error::from)
unistd::write(&self.0, buf).map_err(io::Error::from)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
Expand All @@ -86,7 +86,7 @@ impl io::Read for &PtyMaster {

impl io::Write for &PtyMaster {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
unistd::write(self.0.as_raw_fd(), buf).map_err(io::Error::from)
unistd::write(&self.0, buf).map_err(io::Error::from)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions src/sys/epoll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl EpollEvent {
/// ```
/// # use nix::sys::{epoll::{Epoll, EpollEvent, EpollFlags, EpollCreateFlags}, eventfd::{eventfd, EfdFlags}};
/// # use nix::unistd::write;
/// # use std::os::unix::io::{OwnedFd, FromRawFd, AsRawFd, AsFd};
/// # use std::os::unix::io::{OwnedFd, FromRawFd, AsFd};
/// # use std::time::{Instant, Duration};
/// # fn main() -> nix::Result<()> {
/// const DATA: u64 = 17;
Expand All @@ -87,7 +87,7 @@ impl EpollEvent {
/// epoll.add(&eventfd, EpollEvent::new(EpollFlags::EPOLLIN,DATA))?;
///
/// // Arm eventfd & Time wait
/// write(eventfd.as_raw_fd(), &1u64.to_ne_bytes())?;
/// write(&eventfd, &1u64.to_ne_bytes())?;
/// let now = Instant::now();
///
/// // Wait on event
Expand Down
21 changes: 5 additions & 16 deletions src/sys/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ where
mod tests {
use super::*;
use crate::sys::time::{TimeVal, TimeValLike};
use crate::unistd::{close, pipe, write};
use std::os::unix::io::{FromRawFd, OwnedFd, RawFd};
use crate::unistd::{pipe, write};
use std::os::unix::io::RawFd;

#[test]
fn fdset_insert() {
Expand Down Expand Up @@ -466,12 +466,9 @@ mod tests {
#[test]
fn test_select() {
let (r1, w1) = pipe().unwrap();
let r1 = unsafe { OwnedFd::from_raw_fd(r1) };
let w1 = unsafe { OwnedFd::from_raw_fd(w1) };
let (r2, _w2) = pipe().unwrap();
let r2 = unsafe { OwnedFd::from_raw_fd(r2) };

write(w1.as_raw_fd(), b"hi!").unwrap();
write(&w1, b"hi!").unwrap();
let mut fd_set = FdSet::new();
fd_set.insert(&r1);
fd_set.insert(&r2);
Expand All @@ -483,18 +480,14 @@ mod tests {
);
assert!(fd_set.contains(&r1));
assert!(!fd_set.contains(&r2));
close(_w2).unwrap();
}

#[test]
fn test_select_nfds() {
let (r1, w1) = pipe().unwrap();
let (r2, _w2) = pipe().unwrap();
let r1 = unsafe { OwnedFd::from_raw_fd(r1) };
let w1 = unsafe { OwnedFd::from_raw_fd(w1) };
let r2 = unsafe { OwnedFd::from_raw_fd(r2) };

write(w1.as_raw_fd(), b"hi!").unwrap();
write(&w1, b"hi!").unwrap();
let mut fd_set = FdSet::new();
fd_set.insert(&r1);
fd_set.insert(&r2);
Expand All @@ -521,16 +514,13 @@ mod tests {
}
assert!(fd_set.contains(&r1));
assert!(!fd_set.contains(&r2));
close(_w2).unwrap();
}

#[test]
fn test_select_nfds2() {
let (r1, w1) = pipe().unwrap();
write(w1, b"hi!").unwrap();
write(&w1, b"hi!").unwrap();
let (r2, _w2) = pipe().unwrap();
let r1 = unsafe { OwnedFd::from_raw_fd(r1) };
let r2 = unsafe { OwnedFd::from_raw_fd(r2) };
let mut fd_set = FdSet::new();
fd_set.insert(&r1);
fd_set.insert(&r2);
Expand All @@ -549,6 +539,5 @@ mod tests {
);
assert!(fd_set.contains(&r1));
assert!(!fd_set.contains(&r2));
close(_w2).unwrap();
}
}
4 changes: 2 additions & 2 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ impl<'a> ControlMessage<'a> {
/// let (r, w) = pipe().unwrap();
///
/// let iov = [IoSlice::new(b"hello")];
/// let fds = [r];
/// let fds = [r.as_raw_fd()];
/// let cmsg = ControlMessage::ScmRights(&fds);
/// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap();
/// ```
Expand All @@ -1496,7 +1496,7 @@ impl<'a> ControlMessage<'a> {
/// let (r, w) = pipe().unwrap();
///
/// let iov = [IoSlice::new(b"hello")];
/// let fds = [r];
/// let fds = [r.as_raw_fd()];
/// let cmsg = ControlMessage::ScmRights(&fds);
/// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap();
/// ```
Expand Down
26 changes: 16 additions & 10 deletions src/unistd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use std::ffi::{CString, OsStr};
use std::os::unix::ffi::OsStrExt;
use std::os::unix::ffi::OsStringExt;
use std::os::unix::io::RawFd;
use std::os::unix::io::{AsFd, AsRawFd};
use std::os::unix::io::{AsFd, AsRawFd, OwnedFd};
use std::path::PathBuf;
use std::{fmt, mem, ptr};

Expand Down Expand Up @@ -260,7 +260,7 @@ impl ForkResult {
/// }
/// Ok(ForkResult::Child) => {
/// // Unsafe to use `println!` (or `unwrap`) here. See Safety.
/// write(libc::STDOUT_FILENO, "I'm a new child process\n".as_bytes()).ok();
/// write(std::io::stdout(), "I'm a new child process\n".as_bytes()).ok();
/// unsafe { libc::_exit(0) };
/// }
/// Err(_) => println!("Fork failed"),
Expand Down Expand Up @@ -1115,9 +1115,13 @@ pub fn read(fd: RawFd, buf: &mut [u8]) -> Result<usize> {
/// Write to a raw file descriptor.
///
/// See also [write(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html)
pub fn write(fd: RawFd, buf: &[u8]) -> Result<usize> {
pub fn write<Fd: AsFd>(fd: Fd, buf: &[u8]) -> Result<usize> {
let res = unsafe {
libc::write(fd, buf.as_ptr() as *const c_void, buf.len() as size_t)
libc::write(
fd.as_fd().as_raw_fd(),
buf.as_ptr() as *const c_void,
buf.len() as size_t,
)
};

Errno::result(res).map(|r| r as usize)
Expand Down Expand Up @@ -1189,14 +1193,15 @@ pub fn lseek64(
/// Create an interprocess channel.
///
/// See also [pipe(2)](https://pubs.opengroup.org/onlinepubs/9699919799/functions/pipe.html)
pub fn pipe() -> std::result::Result<(RawFd, RawFd), Error> {
let mut fds = mem::MaybeUninit::<[c_int; 2]>::uninit();
pub fn pipe() -> std::result::Result<(OwnedFd, OwnedFd), Error> {
let mut fds = mem::MaybeUninit::<[OwnedFd; 2]>::uninit();

let res = unsafe { libc::pipe(fds.as_mut_ptr() as *mut c_int) };

Error::result(res)?;

unsafe { Ok((fds.assume_init()[0], fds.assume_init()[1])) }
let [read, write] = unsafe { fds.assume_init() };
Ok((read, write))
}

feature! {
Expand Down Expand Up @@ -1230,15 +1235,16 @@ feature! {
target_os = "openbsd",
target_os = "solaris"
))]
pub fn pipe2(flags: OFlag) -> Result<(RawFd, RawFd)> {
let mut fds = mem::MaybeUninit::<[c_int; 2]>::uninit();
pub fn pipe2(flags: OFlag) -> Result<(OwnedFd, OwnedFd)> {
let mut fds = mem::MaybeUninit::<[OwnedFd; 2]>::uninit();

let res =
unsafe { libc::pipe2(fds.as_mut_ptr() as *mut c_int, flags.bits()) };

Errno::result(res)?;

unsafe { Ok((fds.assume_init()[0], fds.assume_init()[1])) }
let [read, write] = unsafe { fds.assume_init() };
Ok((read, write))
}

/// Truncate a file to a specified length
Expand Down
10 changes: 3 additions & 7 deletions test/sys/test_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@ use nix::sys::select::*;
use nix::sys::signal::SigSet;
use nix::sys::time::{TimeSpec, TimeValLike};
use nix::unistd::{pipe, write};
use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd};
use std::os::unix::io::{AsRawFd, BorrowedFd};

#[test]
pub fn test_pselect() {
let _mtx = crate::SIGNAL_MTX.lock();

let (r1, w1) = pipe().unwrap();
write(w1, b"hi!").unwrap();
let r1 = unsafe { OwnedFd::from_raw_fd(r1) };
write(&w1, b"hi!").unwrap();
let (r2, _w2) = pipe().unwrap();
let r2 = unsafe { OwnedFd::from_raw_fd(r2) };

let mut fd_set = FdSet::new();
fd_set.insert(&r1);
Expand All @@ -31,10 +29,8 @@ pub fn test_pselect() {
#[test]
pub fn test_pselect_nfds2() {
let (r1, w1) = pipe().unwrap();
write(w1, b"hi!").unwrap();
let r1 = unsafe { OwnedFd::from_raw_fd(r1) };
write(&w1, b"hi!").unwrap();
let (r2, _w2) = pipe().unwrap();
let r2 = unsafe { OwnedFd::from_raw_fd(r2) };

let mut fd_set = FdSet::new();
fd_set.insert(&r1);
Expand Down
16 changes: 6 additions & 10 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ pub fn test_socketpair() {
SockFlag::empty(),
)
.unwrap();
write(fd1.as_raw_fd(), b"hello").unwrap();
write(&fd1, b"hello").unwrap();
let mut buf = [0; 5];
read(fd2.as_raw_fd(), &mut buf).unwrap();

Expand Down Expand Up @@ -757,7 +757,7 @@ pub fn test_scm_rights() {

{
let iov = [IoSlice::new(b"hello")];
let fds = [r];
let fds = [r.as_raw_fd()];
let cmsg = ControlMessage::ScmRights(&fds);
assert_eq!(
sendmsg::<()>(
Expand All @@ -770,7 +770,6 @@ pub fn test_scm_rights() {
.unwrap(),
5
);
close(r).unwrap();
}

{
Expand Down Expand Up @@ -803,12 +802,11 @@ pub fn test_scm_rights() {

let received_r = received_r.expect("Did not receive passed fd");
// Ensure that the received file descriptor works
write(w.as_raw_fd(), b"world").unwrap();
write(&w, b"world").unwrap();
let mut buf = [0u8; 5];
read(received_r.as_raw_fd(), &mut buf).unwrap();
assert_eq!(&buf[..], b"world");
close(received_r).unwrap();
close(w).unwrap();
}

// Disable the test on emulated platforms due to not enabled support of AF_ALG in QEMU from rust cross
Expand Down Expand Up @@ -1451,7 +1449,7 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
gid: getgid().as_raw(),
}
.into();
let fds = [r];
let fds = [r.as_raw_fd()];
let cmsgs = [
ControlMessage::ScmCredentials(&cred),
ControlMessage::ScmRights(&fds),
Expand All @@ -1467,7 +1465,6 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {
.unwrap(),
5
);
close(r).unwrap();
}

{
Expand Down Expand Up @@ -1510,12 +1507,11 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec<u8>) {

let received_r = received_r.expect("Did not receive passed fd");
// Ensure that the received file descriptor works
write(w.as_raw_fd(), b"world").unwrap();
write(&w, b"world").unwrap();
let mut buf = [0u8; 5];
read(received_r.as_raw_fd(), &mut buf).unwrap();
assert_eq!(&buf[..], b"world");
close(received_r).unwrap();
close(w).unwrap();
}

// Test creating and using named unix domain sockets
Expand Down Expand Up @@ -1548,7 +1544,7 @@ pub fn test_named_unixdomain() {
)
.expect("socket failed");
connect(s2.as_raw_fd(), &sockaddr).expect("connect failed");
write(s2.as_raw_fd(), b"hello").expect("write failed");
write(&s2, b"hello").expect("write failed");
});

let s3 = accept(s1.as_raw_fd()).expect("accept failed");
Expand Down
6 changes: 3 additions & 3 deletions test/sys/test_sockopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use nix::sys::socket::{
SockProtocol, SockType,
};
use rand::{thread_rng, Rng};
use std::os::unix::io::AsRawFd;
use std::os::unix::io::{AsRawFd, FromRawFd, OwnedFd};

// NB: FreeBSD supports LOCAL_PEERCRED for SOCK_SEQPACKET, but OSX does not.
#[cfg(any(target_os = "dragonfly", target_os = "freebsd",))]
Expand Down Expand Up @@ -151,7 +151,8 @@ fn test_so_tcp_maxseg() {
.unwrap();
connect(ssock.as_raw_fd(), &sock_addr).unwrap();
let rsess = accept(rsock.as_raw_fd()).unwrap();
write(rsess, b"hello").unwrap();
let rsess = unsafe { OwnedFd::from_raw_fd(rsess) };
write(&rsess, b"hello").unwrap();
let actual = getsockopt(&ssock, sockopt::TcpMaxSeg).unwrap();
// Actual max segment size takes header lengths into account, max IPv4 options (60 bytes) + max
// TCP options (40 bytes) are subtracted from the requested maximum as a lower boundary.
Expand Down Expand Up @@ -185,7 +186,6 @@ fn test_so_type() {
#[test]
fn test_so_type_unknown() {
use nix::errno::Errno;
use std::os::unix::io::{FromRawFd, OwnedFd};

require_capability!("test_so_type", CAP_NET_RAW);
let raw_fd = unsafe { libc::socket(libc::AF_PACKET, libc::SOCK_PACKET, 0) };
Expand Down
2 changes: 1 addition & 1 deletion test/sys/test_termios.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use nix::unistd::{read, write};
fn write_all<Fd: AsFd>(f: Fd, buf: &[u8]) {
let mut len = 0;
while len < buf.len() {
len += write(f.as_fd().as_raw_fd(), &buf[len..]).unwrap();
len += write(f.as_fd(), &buf[len..]).unwrap();
}
}

Expand Down
Loading