Skip to content

WIP: Use DirectStorage with CUDA interop to more efficient load tensors #7796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ option(LLAMA_BLAS "llama: use BLAS"
option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT})
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUDA "llama: use CUDA" OFF)
option(LLAMA_CUDA_DIRECT_STORAGE "llama: use DirectStorage to upload tensors" OFF)
set(LLAMA_DIRECT_STORAGE_DIR "" CACHE PATH "llama: path to DirectStorage directory fetched with nuget. See https://devblogs.microsoft.com/directx/directstorage-api-downloads/" )
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
Expand Down Expand Up @@ -152,7 +154,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake)
# Compile flags
#

if (LLAMA_SYCL)
if (LLAMA_SYCL OR LLAMA_CUDA_DIRECT_STORAGE)
set(CMAKE_CXX_STANDARD 17)
else()
set(CMAKE_CXX_STANDARD 11)
Expand Down Expand Up @@ -412,6 +414,15 @@ if (LLAMA_CUDA)
file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")

if (LLAMA_CUDA_DIRECT_STORAGE)
file(GLOB GGML_SOURCES_CUDA_C "ggml-cuda/*.cpp")
file(GLOB GGML_SOURCES_CUDA_H "ggml-cuda/*.h")
list(APPEND GGML_SOURCES_CUDA ${GGML_SOURCES_CUDA_C})
list(APPEND GGML_SOURCES_CUDA ${GGML_SOURCES_CUDA_H})

add_compile_definitions(GGML_ENABLE_DIRECT_STORAGE_CUDA)
endif()

add_compile_definitions(GGML_USE_CUDA)
if (LLAMA_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
Expand Down Expand Up @@ -1172,15 +1183,15 @@ add_library(ggml OBJECT
ggml-backend.h
ggml-quants.c
ggml-quants.h
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
)

Expand All @@ -1195,9 +1206,19 @@ if (BUILD_SHARED_LIBS)
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(ggml_shared SHARED $<TARGET_OBJECTS:ggml>)
target_link_libraries(ggml_shared PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
target_link_libraries(ggml_shared PUBLIC "${LLAMA_DIRECT_STORAGE_DIR}/native/lib/x64/dstorage.lib" cuda cudart d3d12)

install(TARGETS ggml_shared LIBRARY)
endif()

if (LLAMA_CUDA_DIRECT_STORAGE)
set_property(TARGET ggml PROPERTY VS_PACKAGE_REFERENCES "Microsoft.Direct3D.DirectStorage_1.2.2")

target_include_directories(ggml PRIVATE "${LLAMA_DIRECT_STORAGE_DIR}/native/include")
target_link_directories(ggml PRIVATE "${LLAMA_DIRECT_STORAGE_DIR}/native/lib/x64")
target_link_libraries(ggml PUBLIC "${LLAMA_DIRECT_STORAGE_DIR}/native/lib/x64/dstorage.lib" cuda cudart d3d12)
endif()

# llama

add_library(llama
Expand Down
3 changes: 2 additions & 1 deletion ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void *

GGML_ASSERT(buf != NULL && "tensor buffer not set");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
//GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(offset + (size & ~(1u << 31)) <= ggml_nbytes(tensor) && "tensor write out of bounds");

if (!size) {
return;
Expand Down
150 changes: 141 additions & 9 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/dsc.h"

#include <algorithm>
#include <array>
Expand All @@ -45,6 +46,8 @@
#include <stdio.h>
#include <string>
#include <vector>
#include <filesystem>
#include <iostream> // debug

static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");

Expand Down Expand Up @@ -79,6 +82,10 @@ int ggml_cuda_get_device() {
return id;
}

#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
std::unique_ptr<DirectStorageCUDA> dsc;
#endif

static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
Expand Down Expand Up @@ -149,6 +156,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));

#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
dsc = std::move(DirectStorageCUDA::create(8 * 1024 * 1024, 64));
#endif

return info;
}

Expand Down Expand Up @@ -368,13 +379,35 @@ struct ggml_backend_cuda_buffer_context {
int device;
void * dev_ptr = nullptr;
std::string name;
#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
std::unique_ptr<InteropBuffer> direct_storage_buffer;
#endif

ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
device(device), dev_ptr(dev_ptr),
name(GGML_CUDA_NAME + std::to_string(device)) {

}

#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
ggml_backend_cuda_buffer_context(int device, std::unique_ptr<InteropBuffer> && direct_storage_buffer_) :
device(device), dev_ptr(nullptr),
name(GGML_CUDA_NAME + std::to_string(device)),
direct_storage_buffer(std::move(direct_storage_buffer_))
{
dev_ptr = direct_storage_buffer->get_device_ptr();
}
#endif


~ggml_backend_cuda_buffer_context() {
#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
if (direct_storage_buffer) {
direct_storage_buffer.reset();
dev_ptr = nullptr;
}
#endif

CUDA_CHECK(cudaFree(dev_ptr));
}
};
Expand Down Expand Up @@ -418,12 +451,96 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
}
}

