Skip to content

Commit f842d31

Browse files
authored
whisper : calculate mel spectrogram directly into a ggml_tensor (#2208)
* whisper : calculate mel spectrogram directly into a ggml_tensor * whisper : remove unused temp buffer from state * whisper : fix not initializing wstate.embd_enc
1 parent ffef323 commit f842d31

File tree

3 files changed

+144
-67
lines changed

3 files changed

+144
-67
lines changed

whisper-mel-cuda.cu

+9-12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cublas_v2.h>
99
#include <cuComplex.h>
1010
#include <cub/device/device_reduce.cuh>
11+
#include <device_launch_parameters.h>
1112

1213
#include <algorithm>
1314

@@ -301,27 +302,23 @@ public:
301302
&fzero,
302303
mel_data, int(n_mag_frames)));
303304

304-
float * log_mels = nullptr;
305-
CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream));
305+
whisper_mel ret;
306+
// Calculate semi-padded sample length to ensure compatibility
307+
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
308+
ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel);
309+
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
310+
311+
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
306312

307313
calc_log_mel(
308314
mel_data, int(m_n_mel * n_mag_frames),
309-
m_log_mel_temp_storage, int(m_log_mel_temp_storage_size),
315+
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
310316
log_mels, m_stream);
311317

312-
whisper_mel ret;
313-
ret.n_mel = m_n_mel;
314-
ret.n_len = int(n_mag_frames);
315-
// Calculate semi-padded sample length to ensure compatibility
316-
ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
317-
ret.data.resize(m_n_mel * n_mag_frames);
318-
CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream));
319-
320318
CUDA_CHECK(cudaStreamSynchronize(m_stream));
321319

322320
// cleanup
323321
CUFFT_CHECK(cufftDestroy(plan));
324-
CUDA_CHECK(cudaFreeAsync(log_mels, m_stream));
325322
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
326323
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
327324
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));

whisper-mel.hpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,23 @@
33
#include <vector>
44

55
struct whisper_mel {
6-
int n_len;
7-
int n_len_org;
8-
int n_mel;
6+
int n_len_org = 0;
97

10-
std::vector<float> data;
8+
ggml_tensor * tensor = nullptr;
9+
ggml_context * ctx = nullptr;
10+
ggml_backend_buffer_t buffer = nullptr;
11+
12+
whisper_mel() = default;
13+
~whisper_mel();
14+
15+
whisper_mel(const whisper_mel &) = delete;
16+
whisper_mel & operator=(const whisper_mel &) = delete;
17+
whisper_mel(whisper_mel &&) noexcept;
18+
whisper_mel & operator=(whisper_mel &&) noexcept;
19+
20+
void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
21+
void reset();
22+
void take(whisper_mel & other) noexcept;
1123
};
1224

1325
struct whisper_filters {

whisper.cpp

+119-51
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,6 @@ struct whisper_state {
821821
struct ggml_tensor * embd_enc = nullptr;
822822

823823
// helpers for GPU offloading
824-
std::vector<float> inp_mel;
825824
std::vector<float> inp_mask;
826825

827826
// decode output (2-dimensional array: [n_tokens][n_vocab])
@@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
18151814

18161815
static struct ggml_cgraph * whisper_build_graph_conv(
18171816
whisper_context & wctx,
1818-
whisper_state & wstate) {
1817+
whisper_state & wstate,
1818+
const int mel_offset) {
18191819
const auto & model = wctx.model;
18201820
const auto & hparams = model.hparams;
18211821

@@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
18341834

18351835
ggml_cgraph * gf = ggml_new_graph(ctx0);
18361836

1837-
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
1838-
ggml_set_name(mel, "mel");
1839-
ggml_set_input(mel);
1837+
ggml_tensor * mel_inp = wstate.mel.tensor;
1838+
ggml_tensor * mel;
1839+
if (mel_inp) {
1840+
const int n_len = int(mel_inp->ne[0]);
1841+
const int out_s = 2 * n_ctx;
1842+
const int i0 = std::min(mel_offset, n_len);
1843+
const int i1 = std::min(mel_offset + out_s, n_len);
1844+
const int mel_s = i1 - i0;
1845+
1846+
assert(mel_inp->type == GGML_TYPE_F32);
1847+
assert(mel_inp->ne[1] == n_mels);
1848+
1849+
ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));
1850+
1851+
if (mel_s < out_s) {
1852+
mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
1853+
}
1854+
else {
1855+
mel = ggml_cont(ctx0, cur);
1856+
}
1857+
}
1858+
else {
1859+
// just create some tensor so that the graph/buffer size estimation is correct
1860+
mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
1861+
}
1862+
ggml_set_name(mel, "mel"); // used with external encoding
18401863

18411864
struct ggml_tensor * cur = nullptr;
18421865

@@ -2218,45 +2241,21 @@ static bool whisper_encode_internal(
22182241
{
22192242
auto & alloc = wstate.alloc_conv.alloc;
22202243

2221-
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
2244+
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
22222245

22232246
if (!ggml_gallocr_alloc_graph(alloc, gf)) {
22242247
// should never happen as we pre-allocate the memory
22252248
return false;
22262249
}
22272250

2228-
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
2229-
2230-
// set the input
2231-
{
2232-
const auto & mel_inp = wstate.mel;
2233-
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
2234-
2235-
assert(mel->type == GGML_TYPE_F32);
2236-
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
2237-
2238-
wstate.inp_mel.resize(ggml_nelements(mel));
2239-
2240-
float * dst = wstate.inp_mel.data();
2241-
memset(dst, 0, ggml_nbytes(mel));
2242-
2243-
const int i0 = std::min(mel_offset, mel_inp.n_len);
2244-
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
2245-
2246-
for (int j = 0; j < mel_inp.n_mel; ++j) {
2247-
for (int i = i0; i < i1; ++i) {
2248-
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
2249-
}
2250-
}
2251-
2252-
ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
2251+
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2252+
return false;
22532253
}
22542254

2255-
if (!whisper_encode_external(wstate)) {
2256-
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2257-
return false;
2258-
}
2259-
} else {
2255+
if (whisper_encode_external(wstate)) {
2256+
ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
2257+
assert(mel->ne[1] == wctx.model.hparams.n_mels);
2258+
GGML_UNUSED(mel);
22602259
#if defined(WHISPER_USE_COREML)
22612260
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
22622261
#elif defined(WHISPER_USE_OPENVINO)
@@ -2886,6 +2885,54 @@ struct whisper_global_cache {
28862885

28872886
// Mel spectrogram
28882887

2888+
whisper_mel::~whisper_mel() {
2889+
reset();
2890+
}
2891+
2892+
whisper_mel::whisper_mel(whisper_mel && other) noexcept {
2893+
take(other);
2894+
}
2895+
2896+
whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept {
2897+
if (this != &other) {
2898+
reset();
2899+
take(other);
2900+
}
2901+
return *this;
2902+
}
2903+
2904+
void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
2905+
this->n_len_org = n_len_org;
2906+
assert(!ctx);
2907+
ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
2908+
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel);
2909+
buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend));
2910+
auto alloc = ggml_tallocr_new(buffer);
2911+
ggml_tallocr_alloc(&alloc, tensor);
2912+
}
2913+
2914+
void whisper_mel::reset() {
2915+
ggml_free(ctx);
2916+
ggml_backend_buffer_free(buffer);
2917+
2918+
n_len_org = 0;
2919+
tensor = nullptr;
2920+
ctx = nullptr;
2921+
buffer = nullptr;
2922+
}
2923+
2924+
void whisper_mel::take(whisper_mel & other) noexcept {
2925+
n_len_org = other.n_len_org;
2926+
tensor = other.tensor;
2927+
ctx = other.ctx;
2928+
buffer = other.buffer;
2929+
2930+
other.n_len_org = 0;
2931+
other.tensor = nullptr;
2932+
other.ctx = nullptr;
2933+
other.buffer = nullptr;
2934+
}
2935+
28892936
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
28902937

28912938
whisper_span<const float> whisper_mel_calc::hann_window() {
@@ -2973,9 +3020,18 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
29733020
}
29743021
}
29753022

