Skip to content

whisper : calculate mel spectrogram directly into a ggml_tensor #2208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions whisper-mel-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cublas_v2.h>
#include <cuComplex.h>
#include <cub/device/device_reduce.cuh>
#include <device_launch_parameters.h>

#include <algorithm>

Expand Down Expand Up @@ -301,27 +302,23 @@ public:
&fzero,
mel_data, int(n_mag_frames)));

float * log_mels = nullptr;
CUDA_CHECK(cudaMallocAsync(&log_mels, m_n_mel * n_mag_frames * sizeof(float), m_stream));
whisper_mel ret;
// Calculate semi-padded sample length to ensure compatibility
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel);
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));

float* log_mels = reinterpret_cast<float*>(ret.tensor->data);

calc_log_mel(
mel_data, int(m_n_mel * n_mag_frames),
m_log_mel_temp_storage, int(m_log_mel_temp_storage_size),
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
log_mels, m_stream);

whisper_mel ret;
ret.n_mel = m_n_mel;
ret.n_len = int(n_mag_frames);
// Calculate semi-padded sample length to ensure compatibility
ret.n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
ret.data.resize(m_n_mel * n_mag_frames);
CUDA_CHECK(cudaMemcpyAsync(ret.data.data(), log_mels, ret.data.size() * sizeof(float), cudaMemcpyDeviceToHost, m_stream));

CUDA_CHECK(cudaStreamSynchronize(m_stream));

// cleanup
CUFFT_CHECK(cufftDestroy(plan));
CUDA_CHECK(cudaFreeAsync(log_mels, m_stream));
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
Expand Down
20 changes: 16 additions & 4 deletions whisper-mel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@
#include <vector>

struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
int n_len_org = 0;

std::vector<float> data;
ggml_tensor * tensor = nullptr;
ggml_context * ctx = nullptr;
ggml_backend_buffer_t buffer = nullptr;

whisper_mel() = default;
~whisper_mel();

whisper_mel(const whisper_mel &) = delete;
whisper_mel & operator=(const whisper_mel &) = delete;
whisper_mel(whisper_mel &&) noexcept;
whisper_mel & operator=(whisper_mel &&) noexcept;

void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
void reset();
void take(whisper_mel & other) noexcept;
};

struct whisper_filters {
Expand Down
170 changes: 119 additions & 51 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,6 @@ struct whisper_state {
struct ggml_tensor * embd_enc = nullptr;

// helpers for GPU offloading
std::vector<float> inp_mel;
std::vector<float> inp_mask;

// decode output (2-dimensional array: [n_tokens][n_vocab])
Expand Down Expand Up @@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {

static struct ggml_cgraph * whisper_build_graph_conv(
whisper_context & wctx,
whisper_state & wstate) {
whisper_state & wstate,
const int mel_offset) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;

Expand All @@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(

ggml_cgraph * gf = ggml_new_graph(ctx0);

struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
ggml_set_name(mel, "mel");
ggml_set_input(mel);
ggml_tensor * mel_inp = wstate.mel.tensor;
ggml_tensor * mel;
if (mel_inp) {
const int n_len = int(mel_inp->ne[0]);
const int out_s = 2 * n_ctx;
const int i0 = std::min(mel_offset, n_len);
const int i1 = std::min(mel_offset + out_s, n_len);
const int mel_s = i1 - i0;

assert(mel_inp->type == GGML_TYPE_F32);
assert(mel_inp->ne[1] == n_mels);

ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));

if (mel_s < out_s) {
mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
}
else {
mel = ggml_cont(ctx0, cur);
}
}
else {
// just create some tensor so that the graph/buffer size estimation is correct
mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
}
ggml_set_name(mel, "mel"); // used with external encoding

struct ggml_tensor * cur = nullptr;

Expand Down Expand Up @@ -2218,45 +2241,21 @@ static bool whisper_encode_internal(
{
auto & alloc = wstate.alloc_conv.alloc;

ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);

if (!ggml_gallocr_alloc_graph(alloc, gf)) {
// should never happen as we pre-allocate the memory
return false;
}

struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");

// set the input
{
const auto & mel_inp = wstate.mel;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;

assert(mel->type == GGML_TYPE_F32);
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);

wstate.inp_mel.resize(ggml_nelements(mel));

float * dst = wstate.inp_mel.data();
memset(dst, 0, ggml_nbytes(mel));

const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);

for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}

ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}