struct FileInfo {
std::vector<DirectStorageCUDA::File> handles;
size_t handle_idx = 0;

DirectStorageCUDA::File& getFile() {
auto& temp = handles[handle_idx];
++handle_idx;
handle_idx %= handles.size();
return temp;
}
};

std::map<std::string, FileInfo> files;

GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;

ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
if (size & (1u << 31)) {
size &= ~(1u << 31);
if (data == nullptr) {
std::cout << "flush" << std::endl;
dsc->flush(true);
return;
}
struct Temp {
const char* filename;
size_t weights_off;
};
Temp* t = (Temp*)data;

std::string filename = t->filename;
auto it = files.find(filename);
if (it == files.end()) {
files[filename].handles.push_back(dsc->openFile(filename));

#define COPY_RAID
#if defined(COPY_RAID)
// This is a hack to evaluate how fast data can be read from a 2nd disk.
std::filesystem::path p(filename);
std::filesystem::path p2("d:");
p2 /= "\\lmcache";
p2 /= p.filename().c_str();
std::cout << p2.string() << std::endl;
if (std::filesystem::exists(p2)) {
std::cout << "opening " << p2.string() << std::endl;
files[filename].handles.push_back(dsc->openFile(p2.string().c_str()));
}
std::cout << "2nd file" << std::endl;
#endif

it = files.find(filename);
}

//dsc->loadFile(it->second.getFile(), t->weights_off, size, (char*)tensor->data + offset);
if (ctx->direct_storage_buffer) {
size_t tensor_offset = (char*)tensor->data - (char*)ctx->direct_storage_buffer->get_device_ptr();
#if defined(COPY_RAID)
size_t blocksize = 4 * 1024 * 1024;
for (size_t idx = 0; idx < size; idx += blocksize) {
size_t read_len = size - idx;
if (read_len > blocksize)
read_len = blocksize;
dsc->loadFile(it->second.getFile(), t->weights_off + idx, read_len, ctx->direct_storage_buffer.get(), offset + tensor_offset + idx);
}
#else
dsc->loadFile(it->second.getFile(), t->weights_off, size, ctx->direct_storage_buffer.get(), offset + tensor_offset);
#endif
}
else {
#if defined(COPY_RAID)
size_t blocksize = 2 * 1024 * 1024;
for (size_t idx = 0; idx < size; idx += blocksize) {
size_t read_len = size - idx;
if (read_len > blocksize)
read_len = blocksize;
dsc->loadFile(it->second.getFile(), t->weights_off + idx, read_len, (char*)tensor->data + offset + idx);
}
#else
dsc->loadFile(it->second.getFile(), t->weights_off, size, (char*)tensor->data + offset);
#endif
}
}
else
#endif
{
CUDA_CHECK(cudaMemcpyAsync((char*)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
}
}

GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
Expand Down Expand Up @@ -495,15 +612,20 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe

size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0

void * dev_ptr;
#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
if (size < 512 * 1024 * 1024) {
auto interop_buffer = dsc->create_interop_buffer(size);
ggml_backend_cuda_buffer_context* ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, std::move(interop_buffer));
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}
#endif
void* dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size);
if (err != cudaSuccess) {
fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err));
fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
return nullptr;
}

ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);

ggml_backend_cuda_buffer_context* ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}

Expand Down Expand Up @@ -539,11 +661,21 @@ GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backen
return buft_ctx->device == cuda_ctx->device;
}

GGML_CALL static size_t ggml_backend_cuda_get_max_size(ggml_backend_buffer_type_t buft)
{
#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA)
//return 512 * 1024 * 1024; // dx interop limit
return SIZE_MAX;
#else
return SIZE_MAX;
#endif
}// allocation max size

static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
/* .get_name = */ ggml_backend_cuda_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_max_size = */ ggml_backend_cuda_get_max_size, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
/* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
/* .is_host = */ NULL,
Expand Down Expand Up @@ -708,7 +840,7 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backe

static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) {
return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name;
GGML_UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds
GGML_UNUSED(&ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds
}

GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
Expand Down
Loading
Loading