Skip to content

Commit fadae72

Browse files
committed
Merge branch 'hipblas' into develop4Main
2 parents 518eb2a + 8f8ab6c commit fadae72

13 files changed

+1568
-636
lines changed

CMakeLists.txt

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA
5050
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
5151
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
5252
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
53-
option(LLAMA_HIPBLAS "llama: use hipBLAS" ON)
53+
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
54+
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
55+
option(LLAMA_METAL "llama: use Metal" OFF)
56+
option(LLAMA_MPI "llama: use MPI" OFF)
5457
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
5558

5659

@@ -149,6 +152,39 @@ if (LLAMA_HIPBLAS)
149152
endif()
150153
endif()
151154

155+
if (LLAMA_HIPBLAS)
156+
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
157+
158+
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
159+
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
160+
endif()
161+
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
162+
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
163+
endif()
164+
165+
find_package(hip)
166+
find_package(hipblas)
167+
find_package(rocblas)
168+
169+
if (${hipblas_FOUND} AND ${hip_FOUND})
170+
message(STATUS "HIP and hipBLAS found")
171+
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
172+
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
173+
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
174+
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
175+
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
176+
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
177+
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
178+
179+
if (LLAMA_STATIC)
180+
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
181+
endif()
182+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
183+
else()
184+
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
185+
endif()
186+
endif()
187+
152188
if (LLAMA_ALL_WARNINGS)
153189
if (NOT MSVC)
154190
set(c_flags

Makefile

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,6 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
174174
else
175175
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
176176
endif
177-
ifdef LLAMA_CUDA_MMQ_Y
178-
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
179-
else
180-
NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
181-
endif # LLAMA_CUDA_MMQ_Y
182177
#ifdef LLAMA_CUDA_CUBLAS
183178
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
184179
#endif # LLAMA_CUDA_CUBLAS
@@ -228,6 +223,26 @@ ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-l
228223
$(CXX) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
229224
endif # LLAMA_HIPBLAS
230225

226+
ifdef LLAMA_HIPBLAS
227+
ROCM_PATH ?= /opt/rocm
228+
CC := $(ROCM_PATH)/llvm/bin/clang
229+
CXX := $(ROCM_PATH)/llvm/bin/clang++
230+
GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100
231+
LLAMA_CUDA_DMMV_X ?= 32
232+
LLAMA_CUDA_MMV_Y ?= 1
233+
LLAMA_CUDA_KQUANTS_ITER ?= 2
234+
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
235+
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
236+
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib -lhipblas -lamdhip64 -lrocblas
237+
OBJS += ggml-cuda.o
238+
ggml-cuda.o: CXXFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
239+
ggml-cuda.o: CXXFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
240+
ggml-cuda.o: CXXFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
241+
ggml-cuda.o: CXXFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
242+
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
243+
$(CXX) $(CXXFLAGS) -x hip -c -o $@ $<
244+
endif # LLAMA_HIPBLAS
245+
231246
ifdef LLAMA_METAL
232247
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
233248
CXXFLAGS += -DGGML_USE_METAL

examples/common.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
194194
break;
195195
}
196196
params.rope_freq_scale = std::stof(argv[i]);
197+
} else if (arg == "--rope-scale") {
198+
if (++i >= argc) {
199+
invalid_param = true;
200+
break;
201+
}
202+
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
197203
} else if (arg == "--memory-f32") {
198204
params.memory_f16 = false;
199205
} else if (arg == "--top-p") {
@@ -564,8 +570,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
564570
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
565571
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
566572
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
567-
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
568-
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
573+
fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
574+
fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);
575+
fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale);
569576
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
570577
fprintf(stdout, " --no-penalize-nl do not penalize newline token\n");
571578
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");

