Skip to content

reactor(cust_raw): Generate bindings for cuda types and restructure cust_raw's crates #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions crates/blastoff/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ impl CublasContext {
// cudaStream_t is the same as CUstream
cublas_sys::cublasSetStream(
self.raw,
mem::transmute::<*mut driver_sys::CUstream_st, *mut cublas_sys::CUstream_st>(
stream.as_inner(),
),
mem::transmute::<driver_sys::CUstream, cublas_sys::cudaStream_t>(stream.as_inner()),
)
.to_result()?;
let res = func(self)?;
Expand Down
25 changes: 16 additions & 9 deletions crates/cust_raw/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ name = "cust_raw"
version = "0.11.3"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Low level bindings to the CUDA Driver API"
description = "Low level bindings to the CUDA Toolkit SDK"
repository = "https://github.com/Rust-GPU/Rust-CUDA"
readme = "../../README.md"
links = "cuda"
build = "build/main.rs"

[dependencies]
libc = { version = "0.2", optional = true }

[build-dependencies]
bindgen = "0.71.1"
bimap = "0.6.3"
Expand All @@ -19,20 +22,24 @@ features = [
"driver",
"runtime",
"cublas",
"cublaslt",
"cublasxt",
"cudnn",
"cublasLt",
"cublasXt",
"nvptx-compiler",
"nvvm",
]

[features]
default = ["driver"]
driver = []
runtime = []
cublas = []
cublaslt = []
cublasxt = []
cudnn = []
runtime = ["driver_types", "vector_types", "texture_types", "surface_types"]
cuComplex = ["vector_types"]
driver_types = []
library_types = []
surface_types = []
texture_types = []
vector_types = []
cublas = ["runtime", "cuComplex", "library_types"]
cublasLt = ["cublas", "libc"]
cublasXt = ["cublas"]
nvptx-compiler = []
nvvm = []
4 changes: 1 addition & 3 deletions crates/cust_raw/build/driver_wrapper.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
#include "cuComplex.h"
#include "cuda.h"
#include "cudaProfiler.h"
#include "vector_types.h"
#include "cudaProfiler.h"
143 changes: 111 additions & 32 deletions crates/cust_raw/build/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use std::env;
use std::fs;
use std::path;

pub mod callbacks;
pub mod cuda_sdk;
mod callbacks;
mod cuda_sdk;

