From f1571c96fcd5f11f7d60489d01db99e97932290b Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 19 Apr 2024 15:07:32 +0200 Subject: [PATCH 1/4] Add backdoor to ggml to use DirectStorage to load tensors. --- CMakeLists.txt | 39 ++++++++++++++++++++-------- ggml-backend.c | 3 ++- ggml-cuda.cu | 70 ++++++++++++++++++++++++++++++++++++++++++++++++-- llama.cpp | 51 ++++++++++++++++++++++++++++++++++-- 4 files changed, 148 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f134a153bb4ff..6b3b7ea42e51f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,8 @@ option(LLAMA_BLAS "llama: use BLAS" option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ${LLAMA_LLAMAFILE_DEFAULT}) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUDA "llama: use CUDA" OFF) +option(LLAMA_CUDA_DIRECT_STORAGE "llama: use DirectStorage to upload tensors" OFF) +set(LLAMA_DIRECT_STORAGE_DIR "" CACHE PATH "llama: path to DirectStorage directory fetched with nuget. See https://devblogs.microsoft.com/directx/directstorage-api-downloads/" ) option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) @@ -152,7 +154,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake) # Compile flags # -if (LLAMA_SYCL) +if (LLAMA_SYCL OR LLAMA_CUDA_DIRECT_STORAGE) set(CMAKE_CXX_STANDARD 17) else() set(CMAKE_CXX_STANDARD 11) @@ -412,6 +414,15 @@ if (LLAMA_CUDA) file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu") list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") + if (LLAMA_CUDA_DIRECT_STORAGE) + file(GLOB GGML_SOURCES_CUDA_C "ggml-cuda/*.cpp") + file(GLOB GGML_SOURCES_CUDA_H "ggml-cuda/*.h") + list(APPEND GGML_SOURCES_CUDA ${GGML_SOURCES_CUDA_C}) + list(APPEND GGML_SOURCES_CUDA ${GGML_SOURCES_CUDA_H}) + + add_compile_definitions(GGML_ENABLE_DIRECT_STORAGE_CUDA) + endif() + add_compile_definitions(GGML_USE_CUDA) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) @@ -1172,15 +1183,15 @@ add_library(ggml OBJECT ggml-backend.h ggml-quants.c ggml-quants.h - ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} - ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} - ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} - ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} - ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} - ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} - ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE} - ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN} - ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} + ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} + ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} + ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} + ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} + ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} + ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE} + ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN} + ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} ) @@ -1198,6 +1209,14 @@ if (BUILD_SHARED_LIBS) install(TARGETS ggml_shared LIBRARY) endif() +if (LLAMA_CUDA_DIRECT_STORAGE) + set_property(TARGET ggml PROPERTY VS_PACKAGE_REFERENCES "Microsoft.Direct3D.DirectStorage_1.2.2") + + target_include_directories(ggml PRIVATE "${LLAMA_DIRECT_STORAGE_DIR}/native/include") + target_link_directories(ggml PRIVATE "${LLAMA_DIRECT_STORAGE_DIR}/native/lib/x64") + target_link_libraries(ggml PUBLIC "${LLAMA_DIRECT_STORAGE_DIR}/native/lib/x64/dstorage.lib" cuda cudart d3d12) +endif() + # llama add_library(llama diff --git a/ggml-backend.c b/ggml-backend.c index 402d86ef3ac8b..abf29ac095dce 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -223,7 +223,8 @@ GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + //GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(offset + (size & ~(1u << 31)) <= ggml_nbytes(tensor) && "tensor write out of bounds"); if (!size) { return; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d277104d12177..8e1207697374b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -29,6 +29,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/dsc.h" #include #include @@ -45,6 +46,8 @@ #include #include #include +#include +#include // debug static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -79,6 +82,10 @@ int ggml_cuda_get_device() { return id; } +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) +std::unique_ptr dsc; +#endif + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -149,6 +156,10 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + dsc = std::move(DirectStorageCUDA::create(8 * 1024 * 1024, 64)); +#endif + return info; } @@ -418,12 +429,67 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t } } +struct FileInfo { + std::vector handles; + size_t handle_idx = 0; + + DirectStorageCUDA::File& getFile() { + auto& temp = handles[handle_idx]; + ++handle_idx; + handle_idx %= handles.size(); + return temp; + } +}; + +std::map files; + GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); - CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + if (size & (1u << 31)) { + size &= ~(1u << 31); + if (data == nullptr) { + dsc->flush(); + return; + } + struct Temp { + const char* filename; + size_t weights_off; + }; + Temp* t = (Temp*)data; + + std::string filename = t->filename; + auto it = files.find(filename); + if (it == files.end()) { + files[filename].handles.push_back(dsc->openFile(filename)); + +#if 0 + // This is a hack to evaluate how fast data can be read from a 2nd disk. + std::filesystem::path p(filename); + std::filesystem::path p2("d:"); + p2 /= "\\lmcache"; + p2 /= p.filename().c_str(); + std::cout << p2.string() << std::endl; + if (std::filesystem::exists(p2)) { + std::cout << "opening " << p2.string() << std::endl; + files[filename].handles.push_back(dsc->openFile(p2.string().c_str())); + } + std::cout << "2nd file" << std::endl; +#endif + + it = files.find(filename); + } + + dsc->loadFile(it->second.getFile(), t->weights_off, size, (char*)tensor->data + offset); + } + else +#endif + { + CUDA_CHECK(cudaMemcpyAsync((char*)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); + } } GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { diff --git a/llama.cpp b/llama.cpp index fa7c022f29130..992337e0cc154 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7,6 +7,9 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include +#include + #ifdef GGML_USE_CUDA # include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) @@ -1176,8 +1179,10 @@ struct llama_file { // use FILE * so we don't have to re-open the file to mmap FILE * fp; size_t size; + std::string filename; llama_file(const char * fname, const char * mode) { + filename = fname; fp = ggml_fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -3459,7 +3464,9 @@ struct llama_model_loader { size_t size_data = 0; std::vector> mmaps_used; - // Returns false if cancelled by progress_callback + + + // Returns false if canceled by progress_callback bool load_all_data( struct ggml_context * ctx, llama_buf_map & bufs_mmap, @@ -3468,6 +3475,14 @@ struct llama_model_loader { void * progress_callback_user_data) { GGML_ASSERT(size_data != 0 && "call init_mappings() first"); +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + struct ggml_tensor* last_tensor = nullptr; + + // debug statistics + size_t total_data_read = 0; + auto start = std::chrono::high_resolution_clock::now(); +#endif + std::vector> read_buf; for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(ggml_get_name(cur)); @@ -3511,16 +3526,39 @@ struct llama_model_loader { file->seek(weight->offs, SEEK_SET); file->read_raw(cur->data, ggml_nbytes(cur)); } else { + +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + // backdoor to load tensors with DirectStorage + last_tensor = cur; + struct Temp { + const char* filename; + size_t weights_off; + }; + + Temp t; + t.filename = file->filename.c_str(); + t.weights_off = weight->offs; + + ggml_backend_tensor_set(cur, &t, 0, n_size | (1u << 31)); +#else read_buf.resize(ggml_nbytes(cur)); file->seek(weight->offs, SEEK_SET); file->read_raw(read_buf.data(), ggml_nbytes(cur)); ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); +#endif + } } size_done += n_size; } +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + // trigger flush of unread data + if (last_tensor) + ggml_backend_tensor_set(last_tensor, 0, 0, 1u << 31); +#endif + // check if this is the last call and do final cleanup if (size_done >= size_data) { // unmap offloaded tensors and metadata @@ -3541,6 +3579,14 @@ struct llama_model_loader { } } +#if defined(ENABLE_DIRECT_STORAGE_CUDA) + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration> delta(end - start); + //auto seconds = std::chrono::duration_cast(delta); + std::cout << "load time: " << delta.count() << std::endl;; +#endif + + return true; } }; @@ -5874,6 +5920,7 @@ static bool llm_load_tensors( // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = ggml_time_us() - model.t_start_us; + std::cout << "model load time: " << model.t_load_us / 1000.0f << "ms" << std::endl; return true; } @@ -14213,7 +14260,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // mmap consistently increases speed Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. -#if defined(__linux__) || defined(_WIN32) +#if false && defined(__linux__) || defined(_WIN32) constexpr bool use_mmap = true; #else constexpr bool use_mmap = false; From b11224c5e1e43d9bb44f90f1c0690ec4d017b0a2 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 23 Apr 2024 09:37:28 +0200 Subject: [PATCH 2/4] add missing DirectStorageCUDA files --- ggml-cuda/dsc.cpp | 652 ++++++++++++++++++++++++++++++++++++++++++++++ ggml-cuda/dsc.h | 56 ++++ 2 files changed, 708 insertions(+) create mode 100644 ggml-cuda/dsc.cpp create mode 100644 ggml-cuda/dsc.h diff --git a/ggml-cuda/dsc.cpp b/ggml-cuda/dsc.cpp new file mode 100644 index 0000000000000..e390cd588d059 --- /dev/null +++ b/ggml-cuda/dsc.cpp @@ -0,0 +1,652 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include "dsc.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // WindowsSecurityAttributes +#include +#include +#include + +#include + + +class WindowsSecurityAttributes +{ +protected: + SECURITY_ATTRIBUTES m_winSecurityAttributes = {}; + SECURITY_DESCRIPTOR m_securityDescriptor = {}; + PSID pSID = 0; + PACL pACL = 0; + +public: + WindowsSecurityAttributes() + { + InitializeSecurityDescriptor(&m_securityDescriptor, SECURITY_DESCRIPTOR_REVISION); + + SID_IDENTIFIER_AUTHORITY sidIdentifierAuthority = SECURITY_WORLD_SID_AUTHORITY; + AllocateAndInitializeSid(&sidIdentifierAuthority, 1, SECURITY_WORLD_RID, 0, 0, 0, 0, 0, 0, 0, &pSID); + + EXPLICIT_ACCESS explicitAccess = {}; + explicitAccess.grfAccessPermissions = STANDARD_RIGHTS_ALL | SPECIFIC_RIGHTS_ALL; + explicitAccess.grfAccessMode = SET_ACCESS; + explicitAccess.grfInheritance = INHERIT_ONLY; + explicitAccess.Trustee.TrusteeForm = TRUSTEE_IS_SID; + explicitAccess.Trustee.TrusteeType = TRUSTEE_IS_WELL_KNOWN_GROUP; + explicitAccess.Trustee.ptstrName = reinterpret_cast(pSID); + + SetEntriesInAcl(1, &explicitAccess, NULL, &pACL); + SetSecurityDescriptorDacl(&m_securityDescriptor, TRUE, pACL, FALSE); + + m_winSecurityAttributes.nLength = sizeof(m_winSecurityAttributes); + m_winSecurityAttributes.lpSecurityDescriptor = &m_securityDescriptor; + m_winSecurityAttributes.bInheritHandle = TRUE; + } + + WindowsSecurityAttributes(WindowsSecurityAttributes const& rhs) = delete; + WindowsSecurityAttributes(WindowsSecurityAttributes const&& rhs) = delete; + + ~WindowsSecurityAttributes() { + if (pSID) + { + FreeSid(pSID); + } + if (pACL) + { + LocalFree(pACL); + } + } + + operator SECURITY_ATTRIBUTES const* () const { + return &m_winSecurityAttributes; + } +}; + +DirectStorageCUDA::~DirectStorageCUDA() +{ +} + +struct DirectStorageCUDAFileHandleImpl : DirectStorageCUDAFileHandle +{ + ~DirectStorageCUDAFileHandleImpl() {}; + + using File = winrt::com_ptr; + File file; + + IDStorageFile* get() { return file.get(); } + IDStorageFile** put() { return file.put(); } +}; + +InteropBuffer::~InteropBuffer() +{ +} + +class InteropBufferImpl : public InteropBuffer +{ +public: + InteropBufferImpl(winrt::com_ptr const& d3d_device, size_t size) + { + + // Create the ID3D12Resource buffer which will be used as temporary scratch space for d3d + // since it's not possible to import CUDA memory into DX. + D3D12_HEAP_PROPERTIES bufferHeapProps = {}; + bufferHeapProps.Type = D3D12_HEAP_TYPE_DEFAULT; + + + D3D12_HEAP_DESC hd = {}; + hd.SizeInBytes = size; + hd.Properties = bufferHeapProps; + hd.Flags = D3D12_HEAP_FLAG_SHARED; + hd.Alignment = 0; + d3d_device->CreateHeap(&hd, IID_PPV_ARGS(&m_d3d_heap)); + + + D3D12_RESOURCE_DESC bufferDesc = {}; + bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + bufferDesc.Width = size; + bufferDesc.Height = 1; + bufferDesc.DepthOrArraySize = 1; + bufferDesc.MipLevels = 1; + bufferDesc.Format = DXGI_FORMAT_UNKNOWN; + bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + bufferDesc.SampleDesc.Count = 1; + +//#define USE_BUFFER +#if defined(USE_BUFFER) // + winrt::check_hresult(d3d_device->CreateCommittedResource( + &bufferHeapProps, + D3D12_HEAP_FLAG_NONE | D3D12_HEAP_FLAG_SHARED, + &bufferDesc, + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_PPV_ARGS(m_d3d_buffer.put()))); +#else + winrt::check_hresult(d3d_device->CreatePlacedResource( + m_d3d_heap.get(), + 0, + &bufferDesc, + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_PPV_ARGS(m_d3d_buffer.put()))); +#endif + +#if 0 + // debug begin + bufferHeapProps.Type = D3D12_HEAP_TYPE_READBACK; + winrt::check_hresult(d3d_device->CreateCommittedResource( + &bufferHeapProps, + D3D12_HEAP_FLAG_NONE, + &bufferDesc, + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_PPV_ARGS(m_host_buffer.put()))); + + m_host_buffer->Map(0, nullptr, &m_host_ptr); + d3d_device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&m_cmdallocator)); + d3d_device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, m_cmdallocator.get(), nullptr, IID_PPV_ARGS(&m_cmdlist)); +#endif + + D3D12_COMMAND_QUEUE_DESC qd = {}; + qd.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; + d3d_device->CreateCommandQueue(&qd, IID_PPV_ARGS(&m_cmd_queue)); + // debug end + + // create a shared handle to require to import the d3d buffer into CUDA + HANDLE sharedHandle; + WindowsSecurityAttributes windowsSecurityAttributes; + LPCWSTR name = NULL; +#if USE_BUFFER + d3d_device->CreateSharedHandle(m_d3d_buffer.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); + + cudaExternalMemoryHandleDesc externalMemoryHandleDesc = {}; + externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Resource; +#else + d3d_device->CreateSharedHandle(m_d3d_heap.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); + + cudaExternalMemoryHandleDesc externalMemoryHandleDesc = {}; + externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Heap; +#endif + externalMemoryHandleDesc.handle.win32.handle = sharedHandle; + externalMemoryHandleDesc.size = bufferDesc.Width; + externalMemoryHandleDesc.flags = cudaExternalMemoryDedicated; + auto result = cudaImportExternalMemory(&m_externalMemory, &externalMemoryHandleDesc); + + CloseHandle(sharedHandle); + + // get pointer to external memory imported form d3d + cudaExternalMemoryBufferDesc externalMemoryBufferDesc = {}; + externalMemoryBufferDesc.offset = 0; + externalMemoryBufferDesc.size = externalMemoryHandleDesc.size; + externalMemoryBufferDesc.flags = 0; + + result = cudaExternalMemoryGetMappedBuffer(&m_cuda_dev_ptr, m_externalMemory, &externalMemoryBufferDesc); + result = cudaDeviceSynchronize(); + + auto err = cudaMemset(m_cuda_dev_ptr, 255, 512*1024*1024); + result = cudaDeviceSynchronize(); + std::cout << "err: " << err << std::endl; + } + + ~InteropBufferImpl() { + auto result = cudaDestroyExternalMemory(m_externalMemory); + cudaFree(m_cuda_dev_ptr); + if (result != cudaSuccess) { + std::cout << "cudaDestroyExternalMemory interop buffer: " << result << std::endl; + } + } + + void* get_device_ptr() const { + return m_cuda_dev_ptr; + } + + ID3D12Resource* get_d3d_buffer() const { + return m_d3d_buffer.get(); + } + + void* get_host_ptr() const { +#if 0 + m_cmdlist->Reset(m_cmdallocator.get(), nullptr); + m_cmdlist->CopyResource(m_host_buffer.get(), m_d3d_buffer.get()); + m_cmdlist->Close(); + + ID3D12CommandList *ptr = m_cmdlist.get(); + m_cmd_queue->ExecuteCommandLists(1, &ptr); + Sleep(2); +#endif + + return m_host_ptr; + } + +private: + winrt::com_ptr m_cmd_queue = {}; + winrt::com_ptr m_d3d_buffer = {}; + winrt::com_ptr m_d3d_heap = {}; + + cudaExternalMemory_t m_externalMemory; + void* m_cuda_dev_ptr; + + // debug + winrt::com_ptr m_host_buffer = {}; + winrt::com_ptr m_cmdlist = {}; + winrt::com_ptr m_cmdallocator = {}; + void* m_host_ptr; +}; + +class DirectStorageCUDAImpl : public DirectStorageCUDA +{ +public: + DirectStorageCUDAImpl(int scratch_size, int number_of_scratch_spaces); + + virtual ~DirectStorageCUDAImpl() { + flush(true); + std::cout << "~DirectStorageCudaImpl" << std::endl; + } + + struct FileInfo { + const std::string& filename; + void* cuda_device_ptr; + size_t offset; + size_t size; + }; + + virtual std::unique_ptr create_interop_buffer(size_t size); + virtual DirectStorageCUDA::File openFile(std::string const& filename); + virtual void loadFile(DirectStorageCUDA::File const& file, size_t read_start, size_t read_len, void* cuda_dst_ptr); + virtual void loadFile(File const& file, size_t read_start, size_t read_len, InteropBuffer* interop_buffer, size_t interop_buffer_offset); + virtual void flush(bool last); +private: + class StagingArea + { + public: + StagingArea(winrt::com_ptr d3d_device, winrt::com_ptr d3d_factory, size_t chunk_size, size_t number_of_chunks) + : m_d3d_device(d3d_device) + , m_d3d_factory(d3d_factory) + , m_chunk_size(chunk_size) + , m_number_of_chunks(number_of_chunks) + , m_total_staging_space(chunk_size * number_of_chunks) + { + // Create a DirectStorage queue which will be used to load data into a + // buffer on the GPU. + DSTORAGE_QUEUE_DESC queueDesc{}; + queueDesc.Capacity = DSTORAGE_MAX_QUEUE_CAPACITY; + queueDesc.Priority = DSTORAGE_PRIORITY_NORMAL; + queueDesc.SourceType = DSTORAGE_REQUEST_SOURCE_FILE; + queueDesc.Device = m_d3d_device.get(); + + winrt::check_hresult(m_d3d_factory->CreateQueue(&queueDesc, IID_PPV_ARGS(m_d3d_storage_queue.put()))); + + // Create the ID3D12Resource buffer which will be used as temporary scratch space for d3d + // since it's not possible to import CUDA memory into DX. + D3D12_HEAP_PROPERTIES bufferHeapProps = {}; + bufferHeapProps.Type = D3D12_HEAP_TYPE_DEFAULT; + + D3D12_RESOURCE_DESC bufferDesc = {}; + bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + bufferDesc.Width = m_chunk_size * m_number_of_chunks; + bufferDesc.Height = 1; + bufferDesc.DepthOrArraySize = 1; + bufferDesc.MipLevels = 1; + bufferDesc.Format = DXGI_FORMAT_UNKNOWN; + bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + bufferDesc.SampleDesc.Count = 1; + + winrt::check_hresult(m_d3d_device->CreateCommittedResource( + &bufferHeapProps, + D3D12_HEAP_FLAG_NONE | D3D12_HEAP_FLAG_SHARED, + &bufferDesc, + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_PPV_ARGS(m_d3d_scratch_space.put()))); + + + // create a shared handle to require to import the d3d buffer into CUDA + HANDLE sharedHandle; + WindowsSecurityAttributes windowsSecurityAttributes; + LPCWSTR name = NULL; + m_d3d_device->CreateSharedHandle(m_d3d_scratch_space.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); + + cudaExternalMemoryHandleDesc externalMemoryHandleDesc = {}; + externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Resource; + externalMemoryHandleDesc.handle.win32.handle = sharedHandle; + externalMemoryHandleDesc.size = bufferDesc.Width; + externalMemoryHandleDesc.flags = cudaExternalMemoryDedicated; + auto result = cudaImportExternalMemory(&m_externalMemory, &externalMemoryHandleDesc); + + CloseHandle(sharedHandle); + + // get pointer to external memory imported form d3d + cudaExternalMemoryBufferDesc externalMemoryBufferDesc = {}; + externalMemoryBufferDesc.offset = 0; + externalMemoryBufferDesc.size = externalMemoryHandleDesc.size; + externalMemoryBufferDesc.flags = 0; + + result = cudaExternalMemoryGetMappedBuffer(&m_cuda_scratch_space, m_externalMemory, &externalMemoryBufferDesc); + + // create d3d fence for synchronization + auto resultDx = m_d3d_device->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&m_d3d_fence)); + + // import d3d fence as semaphore into CUDA. + cudaExternalSemaphoreHandleDesc extSemHandleDesc = {}; + extSemHandleDesc.type = cudaExternalSemaphoreHandleTypeD3D12Fence; + m_d3d_device->CreateSharedHandle(m_d3d_fence.get(), nullptr, GENERIC_ALL, nullptr, &extSemHandleDesc.handle.win32.handle); + result = cudaImportExternalSemaphore(&m_externalSemaphore, &extSemHandleDesc); + + cudaStreamCreate(&m_cudaStream); + + // intialize fence to wait for in flush + waitParams.params.fence.value = 1; + } + + ~StagingArea() + { + std::cout << "~StagingArea" << std::endl; + auto result = cudaDestroyExternalMemory(m_externalMemory); + cudaFree(m_cuda_scratch_space); + if (result != cudaSuccess) { + std::cout << "cudaDestroyExternalMemory interop buffer: " << result << std::endl; + } + // TODO ensure that no resources are being leaked + } + + // enqueue as much data as possible into the current staging area. + // enqueue will return true if all data has been enqueued, false otherwise. + // start, len and cuda_dev_ptr will be updated. + bool enqueue(DirectStorageCUDA::File const& file, size_t& read_start, size_t& len, void*& cuda_dst_ptr) + { + if (len == 0) + return false; + + m_enqueued = true; + + size_t memcpy_src_start = m_current_staging_offset; + + static size_t load_cnt = 0; + size_t read_end = read_start + len; + for (size_t src_start = read_start; src_start < read_end; src_start += m_chunk_size) + { + ++load_cnt; + size_t src_end = min(read_end, src_start + m_chunk_size); + size_t src_size = src_end - src_start; + + if (m_current_staging_offset + src_size >= m_total_staging_space) + { + //std::cout << load_cnt << std::endl; + load_cnt = 0; + + size_t processed_len = m_current_staging_offset - memcpy_src_start; + m_staging_memcpies.push_back(MemcpyOp(cuda_dst_ptr, (void*)((char*)m_cuda_scratch_space + memcpy_src_start), processed_len)); + + cuda_dst_ptr = reinterpret_cast(reinterpret_cast(cuda_dst_ptr) + processed_len); + read_start += processed_len; + len -= processed_len; + + flush(false); + + memcpy_src_start = m_current_staging_offset; + + return true; + } + + DSTORAGE_REQUEST request = {}; + request.Options.SourceType = DSTORAGE_REQUEST_SOURCE_FILE; + request.Options.DestinationType = DSTORAGE_REQUEST_DESTINATION_BUFFER; + request.Source.File.Source = static_cast(file.get())->get(); + request.Source.File.Offset = src_start; + request.Source.File.Size = src_size; // filesize + request.UncompressedSize = src_size; // filesize + request.Destination.Buffer.Resource = m_d3d_scratch_space.get(); + request.Destination.Buffer.Offset = m_current_staging_offset; + request.Destination.Buffer.Size = src_size; + + m_d3d_storage_queue->EnqueueRequest(&request); + + m_current_staging_offset += request.Destination.Buffer.Size; + } + + m_staging_memcpies.push_back(MemcpyOp((void*)((char*)cuda_dst_ptr), (void*)((char*)m_cuda_scratch_space + memcpy_src_start), m_current_staging_offset - memcpy_src_start)); + + size_t processed_len = m_current_staging_offset - memcpy_src_start; + cuda_dst_ptr = reinterpret_cast(reinterpret_cast(cuda_dst_ptr) + m_current_staging_offset - memcpy_src_start); + read_start += processed_len; + len -= processed_len; + + return false; + } + + void enqueue(DirectStorageCUDA::File const& file, size_t& read_start, size_t& read_len, InteropBuffer* interop_buffer, size_t interop_buffer_offset) + { + InteropBufferImpl* ibi = static_cast(interop_buffer); + bool flushed; + while (read_len) { + size_t request_size = min(m_chunk_size, read_len); + + DSTORAGE_REQUEST request = {}; + request.Options.SourceType = DSTORAGE_REQUEST_SOURCE_FILE; + request.Options.DestinationType = DSTORAGE_REQUEST_DESTINATION_BUFFER; + request.Source.File.Source = static_cast(file.get())->get(); + request.Source.File.Offset = read_start; + request.Source.File.Size = request_size; // filesize + request.UncompressedSize = request_size; // filesize + request.Destination.Buffer.Resource = ibi->get_d3d_buffer(); + request.Destination.Buffer.Offset = interop_buffer_offset; + request.Destination.Buffer.Size = request_size; + //std::cout << read_start / (1024*1024) << " / " << interop_buffer_offset / (1024 * 1024) << "/" << request_size / (1024 * 1024) << std::endl; + + m_d3d_storage_queue->EnqueueRequest(&request); + + read_len -= request_size; + interop_buffer_offset += request_size; + read_start += request_size; + + m_enqueued = true; + //flush(true); + }; + + } + + + void wait() + { + if (m_enqueued) { + cudaStreamSynchronize(m_cudaStream); + m_enqueued = false; + } + } + + void flush(bool last) + { + m_d3d_storage_queue->EnqueueSignal(m_d3d_fence.get(), waitParams.params.fence.value); + m_d3d_storage_queue->Submit(); + + nvtxRangePop(); + nvtxRangePush("wait"); + cudaWaitExternalSemaphoresAsync(&m_externalSemaphore, &waitParams, 1, m_cudaStream); + nvtxRangePop(); + nvtxRangePush("memcpy"); +#if 1 + for (auto const& op : m_staging_memcpies) { + auto result = cudaMemcpyAsync(op.m_dst, op.m_src, op.m_size, cudaMemcpyDeviceToDevice, m_cudaStream); + } +#endif + nvtxRangePop(); + nvtxRangePush("sync"); + //cudaStreamSynchronize(m_cudaStream); + nvtxRangePop(); + + // increase fence value by 1 for next flush call + waitParams.params.fence.value += 1; + + // reset staging area + m_staging_memcpies.clear(); + m_current_staging_offset = 0; + +#if 1 + if (last) { + DSTORAGE_ERROR_RECORD errorRecord{}; + m_d3d_storage_queue->RetrieveErrorRecord(&errorRecord); + if (FAILED(errorRecord.FirstFailure.HResult)) + { + // + // errorRecord.FailureCount - The number of failed requests in the queue since the last + // RetrieveErrorRecord call. + // errorRecord.FirstFailure - Detailed record about the first failed command in the enqueue order. + // + std::cout << "The DirectStorage request failed! HRESULT=0x" << std::hex << errorRecord.FirstFailure.HResult << std::endl; + } + } +#endif + } + + winrt::com_ptr m_d3d_device = {}; + winrt::com_ptr m_d3d_factory = {}; + winrt::com_ptr m_d3d_storage_queue = {}; + winrt::com_ptr m_d3d_scratch_space = {}; + winrt::com_ptr m_d3d_fence = {}; + + // cuda external memory resources + cudaExternalMemoryHandleType m_externalMemoryHandleType; + cudaExternalMemory_t m_externalMemory; + cudaExternalSemaphore_t m_externalSemaphore; + cudaExternalSemaphoreWaitParams waitParams = {}; + + size_t m_chunk_size; + size_t m_number_of_chunks; + size_t m_total_staging_space; + + cudaStream_t m_cudaStream; + void* m_cuda_scratch_space; + bool m_enqueued = false; // is any data enqueued + + // memcpy + size_t m_current_staging_offset = 0; // current offset in the staging buffer + + // memcpies from the staging buffer to the actual CUDA memory + struct MemcpyOp { + MemcpyOp(void* dst, void* src, size_t size) + : m_dst(dst), m_src(src), m_size(size) {} + void* m_dst; + void* m_src; + size_t m_size; + }; + std::vector m_staging_memcpies; + + }; + + winrt::com_ptr m_d3d_device = {}; + winrt::com_ptr m_d3d_factory = {}; + + size_t m_chunk_size; + size_t m_number_of_chunks; + + std::vector> m_staging_areas; + size_t m_staging_index = 0; +}; + +std::unique_ptr DirectStorageCUDA::create(int scratch_size, int number_of_scratch_spaces) +{ + return std::make_unique(scratch_size, number_of_scratch_spaces); +} + + // copy read_len bytes starting at read_start from the given file to the given cuda ptr +void DirectStorageCUDAImpl::loadFile(DirectStorageCUDA::File const& file, size_t read_start, size_t read_len, void* cuda_dst_ptr) +{ + bool flushed; + while (read_len) { + flushed = m_staging_areas[m_staging_index]->enqueue(file, read_start, read_len, cuda_dst_ptr); + if (flushed) { + m_staging_index = (m_staging_index + 1) % m_staging_areas.size(); + m_staging_areas[m_staging_index]->wait(); + } + }; +} + +void DirectStorageCUDAImpl::loadFile(DirectStorageCUDA::File const& file, size_t read_start, size_t read_len, InteropBuffer *interop_buffer, size_t interop_buffer_offset) +{ + if (!interop_buffer) + return; + + m_staging_areas[m_staging_index]->enqueue(file, read_start, read_len, interop_buffer, interop_buffer_offset); +} + + +void DirectStorageCUDAImpl::flush(bool last) +{ + for (auto& sa : m_staging_areas) { + sa->flush(last); + } + if (last) { + for (auto& sa : m_staging_areas) { + sa->wait(); + } + } +} + +DirectStorageCUDAImpl::DirectStorageCUDAImpl(int scratch_size, int number_of_scratch_spaces) + : m_chunk_size(scratch_size) + , m_number_of_chunks(number_of_scratch_spaces) +{ + DSTORAGE_CONFIGURATION direct_storage_config = {}; + direct_storage_config.NumSubmitThreads = 1; + DStorageSetConfiguration(&direct_storage_config); + + winrt::check_hresult(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_12_1, IID_PPV_ARGS(&m_d3d_device))); + winrt::check_hresult(DStorageGetFactory(IID_PPV_ARGS(m_d3d_factory.put()))); + + size_t num_staging_areas = 2; + for (size_t idx = 0; idx < num_staging_areas; ++idx) { + m_staging_areas.emplace_back(std::make_unique(m_d3d_device, m_d3d_factory, m_chunk_size, m_number_of_chunks)); + + } +} + +DirectStorageCUDAImpl::File DirectStorageCUDAImpl::openFile(std::string const& filename) +{ + File file = std::make_unique(); + std::wstring wfilename(filename.begin(), filename.end()); + nvtxRangePush("factory open file"); + HRESULT hr = m_d3d_factory->OpenFile(wfilename.c_str(), IID_PPV_ARGS(static_cast(file.get())->put())); + if (FAILED(hr)) + { + std::wcout << L"The file '" << wfilename << L"' could not be opened. HRESULT=0x" << std::hex << hr << std::endl; + return {}; + } + nvtxRangePop(); + return file; +} + +std::unique_ptr DirectStorageCUDAImpl::create_interop_buffer(size_t size) +{ + return std::make_unique(m_d3d_device, size); +} diff --git a/ggml-cuda/dsc.h b/ggml-cuda/dsc.h new file mode 100644 index 0000000000000..95a8430d54497 --- /dev/null +++ b/ggml-cuda/dsc.h @@ -0,0 +1,56 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once + +#include +#include + +struct DirectStorageCUDAFileHandle { + virtual ~DirectStorageCUDAFileHandle() {}; +}; + +class InteropBuffer { +public: + virtual ~InteropBuffer() = 0; + virtual void* get_device_ptr() const = 0; + virtual void* get_host_ptr() const = 0; +}; + +class DirectStorageCUDA +{ +public: + virtual ~DirectStorageCUDA(); + + using File = std::unique_ptr; + + virtual std::unique_ptr create_interop_buffer(size_t size) = 0; + + virtual File openFile(std::string const& filename) = 0; + virtual void loadFile(File const& file, size_t read_start, size_t read_len, void* cuda_dst_ptr) = 0; + virtual void loadFile(File const& file, size_t read_start, size_t read_len, InteropBuffer *interop_buffer, size_t interop_buffer_offset) = 0; + virtual void flush(bool last = false) = 0; + + static std::unique_ptr create(int scratch_size, int number_of_scratch_spaces); +}; + From 1a07f604516f2a3643b23ede4e11e75055d54387 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 23 Apr 2024 12:00:10 +0200 Subject: [PATCH 3/4] Cleanup tweaks and DSC class. The file copy raid functionality is not protected by an named ifdef --- ggml-cuda.cu | 86 +++++++++++++++++++++++++++++++++++++++++------ ggml-cuda/dsc.cpp | 69 ++++++++++++++++++------------------- ggml-cuda/dsc.h | 3 ++ 3 files changed, 111 insertions(+), 47 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8e1207697374b..b2d9740805268 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -379,13 +379,35 @@ struct ggml_backend_cuda_buffer_context { int device; void * dev_ptr = nullptr; std::string name; +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + std::unique_ptr direct_storage_buffer; +#endif ggml_backend_cuda_buffer_context(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr), name(GGML_CUDA_NAME + std::to_string(device)) { + + } + +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + ggml_backend_cuda_buffer_context(int device, std::unique_ptr && direct_storage_buffer_) : + device(device), dev_ptr(nullptr), + name(GGML_CUDA_NAME + std::to_string(device)), + direct_storage_buffer(std::move(direct_storage_buffer_)) + { + dev_ptr = direct_storage_buffer->get_device_ptr(); } +#endif + ~ggml_backend_cuda_buffer_context() { +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + if (direct_storage_buffer) { + direct_storage_buffer.reset(); + dev_ptr = nullptr; + } +#endif + CUDA_CHECK(cudaFree(dev_ptr)); } }; @@ -451,7 +473,8 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t if (size & (1u << 31)) { size &= ~(1u << 31); if (data == nullptr) { - dsc->flush(); + std::cout << "flush" << std::endl; + dsc->flush(true); return; } struct Temp { @@ -465,7 +488,8 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t if (it == files.end()) { files[filename].handles.push_back(dsc->openFile(filename)); -#if 0 +#define COPY_RAID +#if defined(COPY_RAID) // This is a hack to evaluate how fast data can be read from a 2nd disk. std::filesystem::path p(filename); std::filesystem::path p2("d:"); @@ -482,7 +506,34 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t it = files.find(filename); } - dsc->loadFile(it->second.getFile(), t->weights_off, size, (char*)tensor->data + offset); + //dsc->loadFile(it->second.getFile(), t->weights_off, size, (char*)tensor->data + offset); + if (ctx->direct_storage_buffer) { + size_t tensor_offset = (char*)tensor->data - (char*)ctx->direct_storage_buffer->get_device_ptr(); +#if defined(COPY_RAID) + size_t blocksize = 4 * 1024 * 1024; + for (size_t idx = 0; idx < size; idx += blocksize) { + size_t read_len = size - idx; + if (read_len > blocksize) + read_len = blocksize; + dsc->loadFile(it->second.getFile(), t->weights_off + idx, read_len, ctx->direct_storage_buffer.get(), offset + tensor_offset + idx); + } +#else + dsc->loadFile(it->second.getFile(), t->weights_off, size, ctx->direct_storage_buffer.get(), offset + tensor_offset); +#endif + } + else { +#if defined(COPY_RAID) + size_t blocksize = 2 * 1024 * 1024; + for (size_t idx = 0; idx < size; idx += blocksize) { + size_t read_len = size - idx; + if (read_len > blocksize) + read_len = blocksize; + dsc->loadFile(it->second.getFile(), t->weights_off + idx, read_len, (char*)tensor->data + offset + idx); + } +#else + dsc->loadFile(it->second.getFile(), t->weights_off, size, (char*)tensor->data + offset); +#endif + } } else #endif @@ -561,15 +612,20 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 - void * dev_ptr; +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + if (size < 512 * 1024 * 1024) { + auto interop_buffer = dsc->create_interop_buffer(size); + ggml_backend_cuda_buffer_context* ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, std::move(interop_buffer)); + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); + } +#endif + void* dev_ptr; cudaError_t err = cudaMalloc(&dev_ptr, size); if (err != cudaSuccess) { - fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err)); + fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } - - ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); - + ggml_backend_cuda_buffer_context* ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -605,11 +661,21 @@ GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backen return buft_ctx->device == cuda_ctx->device; } +GGML_CALL static size_t ggml_backend_cuda_get_max_size(ggml_backend_buffer_type_t buft) +{ +#if defined(GGML_ENABLE_DIRECT_STORAGE_CUDA) + //return 512 * 1024 * 1024; // dx interop limit + return SIZE_MAX; +#else + return SIZE_MAX; +#endif +}// allocation max size + static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { /* .get_name = */ ggml_backend_cuda_buffer_type_name, /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_max_size = */ ggml_backend_cuda_get_max_size, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend, /* .is_host = */ NULL, @@ -774,7 +840,7 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backe static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) { return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name; - GGML_UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds + GGML_UNUSED(&ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds } GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { diff --git a/ggml-cuda/dsc.cpp b/ggml-cuda/dsc.cpp index e390cd588d059..3a47f0f02155c 100644 --- a/ggml-cuda/dsc.cpp +++ b/ggml-cuda/dsc.cpp @@ -124,14 +124,16 @@ class InteropBufferImpl : public InteropBuffer D3D12_HEAP_PROPERTIES bufferHeapProps = {}; bufferHeapProps.Type = D3D12_HEAP_TYPE_DEFAULT; +#define USE_HEAP +#if defined(USE_HEAP) D3D12_HEAP_DESC hd = {}; hd.SizeInBytes = size; hd.Properties = bufferHeapProps; hd.Flags = D3D12_HEAP_FLAG_SHARED; hd.Alignment = 0; d3d_device->CreateHeap(&hd, IID_PPV_ARGS(&m_d3d_heap)); - +#endif D3D12_RESOURCE_DESC bufferDesc = {}; bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; @@ -143,27 +145,25 @@ class InteropBufferImpl : public InteropBuffer bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; bufferDesc.SampleDesc.Count = 1; -//#define USE_BUFFER -#if defined(USE_BUFFER) // - winrt::check_hresult(d3d_device->CreateCommittedResource( - &bufferHeapProps, - D3D12_HEAP_FLAG_NONE | D3D12_HEAP_FLAG_SHARED, - &bufferDesc, - D3D12_RESOURCE_STATE_COMMON, - nullptr, - IID_PPV_ARGS(m_d3d_buffer.put()))); -#else +#if defined(USE_HEAP) winrt::check_hresult(d3d_device->CreatePlacedResource( - m_d3d_heap.get(), - 0, - &bufferDesc, - D3D12_RESOURCE_STATE_COMMON, - nullptr, - IID_PPV_ARGS(m_d3d_buffer.put()))); + m_d3d_heap.get(), + 0, + &bufferDesc, + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_PPV_ARGS(m_d3d_buffer.put()))); +#else + winrt::check_hresult(d3d_device->CreateCommittedResource( + &bufferHeapProps, + D3D12_HEAP_FLAG_NONE | D3D12_HEAP_FLAG_SHARED, + &bufferDesc, + D3D12_RESOURCE_STATE_COMMON, + nullptr, + IID_PPV_ARGS(m_d3d_buffer.put()))); #endif -#if 0 - // debug begin +#if defined(DEBUG_READBACK) bufferHeapProps.Type = D3D12_HEAP_TYPE_READBACK; winrt::check_hresult(d3d_device->CreateCommittedResource( &bufferHeapProps, @@ -176,27 +176,26 @@ class InteropBufferImpl : public InteropBuffer m_host_buffer->Map(0, nullptr, &m_host_ptr); d3d_device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&m_cmdallocator)); d3d_device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, m_cmdallocator.get(), nullptr, IID_PPV_ARGS(&m_cmdlist)); -#endif D3D12_COMMAND_QUEUE_DESC qd = {}; qd.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; d3d_device->CreateCommandQueue(&qd, IID_PPV_ARGS(&m_cmd_queue)); - // debug end +#endif // create a shared handle to require to import the d3d buffer into CUDA HANDLE sharedHandle; WindowsSecurityAttributes windowsSecurityAttributes; LPCWSTR name = NULL; -#if USE_BUFFER - d3d_device->CreateSharedHandle(m_d3d_buffer.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); +#if defined(USE_HEAP) + d3d_device->CreateSharedHandle(m_d3d_heap.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); cudaExternalMemoryHandleDesc externalMemoryHandleDesc = {}; - externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Resource; + externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Heap; #else - d3d_device->CreateSharedHandle(m_d3d_heap.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); + d3d_device->CreateSharedHandle(m_d3d_buffer.get(), windowsSecurityAttributes, GENERIC_ALL, name, &sharedHandle); cudaExternalMemoryHandleDesc externalMemoryHandleDesc = {}; - externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Heap; + externalMemoryHandleDesc.type = cudaExternalMemoryHandleTypeD3D12Resource; #endif externalMemoryHandleDesc.handle.win32.handle = sharedHandle; externalMemoryHandleDesc.size = bufferDesc.Width; @@ -212,11 +211,6 @@ class InteropBufferImpl : public InteropBuffer externalMemoryBufferDesc.flags = 0; result = cudaExternalMemoryGetMappedBuffer(&m_cuda_dev_ptr, m_externalMemory, &externalMemoryBufferDesc); - result = cudaDeviceSynchronize(); - - auto err = cudaMemset(m_cuda_dev_ptr, 255, 512*1024*1024); - result = cudaDeviceSynchronize(); - std::cout << "err: " << err << std::endl; } ~InteropBufferImpl() { @@ -235,19 +229,19 @@ class InteropBufferImpl : public InteropBuffer return m_d3d_buffer.get(); } +#if defined(DEBUG_READBACK) void* get_host_ptr() const { -#if 0 m_cmdlist->Reset(m_cmdallocator.get(), nullptr); m_cmdlist->CopyResource(m_host_buffer.get(), m_d3d_buffer.get()); m_cmdlist->Close(); ID3D12CommandList *ptr = m_cmdlist.get(); m_cmd_queue->ExecuteCommandLists(1, &ptr); - Sleep(2); -#endif + Sleep(2); // actually one would have to wait for an event here return m_host_ptr; } +#endif private: winrt::com_ptr m_cmd_queue = {}; @@ -257,11 +251,12 @@ class InteropBufferImpl : public InteropBuffer cudaExternalMemory_t m_externalMemory; void* m_cuda_dev_ptr; - // debug +#if defined(DEBUG_READBACK) winrt::com_ptr m_host_buffer = {}; winrt::com_ptr m_cmdlist = {}; winrt::com_ptr m_cmdallocator = {}; void* m_host_ptr; +#endif }; class DirectStorageCUDAImpl : public DirectStorageCUDA @@ -450,6 +445,7 @@ class DirectStorageCUDAImpl : public DirectStorageCUDA InteropBufferImpl* ibi = static_cast(interop_buffer); bool flushed; while (read_len) { + //std::cout << file.get() << std::endl; size_t request_size = min(m_chunk_size, read_len); DSTORAGE_REQUEST request = {}; @@ -462,7 +458,6 @@ class DirectStorageCUDAImpl : public DirectStorageCUDA request.Destination.Buffer.Resource = ibi->get_d3d_buffer(); request.Destination.Buffer.Offset = interop_buffer_offset; request.Destination.Buffer.Size = request_size; - //std::cout << read_start / (1024*1024) << " / " << interop_buffer_offset / (1024 * 1024) << "/" << request_size / (1024 * 1024) << std::endl; m_d3d_storage_queue->EnqueueRequest(&request); @@ -471,7 +466,7 @@ class DirectStorageCUDAImpl : public DirectStorageCUDA read_start += request_size; m_enqueued = true; - //flush(true); + //flush(false); // flushing less often improves perf a little bit, but removes ability to track current load status }; } diff --git a/ggml-cuda/dsc.h b/ggml-cuda/dsc.h index 95a8430d54497..e5f9088a4f9d2 100644 --- a/ggml-cuda/dsc.h +++ b/ggml-cuda/dsc.h @@ -34,7 +34,10 @@ class InteropBuffer { public: virtual ~InteropBuffer() = 0; virtual void* get_device_ptr() const = 0; + +#if defined(DEBUG_READBACK) virtual void* get_host_ptr() const = 0; +#endif }; class DirectStorageCUDA From 18dbe4b8af23765ccc2c824adc13202a25f0afb1 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 26 Apr 2024 10:32:10 +0200 Subject: [PATCH 4/4] link direct storage to ggml_shared as well. --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b3b7ea42e51f..98ebcbd712389 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1206,6 +1206,8 @@ if (BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) add_library(ggml_shared SHARED $) target_link_libraries(ggml_shared PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) + target_link_libraries(ggml_shared PUBLIC "${LLAMA_DIRECT_STORAGE_DIR}/native/lib/x64/dstorage.lib" cuda cudart d3d12) + install(TARGETS ggml_shared LIBRARY) endif()