examples/llama.vim

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
" Requires an already running llama.cpp server
2+
" To install either copy or symlink to ~/.vim/autoload/llama.vim
3+
" Then start with either :call llama#doLlamaGen(),
4+
" or add a keybind to your vimrc such as
5+
" nnoremap Z :call llama#doLlamaGen()<CR>
6+
" Similarly, you could add an insert mode keybind with
7+
" inoremap <C-B> <Cmd>call llama#doLlamaGen()<CR>
8+
"
9+
" g:llama_api_url and g:llama_overrides can be configured in your .vimrc
10+
" let g:llama_api_url = "192.168.1.10:8080"
11+
" llama_overrides can also be set through buffer/window scopes. For instance
12+
" autocmd filetype python let b:llama_overrides = {"temp": 0.2}
13+
" Could be added to your .vimrc to automatically set a lower temperature when
14+
" editing a python script
15+
" Additionally, an override dict can be stored at the top of a file
16+
" !*{"stop": ["User:"]}
17+
" Could be added to the start of your chatlog.txt to set the stopping token
18+
" These parameter dicts are merged together from lowest to highest priority:
19+
" server default -> g:llama_overrides -> w:llama_overrides ->
20+
" b:llama_overrides -> in file (!*) overrides
21+
"
22+
" Sublists (like logit_bias and stop) are overridden, not merged
23+
" Example override:
24+
" !*{"logit_bias": [[13, -5], [2, false]], "temperature": 1, "top_k": 5, "top_p": 0.5, "n_predict": 256, "repeat_last_n": 256, "repeat_penalty": 1.17647}
25+
if !exists("g:llama_api_url")
26+
let g:llama_api_url= "127.0.0.1:8080"
27+
endif
28+
if !exists("g:llama_overrides")
29+
let g:llama_overrides = {}
30+
endif
31+
const s:querydata = {"n_predict": 256, "stop": [ "\n" ], "stream": v:true }
32+
const s:curlcommand = ['curl','--data-raw', "{\"prompt\":\"### System:\"}", '--silent', '--no-buffer', '--request', 'POST', '--url', g:llama_api_url .. '/completion', '--header', "Content-Type: application/json"]
33+
let s:linedict = {}
34+
35+
func s:callbackHandler(bufn, channel, msg)
36+
if len(a:msg) < 3
37+
return
38+
elseif a:msg[0] == "d"
39+
let l:msg = a:msg[6:-1]
40+
else
41+
let l:msg = a:msg
42+
endif
43+
let l:decoded_msg = json_decode(l:msg)
44+
let l:newtext = split(l:decoded_msg['content'], "\n", 1)
45+
if len(l:newtext) > 0
46+
call setbufline(a:bufn, s:linedict[a:bufn], getbufline(a:bufn, s:linedict[a:bufn])[0] .. newtext[0])
47+
else
48+
echo "nothing genned"
49+
endif
50+
if len(newtext) > 1
51+
let l:failed = appendbufline(a:bufn, s:linedict[a:bufn], newtext[1:-1])
52+
let s:linedict[a:bufn] = s:linedict[a:bufn] + len(newtext)-1
53+
endif
54+
if has_key(l:decoded_msg, "stop") && l:decoded_msg.stop
55+
echo "Finished generation"
56+
endif
57+
endfunction
58+
59+
func llama#doLlamaGen()
60+
if exists("b:job")
61+
if job_status(b:job) == "run"
62+
call job_stop(b:job)
63+
return
64+
endif
65+
endif
66+
67+
let l:cbuffer = bufnr("%")
68+
let s:linedict[l:cbuffer] = line('$')
69+
let l:buflines = getbufline(l:cbuffer, 1, 1000)
70+
let l:querydata = copy(s:querydata)
71+
call extend(l:querydata, g:llama_overrides)
72+
if exists("w:llama_overrides")
73+
call extend(l:querydata, w:llama_overrides)
74+
endif
75+
if exists("b:llama_overrides")
76+
call extend(l:querydata, b:llama_overrides)
77+
endif
78+
if l:buflines[0][0:1] == '!*'
79+
let l:userdata = json_decode(l:buflines[0][2:-1])
80+
call extend(l:querydata, l:userdata)
81+
let l:buflines = l:buflines[1:-1]
82+
endif
83+
let l:querydata.prompt = join(l:buflines, "\n")
84+
let l:curlcommand = copy(s:curlcommand)
85+
let l:curlcommand[2] = json_encode(l:querydata)
86+
let b:job = job_start(l:curlcommand, {"callback": function("s:callbackHandler", [l:cbuffer])})
87+
endfunction
88+
89+
" Echos the tokkenization of the provided string , or cursor to end of word
90+
" Onus is placed on the user to include the preceding space
91+
func llama#tokenizeWord(...)
92+
if (a:0 > 0)
93+
let l:input = a:1
94+
else
95+
exe "normal \"*ye"
96+
let l:input = @*
97+
endif
98+
let l:querydata = {"content": l:input}
99+
let l:curlcommand = copy(s:curlcommand)
100+
let l:curlcommand[2] = json_encode(l:querydata)
101+
let l:curlcommand[8] = g:llama_api_url .. "/tokenize"
102+
let s:token_job = job_start(l:curlcommand, {"callback": function("s:tokenizeWordCallback", [l:input])})
103+
endfunction
104+
105+
func s:tokenizeWordCallback(plaintext, channel, msg)
106+
echo '"' .. a:plaintext ..'" - ' .. string(json_decode(a:msg).tokens)
107+
endfunction
108+
109+
110+
" Echos the token count of the entire buffer (or provided string)
111+
" Example usage :echo llama#tokenCount()
112+
func llama#tokenCount(...)
113+
if (a:0 > 0)
114+
let l:buflines = a:1
115+
else
116+
let l:buflines = getline(1,1000)
117+
if l:buflines[0][0:1] == '!*'
118+
let l:buflines = l:buflines[1:-1]
119+
endif
120+
let l:buflines = join(l:buflines, "\n")
121+
endif
122+
let l:querydata = {"content": l:buflines}
123+
let l:curlcommand = copy(s:curlcommand)
124+
let l:curlcommand[2] = json_encode(l:querydata)
125+
let l:curlcommand[8] = g:llama_api_url .. "/tokenize"
126+
let s:token_job = job_start(l:curlcommand, {"callback": "s:tokenCountCallback"})
127+
endfunction
128+
129+
func s:tokenCountCallback(channel, msg)
130+
let resp = json_decode(a:msg)
131+
echo len(resp.tokens)
132+
endfunction