fn main() {
let outdir = path::PathBuf::from(
Expand Down Expand Up @@ -63,8 +63,9 @@ fn main() {
println!("cargo::rerun-if-env-changed={}", e);
}

create_cuda_driver_bindings(&sdk, outdir.as_path());
create_cuda_runtime_bindings(&sdk, outdir.as_path());
create_driver_bindings(&sdk, outdir.as_path());
create_runtime_bindings(&sdk, outdir.as_path());
create_runtime_types_bindings(&sdk, outdir.as_path());
create_cublas_bindings(&sdk, outdir.as_path());
create_nptx_compiler_bindings(&sdk, outdir.as_path());
create_nvvm_bindings(&sdk, outdir.as_path());
Expand All @@ -73,8 +74,8 @@ fn main() {
feature = "driver",
feature = "runtime",
feature = "cublas",
feature = "cublaslt",
feature = "cublasxt"
feature = "cublasLt",
feature = "cublasXt"
)) {
for libdir in sdk.cuda_library_paths() {
println!("cargo::rustc-link-search=native={}", libdir.display());
Expand All @@ -84,11 +85,11 @@ fn main() {
if cfg!(feature = "runtime") {
println!("cargo::rustc-link-lib=dylib=cudart");
}
if cfg!(feature = "cublas") || cfg!(feature = "cublasxt") {
if cfg!(feature = "cublas") || cfg!(feature = "cublasXt") {
println!("cargo::rustc-link-lib=dylib=cublas");
}
if cfg!(feature = "cublaslt") {
println!("cargo::rustc-link-lib=dylib=cublaslt");
if cfg!(feature = "cublasLt") {
println!("cargo::rustc-link-lib=dylib=cublasLt");
}
if cfg!(feature = "nvvm") {
for libdir in sdk.nvvm_library_paths() {
Expand All @@ -101,7 +102,53 @@ fn main() {
}
}

fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
fn create_runtime_types_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
let params = &[
(cfg!(feature = "driver_types"), "driver_types"),
(cfg!(feature = "library_types"), "library_types"),
(cfg!(feature = "vector_types"), "vector_types"),
(cfg!(feature = "texture_types"), "texture_types"),
(cfg!(feature = "surface_types"), "surface_types"),
(cfg!(feature = "cuComplex"), "cuComplex"),
];
for (should_generate, pkg) in params {
if !should_generate {
continue;
}
let bindgen_path = path::PathBuf::from(format!("{}/{}_sys.rs", outdir.display(), pkg));
let header = sdk
.cuda_root()
.join(format!("include/{}.h", pkg))
.display()
.to_string();
let bindings = bindgen::Builder::default()
.header(&header)
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.clang_args(
sdk.cuda_include_paths()
.iter()
.map(|p| format!("-I{}", p.display())),
)
.allowlist_file(format!(r".*{pkg}\.h"))
.allowlist_recursively(false)
.default_enum_style(bindgen::EnumVariation::Rust {
non_exhaustive: false,
})
.derive_default(true)
.derive_eq(true)
.derive_hash(true)
.derive_ord(true)
.size_t_is_usize(true)
.layout_tests(true)
.generate()
.unwrap_or_else(|e| panic!("Unable to generate {pkg} bindings: {e}"));
bindings
.write_to_file(bindgen_path.as_path())
.unwrap_or_else(|e| panic!("Cannot write {pkg} bindgen output to file: {e}"));
}
}

fn create_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
if !cfg!(feature = "driver") {
return;
}
Expand All @@ -121,13 +168,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
.iter()
.map(|p| format!("-I{}", p.display())),
)
.allowlist_type("^CU.*")
.allowlist_type("^cuuint(32|64)_t")
.allowlist_type("^cudaError_enum")
.allowlist_type("^cu.*Complex$")
.allowlist_type("^cuda.*")
.allowlist_var("^CU.*")
.allowlist_function("^cu.*")
.allowlist_file(r".*cuda[^/\\]*\.h")
.default_enum_style(bindgen::EnumVariation::Rust {
non_exhaustive: false,
})
Expand All @@ -145,7 +186,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
.expect("Cannot write CUDA driver bindgen output to file.");
}

fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
fn create_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
if !cfg!(feature = "runtime") {
return;
}
Expand All @@ -165,14 +206,13 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
.iter()
.map(|p| format!("-I{}", p.display())),
)
.allowlist_type("^CU.*")
.allowlist_type("^cuda.*")
.allowlist_type("^libraryPropertyType.*")
.allowlist_var("^CU.*")
.allowlist_function("^cu.*")
.allowlist_file(r".*cuda[^/\\]*\.h")
.allowlist_file(r".*cuComplex\.h")
.allowlist_recursively(false)
.default_enum_style(bindgen::EnumVariation::Rust {
non_exhaustive: false,
})
.disable_nested_struct_naming()
.derive_default(true)
.derive_eq(true)
.derive_hash(true)
Expand All @@ -188,19 +228,51 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
}

fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
#[rustfmt::skip]
let params = &[
(cfg!(feature = "cublas"), "cublas", "^cublas.*", "^CUBLAS.*"),
(cfg!(feature = "cublaslt"), "cublasLt", "^cublasLt.*", "^CUBLASLT.*"),
(cfg!(feature = "cublasxt"), "cublasXt", "^cublasXt.*", "^CUBLASXT.*"),
(
cfg!(feature = "cublas"),
"cublas",
vec![r".*cublas(_api|_v2)\.h"],
vec![
r".*cuComplex\.h",
r".*driver_types\.h",
r".*library_types\.h",
r".*vector_types\.h",
],
),
(
cfg!(feature = "cublasLt"),
"cublasLt",
vec![r".*cublasLt\.h"],
vec![
r".*cublas(_api|_v2)*\.h",
r".*cuComplex\.h",
r".*driver_types\.h",
r".*library_types\.h",
r".*vector_types\.h",
r".*std\w+\.h",
],
),
(
cfg!(feature = "cublasXt"),
"cublasXt",
vec![r".*cublasXt\.h"],
vec![
r".*cublas(_api|_v2)*\.h",
r".*cuComplex\.h",
r".*driver_types\.h",
r".*library_types\.h",
r".*vector_types\.h",
],
),
];
for (should_generate, pkg, tf, var) in params {
for (should_generate, pkg, allowed, blocked) in params {
if !should_generate {
continue;
}
let bindgen_path = path::PathBuf::from(format!("{}/{pkg}_sys.rs", outdir.display()));
let header = format!("build/{pkg}_wrapper.h");
let bindings = bindgen::Builder::default()
let mut bindings = bindgen::Builder::default()
.header(&header)
.parse_callbacks(Box::new(callbacks::FunctionRenames::new(
pkg,
Expand All @@ -214,9 +286,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
.iter()
.map(|p| format!("-I{}", p.display())),
)
.allowlist_type(tf)
.allowlist_function(tf)
.allowlist_var(var)
.allowlist_recursively(false);

for file in allowed {
bindings = bindings.allowlist_file(file);
}
for file in blocked {
bindings = bindings.blocklist_file(file);
}

let bindings = bindings
.default_enum_style(bindgen::EnumVariation::Rust {
non_exhaustive: false,
})
Expand Down
5 changes: 0 additions & 5 deletions crates/cust_raw/src/cublas_sys.rs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

use libc::FILE;

use super::*;
use crate::types::driver::*;
use crate::types::library::*;

include!(concat!(env!("OUT_DIR"), "/cublasLt_sys.rs"));
18 changes: 18 additions & 0 deletions crates/cust_raw/src/cublas_sys/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//! Bindings to the CUDA Basic Linear Algebra Subprograms (cuBLAS) library.
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

use crate::types::library::*;

pub use crate::runtime_sys::cudaStream_t;
pub use crate::types::complex::*;
pub use crate::types::library::cudaDataType;

include!(concat!(env!("OUT_DIR"), "/cublas_sys.rs"));

#[cfg(feature = "cublasLt")]
pub mod lt;

#[cfg(feature = "cublasXt")]
pub mod xt;
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

use super::*;

include!(concat!(env!("OUT_DIR"), "/cublasXt_sys.rs"));
2 changes: 2 additions & 0 deletions crates/cust_raw/src/driver_sys.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Bindings to the CUDA Driver API

#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
Expand Down
18 changes: 14 additions & 4 deletions crates/cust_raw/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
//! # `cust_raw`: Bindings to the CUDA Toolkit SDK
//!
#[cfg(feature = "driver")]
pub mod driver_sys;

#[cfg(feature = "runtime")]
pub mod runtime_sys;

#[cfg(any(
feature = "driver_types",
feature = "vector_types",
feature = "texture_types",
feature = "surface_types",
feature = "cuComplex",
feature = "library_types"
))]
pub mod types;

#[cfg(feature = "cublas")]
pub mod cublas_sys;
#[cfg(feature = "cublaslt")]
pub mod cublaslt_sys;
#[cfg(feature = "cublasxt")]
pub mod cublasxt_sys;

#[cfg(feature = "nvptx-compiler")]
pub mod nvptx_compiler_sys;

#[cfg(feature = "nvvm")]
pub mod nvvm_sys;
2 changes: 2 additions & 0 deletions crates/cust_raw/src/nvptx_compiler_sys.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Bindings to the NVPTX Compiler library.

#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
Expand Down
3 changes: 3 additions & 0 deletions crates/cust_raw/src/nvvm_sys.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//! Bindings to the libNVVM API, an interface for generating PTX code from both
//! binary and text NVVM IR inputs.

#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
Expand Down
6 changes: 6 additions & 0 deletions crates/cust_raw/src/runtime_sys.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
//! Bindings to the CUDA Runtime API
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

pub use crate::types::driver::*;
pub use crate::types::surface::*;
pub use crate::types::texture::*;
pub use crate::types::vector::dim3;

include!(concat!(env!("OUT_DIR"), "/runtime_sys.rs"));
Loading
Loading