if (!whisper_encode_external(wstate)) {
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
}
} else {
if (whisper_encode_external(wstate)) {
ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
assert(mel->ne[1] == wctx.model.hparams.n_mels);
GGML_UNUSED(mel);
#if defined(WHISPER_USE_COREML)
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
#elif defined(WHISPER_USE_OPENVINO)
Expand Down Expand Up @@ -2886,6 +2885,54 @@ struct whisper_global_cache {

// Mel spectrogram

whisper_mel::~whisper_mel() {
reset();
}

whisper_mel::whisper_mel(whisper_mel && other) noexcept {
take(other);
}

whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept {
if (this != &other) {
reset();
take(other);
}
return *this;
}

void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
this->n_len_org = n_len_org;
assert(!ctx);
ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel);
buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend));
auto alloc = ggml_tallocr_new(buffer);
ggml_tallocr_alloc(&alloc, tensor);
}

void whisper_mel::reset() {
ggml_free(ctx);
ggml_backend_buffer_free(buffer);

n_len_org = 0;
tensor = nullptr;
ctx = nullptr;
buffer = nullptr;
}

void whisper_mel::take(whisper_mel & other) noexcept {
n_len_org = other.n_len_org;
tensor = other.tensor;
ctx = other.ctx;
buffer = other.buffer;

other.n_len_org = 0;
other.tensor = nullptr;
other.ctx = nullptr;
other.buffer = nullptr;
}

whisper_mel_calc::~whisper_mel_calc() = default; // export vtable

whisper_span<const float> whisper_mel_calc::hann_window() {
Expand Down Expand Up @@ -2973,9 +3020,18 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
}
}

static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
namespace {

struct whisper_mel_data {
int n_len;
int n_len_org;
int n_mel;
float* data;
};

void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
const whisper_filters & filters, whisper_mel_data & mel) {
const auto frame_size = WHISPER_N_FFT;
const auto frame_step = WHISPER_HOP_LENGTH;
std::vector<float> fft_in(frame_size, 0.0);
Expand Down Expand Up @@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
}
}
}
namespace {

struct mel_calc_cpu : public whisper_mel_calc {
ggml_backend_t m_backend;
const whisper_filters& m_filters;
mel_calc_cpu(const whisper_filters & filters) : m_filters(filters) {}
mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}

// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) const override {
Expand All @@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc {
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());

whisper_mel mel;
whisper_mel_data mel;
mel.n_mel = m_filters.n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
mel.data.resize(mel.n_mel * mel.n_len);

std::vector<float> host_mel_data;

whisper_mel ret;
ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
if (ggml_backend_buffer_is_host(ret.buffer)) {
mel.data = reinterpret_cast<float*>(ret.tensor->data);
} else {
host_mel_data.resize(mel.n_len * mel.n_mel);
mel.data = host_mel_data.data();
}

{
std::vector<std::thread> workers(n_threads - 1);
Expand Down Expand Up @@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc {
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}

return mel;
if (!host_mel_data.empty()) {
// the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
}

return ret;
}
};
}
Expand All @@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
return ret;
} else
#endif
return new mel_calc_cpu(filters);
return new mel_calc_cpu(backend, filters);
}

// split text into tokens
Expand Down Expand Up @@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
{
bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
[&]() {
return whisper_build_graph_conv(*ctx, *state);
return whisper_build_graph_conv(*ctx, *state, 0);
});

if (!ok) {
Expand Down Expand Up @@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state(
return -1;
}

state->mel.n_len = n_len;
state->mel.n_len_org = n_len;
state->mel.n_mel = n_mel;

state->mel.data.resize(n_len*n_mel);
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
state->mel.reset();
state->mel.init(ctx->backend, n_len, n_len, n_mel);
ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));

return 0;
}
Expand Down
Loading