examples/llm.vim

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
" Basic plugin example
2+
13
function! Llm()
24

35
let url = "http://127.0.0.1:8080/completion"
@@ -16,8 +18,10 @@ function! Llm()
1618
" Extract the content field from the response
1719
let content = json_decode(response).content
1820

21+
let split_newlines = split(content, '\n', 1)
22+
1923
" Insert the content at the cursor position
20-
call setline(line('.'), getline('.') . content)
24+
call setline(line('.'), [ getline('.') . split_newlines[0] ] + split_newlines[1:])
2125
endfunction
2226

2327
command! Llm call Llm()

examples/main/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ The `--ctx-size` option allows you to set the size of the prompt context used by
140140

141141
- `-c N, --ctx-size N`: Set the size of the prompt context (default: 512). The LLaMA models were built with a context of 2048, which will yield the best results on longer input/inference. However, increasing the context size beyond 2048 may lead to unpredictable results.
142142

143+
### Extended Context Size
144+
145+
Some fine-tuned models have extened the context length by scaling RoPE. For example, if the original pretrained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8.
146+
147+
- `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model.
148+
143149
### Keep Prompt
144150

145151
The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained.

examples/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ node .
151151

152152
`mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1).
153153

154+
`grammar`: Set grammar for grammar-based sampling (default: no grammar)
155+
154156
`seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
155157

156158
`ignore_eos`: Ignore end of stream token and continue generating (default: false).

0 commit comments

Comments
 (0)