Skip to content

Commit ca9a116

Browse files
committed
possibly slower, but cannot use larger batches without modifying ggml library.
1 parent bfeb347 commit ca9a116

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

gpttype_adapter.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
432432
{
433433
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
434434

435-
// if(inputs.gpulayers>0)
436-
// {
437-
// rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
438-
// }
435+
if(inputs.gpulayers>0)
436+
{
437+
rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
438+
}
439439

440440
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
441441
const size_t n_vocab = header.n_vocab;
@@ -1066,15 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
10661066
}
10671067
else
10681068
{
1069-
// if(embd.size()>1)
1070-
// {
1071-
// evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
1072-
// }
1073-
// else
1074-
// {
1069+
if(embd.size()>1)
1070+
{
1071+
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
1072+
}
1073+
else
1074+
{
10751075
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
10761076
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
1077-
//}
1077+
}
10781078

10791079
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
10801080
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;

otherarch/rwkv_v3.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
#include "rwkv_v3.h"
77
#include "ggml.h"
88

9+
#ifdef GGML_USE_CUBLAS
10+
#include "ggml-cuda.h"
11+
#endif
12+
#if defined(GGML_USE_CLBLAST)
13+
#include "ggml-opencl.h"
14+
#endif
15+
916
#include <string>
1017
#include <vector>
1118
#include <cstring>
@@ -1058,7 +1065,11 @@ struct rwkv_future_tensor rwkv_future_graph_work(struct rwkv_future_ctx & ctx,
10581065
const size_t n_threads,
10591066
const size_t sequence_len = 1
10601067
) {
1068+
#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS)
1069+
enum ggml_type mul_mat_type = type == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16;
1070+
#else
10611071
enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type;
1072+
#endif
10621073
return ctx.alloc(GGML_TYPE_I8, rwkv_future_tensor::size(mul_mat_type, ffn_key_height, sequence_len) * n_threads + 64 * (n_threads - 1));
10631074
}
10641075

@@ -1545,7 +1556,38 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32
15451556
}
15461557

15471558
bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) {
1559+
#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS)
1560+
printf("\nOffloading %u (or fewer) layers...",n_layers);
1561+
const auto offload = [&](struct ggml_tensor * tensor) {
1562+
// TODO support multi-GPU
1563+
tensor->backend = GGML_BACKEND_GPU;
1564+
#if defined(GGML_USE_CLBLAST)
1565+
ggml_cl_transform_tensor(tensor->data, tensor);
1566+
#else
1567+
ggml_cuda_transform_tensor(tensor->data, tensor);
1568+
#endif
1569+
};
1570+
1571+
const size_t n_gpu = std::min(n_layers, ctx->instance->model.header.n_layer);
1572+
1573+
if (ctx->gpu_layers < n_gpu) {
1574+
for (size_t & i = ctx->gpu_layers; i < n_gpu; i++) {
1575+
const struct rwkv_layer & layer = ctx->instance->model.layers[i];
1576+
1577+
// TODO also offload other operations to GPU with ggml_cuda_assign_buffers
1578+
offload(layer.att_key);
1579+
offload(layer.att_value);
1580+
offload(layer.att_receptance);
1581+
offload(layer.att_output);
1582+
1583+
offload(layer.ffn_key);
1584+
offload(layer.ffn_value);
1585+
offload(layer.ffn_receptance);
1586+
}
15481587

1588+
return true;
1589+
}
1590+
#endif
15491591
return false;
15501592
}
15511593

0 commit comments

Comments
 (0)