2976-
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
3023+
namespace {
3024+
3025+
struct whisper_mel_data {
3026+
int n_len;
3027+
int n_len_org;
3028+
int n_mel;
3029+
float* data;
3030+
};
3031+
3032+
void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
29773033
int n_samples, int n_threads,
2978-
const whisper_filters & filters, whisper_mel & mel) {
3034+
const whisper_filters & filters, whisper_mel_data & mel) {
29793035
const auto frame_size = WHISPER_N_FFT;
29803036
const auto frame_step = WHISPER_HOP_LENGTH;
29813037
std::vector<float> fft_in(frame_size, 0.0);
@@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
30413097
}
30423098
}
30433099
}
3044-
namespace {
3100+
30453101
struct mel_calc_cpu : public whisper_mel_calc {
3102+
ggml_backend_t m_backend;
30463103
const whisper_filters& m_filters;
3047-
mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {}
3104+
mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
30483105

30493106
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
30503107
whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
@@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc {
30693126
// reflective pad 200 samples at the beginning of audio
30703127
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
30713128

3072-
whisper_mel mel;
3129+
whisper_mel_data mel;
30733130
mel.n_mel = m_filters.n_mel;
30743131
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
30753132
// Calculate number of frames + remove the last frame
30763133
mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
30773134
// Calculate semi-padded sample length to ensure compatibility
30783135
mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
3079-
mel.data.resize(mel.n_mel * mel.n_len);
30803136

3137+
std::vector<float> host_mel_data;
3138+
3139+
whisper_mel ret;
3140+
ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
3141+
if (ggml_backend_buffer_is_host(ret.buffer)) {
3142+
mel.data = reinterpret_cast<float*>(ret.tensor->data);
3143+
} else {
3144+
host_mel_data.resize(mel.n_len * mel.n_mel);
3145+
mel.data = host_mel_data.data();
3146+
}
30813147

30823148
{
30833149
std::vector<std::thread> workers(n_threads - 1);
@@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc {
31143180
mel.data[i] = (mel.data[i] + 4.0)/4.0;
31153181
}
31163182

3117-
return mel;
3183+
if (!host_mel_data.empty()) {
3184+
// the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
3185+
ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
3186+
}
3187+
3188+
return ret;
31183189
}
31193190
};
31203191
}
@@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
31293200
return ret;
31303201
} else
31313202
#endif
3132-
return new mel_calc_cpu(filters);
3203+
return new mel_calc_cpu(backend, filters);
31333204
}
31343205

31353206
// split text into tokens
@@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
33473418
{
33483419
bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
33493420
[&]() {
3350-
return whisper_build_graph_conv(*ctx, *state);
3421+
return whisper_build_graph_conv(*ctx, *state, 0);
33513422
});
33523423

33533424
if (!ok) {
@@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state(
37633834
return -1;
37643835
}
37653836

3766-
state->mel.n_len = n_len;
3767-
state->mel.n_len_org = n_len;
3768-
state->mel.n_mel = n_mel;
3769-
3770-
state->mel.data.resize(n_len*n_mel);
3771-
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
3837+
state->mel.reset();
3838+
state->mel.init(ctx->backend, n_len, n_len, n_mel);
3839+
ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
37723840

37733841
return 0;
37743842
}

0 commit comments

Comments
 (0)