From 3ee02bbb5330fc2fad6fff924493991832aeadc7 Mon Sep 17 00:00:00 2001 From: Jorge Ortega Date: Sat, 5 Apr 2025 16:15:34 -0700 Subject: [PATCH] reactor(cust_raw): Generate bindings for cuda types and restructure `cust_raw`'s crates - Allow list files instead of types/var/functions. - Split out type headers from runtime/cublas to their own crates. --- crates/blastoff/src/context.rs | 4 +- crates/cust_raw/Cargo.toml | 25 +-- crates/cust_raw/build/driver_wrapper.h | 4 +- crates/cust_raw/build/main.rs | 143 ++++++++++++++---- crates/cust_raw/src/cublas_sys.rs | 5 - .../src/{cublaslt_sys.rs => cublas_sys/lt.rs} | 7 +- crates/cust_raw/src/cublas_sys/mod.rs | 18 +++ .../src/{cublasxt_sys.rs => cublas_sys/xt.rs} | 2 + crates/cust_raw/src/driver_sys.rs | 2 + crates/cust_raw/src/lib.rs | 18 ++- crates/cust_raw/src/nvptx_compiler_sys.rs | 2 + crates/cust_raw/src/nvvm_sys.rs | 3 + crates/cust_raw/src/runtime_sys.rs | 6 + crates/cust_raw/src/types/complex.rs | 4 + crates/cust_raw/src/types/driver.rs | 9 ++ crates/cust_raw/src/types/library.rs | 3 + crates/cust_raw/src/types/mod.rs | 18 +++ crates/cust_raw/src/types/surface.rs | 5 + crates/cust_raw/src/types/texture.rs | 6 + crates/cust_raw/src/types/vector.rs | 3 + 20 files changed, 230 insertions(+), 57 deletions(-) delete mode 100644 crates/cust_raw/src/cublas_sys.rs rename crates/cust_raw/src/{cublaslt_sys.rs => cublas_sys/lt.rs} (55%) create mode 100644 crates/cust_raw/src/cublas_sys/mod.rs rename crates/cust_raw/src/{cublasxt_sys.rs => cublas_sys/xt.rs} (90%) create mode 100644 crates/cust_raw/src/types/complex.rs create mode 100644 crates/cust_raw/src/types/driver.rs create mode 100644 crates/cust_raw/src/types/library.rs create mode 100644 crates/cust_raw/src/types/mod.rs create mode 100644 crates/cust_raw/src/types/surface.rs create mode 100644 crates/cust_raw/src/types/texture.rs create mode 100644 crates/cust_raw/src/types/vector.rs diff --git a/crates/blastoff/src/context.rs b/crates/blastoff/src/context.rs index 83c53aa..59c3646 100644 --- a/crates/blastoff/src/context.rs +++ b/crates/blastoff/src/context.rs @@ -147,9 +147,7 @@ impl CublasContext { // cudaStream_t is the same as CUstream cublas_sys::cublasSetStream( self.raw, - mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>( - stream.as_inner(), - ), + mem::transmute::(stream.as_inner()), ) .to_result()?; let res = func(self)?; diff --git a/crates/cust_raw/Cargo.toml b/crates/cust_raw/Cargo.toml index 046a713..92005aa 100644 --- a/crates/cust_raw/Cargo.toml +++ b/crates/cust_raw/Cargo.toml @@ -3,12 +3,15 @@ name = "cust_raw" version = "0.11.3" edition = "2021" license = "MIT OR Apache-2.0" -description = "Low level bindings to the CUDA Driver API" +description = "Low level bindings to the CUDA Toolkit SDK" repository = "https://github.com/Rust-GPU/Rust-CUDA" readme = "../../README.md" links = "cuda" build = "build/main.rs" +[dependencies] +libc = { version = "0.2", optional = true } + [build-dependencies] bindgen = "0.71.1" bimap = "0.6.3" @@ -19,9 +22,8 @@ features = [ "driver", "runtime", "cublas", - "cublaslt", - "cublasxt", - "cudnn", + "cublasLt", + "cublasXt", "nvptx-compiler", "nvvm", ] @@ -29,10 +31,15 @@ features = [ [features] default = ["driver"] driver = [] -runtime = [] -cublas = [] -cublaslt = [] -cublasxt = [] -cudnn = [] +runtime = ["driver_types", "vector_types", "texture_types", "surface_types"] +cuComplex = ["vector_types"] +driver_types = [] +library_types = [] +surface_types = [] +texture_types = [] +vector_types = [] +cublas = ["runtime", "cuComplex", "library_types"] +cublasLt = ["cublas", "libc"] +cublasXt = ["cublas"] nvptx-compiler = [] nvvm = [] diff --git a/crates/cust_raw/build/driver_wrapper.h b/crates/cust_raw/build/driver_wrapper.h index 624c339..bb2bd7e 100644 --- a/crates/cust_raw/build/driver_wrapper.h +++ b/crates/cust_raw/build/driver_wrapper.h @@ -1,4 +1,2 @@ -#include "cuComplex.h" #include "cuda.h" -#include "cudaProfiler.h" -#include "vector_types.h" \ No newline at end of file +#include "cudaProfiler.h" \ No newline at end of file diff --git a/crates/cust_raw/build/main.rs b/crates/cust_raw/build/main.rs index 07a8c9c..8c2964b 100644 --- a/crates/cust_raw/build/main.rs +++ b/crates/cust_raw/build/main.rs @@ -28,8 +28,8 @@ use std::env; use std::fs; use std::path; -pub mod callbacks; -pub mod cuda_sdk; +mod callbacks; +mod cuda_sdk; fn main() { let outdir = path::PathBuf::from( @@ -63,8 +63,9 @@ fn main() { println!("cargo::rerun-if-env-changed={}", e); } - create_cuda_driver_bindings(&sdk, outdir.as_path()); - create_cuda_runtime_bindings(&sdk, outdir.as_path()); + create_driver_bindings(&sdk, outdir.as_path()); + create_runtime_bindings(&sdk, outdir.as_path()); + create_runtime_types_bindings(&sdk, outdir.as_path()); create_cublas_bindings(&sdk, outdir.as_path()); create_nptx_compiler_bindings(&sdk, outdir.as_path()); create_nvvm_bindings(&sdk, outdir.as_path()); @@ -73,8 +74,8 @@ fn main() { feature = "driver", feature = "runtime", feature = "cublas", - feature = "cublaslt", - feature = "cublasxt" + feature = "cublasLt", + feature = "cublasXt" )) { for libdir in sdk.cuda_library_paths() { println!("cargo::rustc-link-search=native={}", libdir.display()); @@ -84,11 +85,11 @@ fn main() { if cfg!(feature = "runtime") { println!("cargo::rustc-link-lib=dylib=cudart"); } - if cfg!(feature = "cublas") || cfg!(feature = "cublasxt") { + if cfg!(feature = "cublas") || cfg!(feature = "cublasXt") { println!("cargo::rustc-link-lib=dylib=cublas"); } - if cfg!(feature = "cublaslt") { - println!("cargo::rustc-link-lib=dylib=cublaslt"); + if cfg!(feature = "cublasLt") { + println!("cargo::rustc-link-lib=dylib=cublasLt"); } if cfg!(feature = "nvvm") { for libdir in sdk.nvvm_library_paths() { @@ -101,7 +102,53 @@ fn main() { } } -fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { +fn create_runtime_types_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { + let params = &[ + (cfg!(feature = "driver_types"), "driver_types"), + (cfg!(feature = "library_types"), "library_types"), + (cfg!(feature = "vector_types"), "vector_types"), + (cfg!(feature = "texture_types"), "texture_types"), + (cfg!(feature = "surface_types"), "surface_types"), + (cfg!(feature = "cuComplex"), "cuComplex"), + ]; + for (should_generate, pkg) in params { + if !should_generate { + continue; + } + let bindgen_path = path::PathBuf::from(format!("{}/{}_sys.rs", outdir.display(), pkg)); + let header = sdk + .cuda_root() + .join(format!("include/{}.h", pkg)) + .display() + .to_string(); + let bindings = bindgen::Builder::default() + .header(&header) + .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) + .clang_args( + sdk.cuda_include_paths() + .iter() + .map(|p| format!("-I{}", p.display())), + ) + .allowlist_file(format!(r".*{pkg}\.h")) + .allowlist_recursively(false) + .default_enum_style(bindgen::EnumVariation::Rust { + non_exhaustive: false, + }) + .derive_default(true) + .derive_eq(true) + .derive_hash(true) + .derive_ord(true) + .size_t_is_usize(true) + .layout_tests(true) + .generate() + .unwrap_or_else(|e| panic!("Unable to generate {pkg} bindings: {e}")); + bindings + .write_to_file(bindgen_path.as_path()) + .unwrap_or_else(|e| panic!("Cannot write {pkg} bindgen output to file: {e}")); + } +} + +fn create_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { if !cfg!(feature = "driver") { return; } @@ -121,13 +168,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { .iter() .map(|p| format!("-I{}", p.display())), ) - .allowlist_type("^CU.*") - .allowlist_type("^cuuint(32|64)_t") - .allowlist_type("^cudaError_enum") - .allowlist_type("^cu.*Complex$") - .allowlist_type("^cuda.*") - .allowlist_var("^CU.*") - .allowlist_function("^cu.*") + .allowlist_file(r".*cuda[^/\\]*\.h") .default_enum_style(bindgen::EnumVariation::Rust { non_exhaustive: false, }) @@ -145,7 +186,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { .expect("Cannot write CUDA driver bindgen output to file."); } -fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { +fn create_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { if !cfg!(feature = "runtime") { return; } @@ -165,14 +206,13 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { .iter() .map(|p| format!("-I{}", p.display())), ) - .allowlist_type("^CU.*") - .allowlist_type("^cuda.*") - .allowlist_type("^libraryPropertyType.*") - .allowlist_var("^CU.*") - .allowlist_function("^cu.*") + .allowlist_file(r".*cuda[^/\\]*\.h") + .allowlist_file(r".*cuComplex\.h") + .allowlist_recursively(false) .default_enum_style(bindgen::EnumVariation::Rust { non_exhaustive: false, }) + .disable_nested_struct_naming() .derive_default(true) .derive_eq(true) .derive_hash(true) @@ -188,19 +228,51 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { } fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { - #[rustfmt::skip] let params = &[ - (cfg!(feature = "cublas"), "cublas", "^cublas.*", "^CUBLAS.*"), - (cfg!(feature = "cublaslt"), "cublasLt", "^cublasLt.*", "^CUBLASLT.*"), - (cfg!(feature = "cublasxt"), "cublasXt", "^cublasXt.*", "^CUBLASXT.*"), + ( + cfg!(feature = "cublas"), + "cublas", + vec![r".*cublas(_api|_v2)\.h"], + vec![ + r".*cuComplex\.h", + r".*driver_types\.h", + r".*library_types\.h", + r".*vector_types\.h", + ], + ), + ( + cfg!(feature = "cublasLt"), + "cublasLt", + vec![r".*cublasLt\.h"], + vec![ + r".*cublas(_api|_v2)*\.h", + r".*cuComplex\.h", + r".*driver_types\.h", + r".*library_types\.h", + r".*vector_types\.h", + r".*std\w+\.h", + ], + ), + ( + cfg!(feature = "cublasXt"), + "cublasXt", + vec![r".*cublasXt\.h"], + vec![ + r".*cublas(_api|_v2)*\.h", + r".*cuComplex\.h", + r".*driver_types\.h", + r".*library_types\.h", + r".*vector_types\.h", + ], + ), ]; - for (should_generate, pkg, tf, var) in params { + for (should_generate, pkg, allowed, blocked) in params { if !should_generate { continue; } let bindgen_path = path::PathBuf::from(format!("{}/{pkg}_sys.rs", outdir.display())); let header = format!("build/{pkg}_wrapper.h"); - let bindings = bindgen::Builder::default() + let mut bindings = bindgen::Builder::default() .header(&header) .parse_callbacks(Box::new(callbacks::FunctionRenames::new( pkg, @@ -214,9 +286,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) { .iter() .map(|p| format!("-I{}", p.display())), ) - .allowlist_type(tf) - .allowlist_function(tf) - .allowlist_var(var) + .allowlist_recursively(false); + + for file in allowed { + bindings = bindings.allowlist_file(file); + } + for file in blocked { + bindings = bindings.blocklist_file(file); + } + + let bindings = bindings .default_enum_style(bindgen::EnumVariation::Rust { non_exhaustive: false, }) diff --git a/crates/cust_raw/src/cublas_sys.rs b/crates/cust_raw/src/cublas_sys.rs deleted file mode 100644 index 712b9b7..0000000 --- a/crates/cust_raw/src/cublas_sys.rs +++ /dev/null @@ -1,5 +0,0 @@ -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(non_snake_case)] - -include!(concat!(env!("OUT_DIR"), "/cublas_sys.rs")); diff --git a/crates/cust_raw/src/cublaslt_sys.rs b/crates/cust_raw/src/cublas_sys/lt.rs similarity index 55% rename from crates/cust_raw/src/cublaslt_sys.rs rename to crates/cust_raw/src/cublas_sys/lt.rs index c806b3b..eda7e4e 100644 --- a/crates/cust_raw/src/cublaslt_sys.rs +++ b/crates/cust_raw/src/cublas_sys/lt.rs @@ -1,5 +1,10 @@ -#![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] +use libc::FILE; + +use super::*; +use crate::types::driver::*; +use crate::types::library::*; + include!(concat!(env!("OUT_DIR"), "/cublasLt_sys.rs")); diff --git a/crates/cust_raw/src/cublas_sys/mod.rs b/crates/cust_raw/src/cublas_sys/mod.rs new file mode 100644 index 0000000..4238092 --- /dev/null +++ b/crates/cust_raw/src/cublas_sys/mod.rs @@ -0,0 +1,18 @@ +//! Bindings to the CUDA Basic Linear Algebra Subprograms (cuBLAS) library. +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +use crate::types::library::*; + +pub use crate::runtime_sys::cudaStream_t; +pub use crate::types::complex::*; +pub use crate::types::library::cudaDataType; + +include!(concat!(env!("OUT_DIR"), "/cublas_sys.rs")); + +#[cfg(feature = "cublasLt")] +pub mod lt; + +#[cfg(feature = "cublasXt")] +pub mod xt; diff --git a/crates/cust_raw/src/cublasxt_sys.rs b/crates/cust_raw/src/cublas_sys/xt.rs similarity index 90% rename from crates/cust_raw/src/cublasxt_sys.rs rename to crates/cust_raw/src/cublas_sys/xt.rs index c4d934d..6b7c291 100644 --- a/crates/cust_raw/src/cublasxt_sys.rs +++ b/crates/cust_raw/src/cublas_sys/xt.rs @@ -2,4 +2,6 @@ #![allow(non_camel_case_types)] #![allow(non_snake_case)] +use super::*; + include!(concat!(env!("OUT_DIR"), "/cublasXt_sys.rs")); diff --git a/crates/cust_raw/src/driver_sys.rs b/crates/cust_raw/src/driver_sys.rs index 841e3c7..f073590 100644 --- a/crates/cust_raw/src/driver_sys.rs +++ b/crates/cust_raw/src/driver_sys.rs @@ -1,3 +1,5 @@ +//! Bindings to the CUDA Driver API + #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] diff --git a/crates/cust_raw/src/lib.rs b/crates/cust_raw/src/lib.rs index 62bd4b0..686529f 100644 --- a/crates/cust_raw/src/lib.rs +++ b/crates/cust_raw/src/lib.rs @@ -1,16 +1,26 @@ +//! # `cust_raw`: Bindings to the CUDA Toolkit SDK +//! #[cfg(feature = "driver")] pub mod driver_sys; + #[cfg(feature = "runtime")] pub mod runtime_sys; +#[cfg(any( + feature = "driver_types", + feature = "vector_types", + feature = "texture_types", + feature = "surface_types", + feature = "cuComplex", + feature = "library_types" +))] +pub mod types; + #[cfg(feature = "cublas")] pub mod cublas_sys; -#[cfg(feature = "cublaslt")] -pub mod cublaslt_sys; -#[cfg(feature = "cublasxt")] -pub mod cublasxt_sys; #[cfg(feature = "nvptx-compiler")] pub mod nvptx_compiler_sys; + #[cfg(feature = "nvvm")] pub mod nvvm_sys; diff --git a/crates/cust_raw/src/nvptx_compiler_sys.rs b/crates/cust_raw/src/nvptx_compiler_sys.rs index c3f1090..de0b00f 100644 --- a/crates/cust_raw/src/nvptx_compiler_sys.rs +++ b/crates/cust_raw/src/nvptx_compiler_sys.rs @@ -1,3 +1,5 @@ +//! Bindings to the NVPTX Compiler library. + #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] diff --git a/crates/cust_raw/src/nvvm_sys.rs b/crates/cust_raw/src/nvvm_sys.rs index 6911dc4..7e7c48c 100644 --- a/crates/cust_raw/src/nvvm_sys.rs +++ b/crates/cust_raw/src/nvvm_sys.rs @@ -1,3 +1,6 @@ +//! Bindings to the libNVVM API, an interface for generating PTX code from both +//! binary and text NVVM IR inputs. + #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] diff --git a/crates/cust_raw/src/runtime_sys.rs b/crates/cust_raw/src/runtime_sys.rs index 2ec7c5e..72b6552 100644 --- a/crates/cust_raw/src/runtime_sys.rs +++ b/crates/cust_raw/src/runtime_sys.rs @@ -1,5 +1,11 @@ +//! Bindings to the CUDA Runtime API #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] +pub use crate::types::driver::*; +pub use crate::types::surface::*; +pub use crate::types::texture::*; +pub use crate::types::vector::dim3; + include!(concat!(env!("OUT_DIR"), "/runtime_sys.rs")); diff --git a/crates/cust_raw/src/types/complex.rs b/crates/cust_raw/src/types/complex.rs new file mode 100644 index 0000000..5191c2b --- /dev/null +++ b/crates/cust_raw/src/types/complex.rs @@ -0,0 +1,4 @@ +// Bindings to CUDA complex number types. +#![allow(non_camel_case_types)] +use crate::types::vector::*; +include!(concat!(env!("OUT_DIR"), "/cuComplex_sys.rs")); diff --git a/crates/cust_raw/src/types/driver.rs b/crates/cust_raw/src/types/driver.rs new file mode 100644 index 0000000..f06d5ef --- /dev/null +++ b/crates/cust_raw/src/types/driver.rs @@ -0,0 +1,9 @@ +//! Bindings to driver types in the CUDA runtime API. +#![allow(dead_code)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(unused_imports)] +#![allow(clippy::missing_safety_doc)] +use crate::types::vector::dim3; +include!(concat!(env!("OUT_DIR"), "/driver_types_sys.rs")); diff --git a/crates/cust_raw/src/types/library.rs b/crates/cust_raw/src/types/library.rs new file mode 100644 index 0000000..30843f8 --- /dev/null +++ b/crates/cust_raw/src/types/library.rs @@ -0,0 +1,3 @@ +//! Bindings to types used to query CUDA library properties. +#![allow(non_camel_case_types)] +include!(concat!(env!("OUT_DIR"), "/library_types_sys.rs")); diff --git a/crates/cust_raw/src/types/mod.rs b/crates/cust_raw/src/types/mod.rs new file mode 100644 index 0000000..a3e4f21 --- /dev/null +++ b/crates/cust_raw/src/types/mod.rs @@ -0,0 +1,18 @@ +//! The CUDA runtime types bindings. +#[cfg(feature = "driver_types")] +pub mod driver; + +#[cfg(feature = "vector_types")] +pub mod vector; + +#[cfg(feature = "texture_types")] +pub mod texture; + +#[cfg(feature = "surface_types")] +pub mod surface; + +#[cfg(feature = "cuComplex")] +pub mod complex; + +#[cfg(feature = "library_types")] +pub mod library; diff --git a/crates/cust_raw/src/types/surface.rs b/crates/cust_raw/src/types/surface.rs new file mode 100644 index 0000000..b455501 --- /dev/null +++ b/crates/cust_raw/src/types/surface.rs @@ -0,0 +1,5 @@ +//! Bindings to CUDA surface types. +#![allow(non_camel_case_types)] +#![allow(non_upper_case_globals)] +#![allow(dead_code)] +include!(concat!(env!("OUT_DIR"), "/surface_types_sys.rs")); diff --git a/crates/cust_raw/src/types/texture.rs b/crates/cust_raw/src/types/texture.rs new file mode 100644 index 0000000..3873d45 --- /dev/null +++ b/crates/cust_raw/src/types/texture.rs @@ -0,0 +1,6 @@ +//! Bindings to CUDA texture types. +#![allow(dead_code)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +include!(concat!(env!("OUT_DIR"), "/texture_types_sys.rs")); diff --git a/crates/cust_raw/src/types/vector.rs b/crates/cust_raw/src/types/vector.rs new file mode 100644 index 0000000..13d51af --- /dev/null +++ b/crates/cust_raw/src/types/vector.rs @@ -0,0 +1,3 @@ +//! Binding to CUDA vector types. +#![allow(non_camel_case_types)] +include!(concat!(env!("OUT_DIR"), "/vector_types_sys.rs"));