@@ -821,7 +821,6 @@ struct whisper_state {
821
821
struct ggml_tensor * embd_enc = nullptr ;
822
822
823
823
// helpers for GPU offloading
824
- std::vector<float > inp_mel;
825
824
std::vector<float > inp_mask;
826
825
827
826
// decode output (2-dimensional array: [n_tokens][n_vocab])
@@ -1815,7 +1814,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
1815
1814
1816
1815
static struct ggml_cgraph * whisper_build_graph_conv (
1817
1816
whisper_context & wctx,
1818
- whisper_state & wstate) {
1817
+ whisper_state & wstate,
1818
+ const int mel_offset) {
1819
1819
const auto & model = wctx.model ;
1820
1820
const auto & hparams = model.hparams ;
1821
1821
@@ -1834,9 +1834,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1834
1834
1835
1835
ggml_cgraph * gf = ggml_new_graph (ctx0);
1836
1836
1837
- struct ggml_tensor * mel = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, 2 *n_ctx, n_mels);
1838
- ggml_set_name (mel, " mel" );
1839
- ggml_set_input (mel);
1837
+ ggml_tensor * mel_inp = wstate.mel .tensor ;
1838
+ ggml_tensor * mel;
1839
+ if (mel_inp) {
1840
+ const int n_len = int (mel_inp->ne [0 ]);
1841
+ const int out_s = 2 * n_ctx;
1842
+ const int i0 = std::min (mel_offset, n_len);
1843
+ const int i1 = std::min (mel_offset + out_s, n_len);
1844
+ const int mel_s = i1 - i0;
1845
+
1846
+ assert (mel_inp->type == GGML_TYPE_F32);
1847
+ assert (mel_inp->ne [1 ] == n_mels);
1848
+
1849
+ ggml_tensor * cur = ggml_view_2d (ctx0, mel_inp, out_s, n_mels, mel_inp->nb [1 ], ggml_row_size (mel_inp->type , i0));
1850
+
1851
+ if (mel_s < out_s) {
1852
+ mel = ggml_pad (ctx0, cur, out_s - mel_s, 0 , 0 , 0 );
1853
+ }
1854
+ else {
1855
+ mel = ggml_cont (ctx0, cur);
1856
+ }
1857
+ }
1858
+ else {
1859
+ // just create some tensor so that the graph/buffer size estimation is correct
1860
+ mel = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
1861
+ }
1862
+ ggml_set_name (mel, " mel" ); // used with external encoding
1840
1863
1841
1864
struct ggml_tensor * cur = nullptr ;
1842
1865
@@ -2218,45 +2241,21 @@ static bool whisper_encode_internal(
2218
2241
{
2219
2242
auto & alloc = wstate.alloc_conv .alloc ;
2220
2243
2221
- ggml_cgraph * gf = whisper_build_graph_conv (wctx, wstate);
2244
+ ggml_cgraph * gf = whisper_build_graph_conv (wctx, wstate, mel_offset );
2222
2245
2223
2246
if (!ggml_gallocr_alloc_graph (alloc, gf)) {
2224
2247
// should never happen as we pre-allocate the memory
2225
2248
return false ;
2226
2249
}
2227
2250
2228
- struct ggml_tensor * mel = ggml_graph_get_tensor (gf, " mel" );
2229
-
2230
- // set the input
2231
- {
2232
- const auto & mel_inp = wstate.mel ;
2233
- const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model .hparams .n_audio_ctx ;
2234
-
2235
- assert (mel->type == GGML_TYPE_F32);
2236
- assert (mel_inp.n_mel == wctx.model .hparams .n_mels );
2237
-
2238
- wstate.inp_mel .resize (ggml_nelements (mel));
2239
-
2240
- float * dst = wstate.inp_mel .data ();
2241
- memset (dst, 0 , ggml_nbytes (mel));
2242
-
2243
- const int i0 = std::min (mel_offset, mel_inp.n_len );
2244
- const int i1 = std::min (mel_offset + 2 *n_ctx, mel_inp.n_len );
2245
-
2246
- for (int j = 0 ; j < mel_inp.n_mel ; ++j) {
2247
- for (int i = i0; i < i1; ++i) {
2248
- dst[j*2 *n_ctx + (i - i0)] = mel_inp.data [j*mel_inp.n_len + i];
2249
- }
2250
- }
2251
-
2252
- ggml_backend_tensor_set (mel, wstate.inp_mel .data (), 0 , ggml_nelements (mel)*sizeof (float ));
2251
+ if (!ggml_graph_compute_helper (wstate.backend , gf, n_threads)) {
2252
+ return false ;
2253
2253
}
2254
2254
2255
- if (!whisper_encode_external (wstate)) {
2256
- if (!ggml_graph_compute_helper (wstate.backend , gf, n_threads)) {
2257
- return false ;
2258
- }
2259
- } else {
2255
+ if (whisper_encode_external (wstate)) {
2256
+ ggml_tensor * mel = ggml_graph_get_tensor (gf, " mel" );
2257
+ assert (mel->ne [1 ] == wctx.model .hparams .n_mels );
2258
+ GGML_UNUSED (mel);
2260
2259
#if defined(WHISPER_USE_COREML)
2261
2260
whisper_coreml_encode (wstate.ctx_coreml , mel->ne [0 ], mel->ne [1 ], (float *) mel->data , (float *) wstate.embd_enc ->data );
2262
2261
#elif defined(WHISPER_USE_OPENVINO)
@@ -2886,6 +2885,54 @@ struct whisper_global_cache {
2886
2885
2887
2886
// Mel spectrogram
2888
2887
2888
+ whisper_mel::~whisper_mel () {
2889
+ reset ();
2890
+ }
2891
+
2892
+ whisper_mel::whisper_mel (whisper_mel && other) noexcept {
2893
+ take (other);
2894
+ }
2895
+
2896
+ whisper_mel & whisper_mel::operator =(whisper_mel && other) noexcept {
2897
+ if (this != &other) {
2898
+ reset ();
2899
+ take (other);
2900
+ }
2901
+ return *this ;
2902
+ }
2903
+
2904
+ void whisper_mel::init (ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
2905
+ this ->n_len_org = n_len_org;
2906
+ assert (!ctx);
2907
+ ctx = ggml_init ({ggml_tensor_overhead (), nullptr , true });
2908
+ tensor = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, n_len, n_mel);
2909
+ buffer = ggml_backend_alloc_buffer (backend, ggml_nbytes (tensor) + ggml_backend_get_alignment (backend));
2910
+ auto alloc = ggml_tallocr_new (buffer);
2911
+ ggml_tallocr_alloc (&alloc, tensor);
2912
+ }
2913
+
2914
+ void whisper_mel::reset () {
2915
+ ggml_free (ctx);
2916
+ ggml_backend_buffer_free (buffer);
2917
+
2918
+ n_len_org = 0 ;
2919
+ tensor = nullptr ;
2920
+ ctx = nullptr ;
2921
+ buffer = nullptr ;
2922
+ }
2923
+
2924
+ void whisper_mel::take (whisper_mel & other) noexcept {
2925
+ n_len_org = other.n_len_org ;
2926
+ tensor = other.tensor ;
2927
+ ctx = other.ctx ;
2928
+ buffer = other.buffer ;
2929
+
2930
+ other.n_len_org = 0 ;
2931
+ other.tensor = nullptr ;
2932
+ other.ctx = nullptr ;
2933
+ other.buffer = nullptr ;
2934
+ }
2935
+
2889
2936
whisper_mel_calc::~whisper_mel_calc () = default ; // export vtable
2890
2937
2891
2938
whisper_span<const float > whisper_mel_calc::hann_window () {
@@ -2973,9 +3020,18 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
2973
3020
}
2974
3021
}
2975
3022
2976
- static void log_mel_spectrogram_worker_thread (int ith, const float * hann, const std::vector<float > & samples,
3023
+ namespace {
3024
+
3025
+ struct whisper_mel_data {
3026
+ int n_len;
3027
+ int n_len_org;
3028
+ int n_mel;
3029
+ float * data;
3030
+ };
3031
+
3032
+ void log_mel_spectrogram_worker_thread (int ith, const float * hann, const std::vector<float > & samples,
2977
3033
int n_samples, int n_threads,
2978
- const whisper_filters & filters, whisper_mel & mel) {
3034
+ const whisper_filters & filters, whisper_mel_data & mel) {
2979
3035
const auto frame_size = WHISPER_N_FFT;
2980
3036
const auto frame_step = WHISPER_HOP_LENGTH;
2981
3037
std::vector<float > fft_in (frame_size, 0.0 );
@@ -3041,10 +3097,11 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const
3041
3097
}
3042
3098
}
3043
3099
}
3044
- namespace {
3100
+
3045
3101
struct mel_calc_cpu : public whisper_mel_calc {
3102
+ ggml_backend_t m_backend;
3046
3103
const whisper_filters& m_filters;
3047
- mel_calc_cpu (const whisper_filters & filters) : m_filters(filters) {}
3104
+ mel_calc_cpu (ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
3048
3105
3049
3106
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
3050
3107
whisper_mel calculate (whisper_span<const float > ssamples, int n_threads) const override {
@@ -3069,15 +3126,24 @@ struct mel_calc_cpu : public whisper_mel_calc {
3069
3126
// reflective pad 200 samples at the beginning of audio
3070
3127
std::reverse_copy (samples + 1 , samples + 1 + stage_2_pad, samples_padded.begin ());
3071
3128
3072
- whisper_mel mel;
3129
+ whisper_mel_data mel;
3073
3130
mel.n_mel = m_filters.n_mel ;
3074
3131
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
3075
3132
// Calculate number of frames + remove the last frame
3076
3133
mel.n_len = (samples_padded.size () - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
3077
3134
// Calculate semi-padded sample length to ensure compatibility
3078
3135
mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
3079
- mel.data .resize (mel.n_mel * mel.n_len );
3080
3136
3137
+ std::vector<float > host_mel_data;
3138
+
3139
+ whisper_mel ret;
3140
+ ret.init (m_backend, mel.n_len , mel.n_len_org , mel.n_mel );
3141
+ if (ggml_backend_buffer_is_host (ret.buffer )) {
3142
+ mel.data = reinterpret_cast <float *>(ret.tensor ->data );
3143
+ } else {
3144
+ host_mel_data.resize (mel.n_len * mel.n_mel );
3145
+ mel.data = host_mel_data.data ();
3146
+ }
3081
3147
3082
3148
{
3083
3149
std::vector<std::thread> workers (n_threads - 1 );
@@ -3114,7 +3180,12 @@ struct mel_calc_cpu : public whisper_mel_calc {
3114
3180
mel.data [i] = (mel.data [i] + 4.0 )/4.0 ;
3115
3181
}
3116
3182
3117
- return mel;
3183
+ if (!host_mel_data.empty ()) {
3184
+ // the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
3185
+ ggml_backend_tensor_set (ret.tensor , host_mel_data.data (), 0 , ggml_nbytes (ret.tensor ));
3186
+ }
3187
+
3188
+ return ret;
3118
3189
}
3119
3190
};
3120
3191
}
@@ -3129,7 +3200,7 @@ whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper
3129
3200
return ret;
3130
3201
} else
3131
3202
#endif
3132
- return new mel_calc_cpu (filters);
3203
+ return new mel_calc_cpu (backend, filters);
3133
3204
}
3134
3205
3135
3206
// split text into tokens
@@ -3347,7 +3418,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3347
3418
{
3348
3419
bool ok = whisper_allocr_graph_init (state->alloc_conv , ctx->backend ,
3349
3420
[&]() {
3350
- return whisper_build_graph_conv (*ctx, *state);
3421
+ return whisper_build_graph_conv (*ctx, *state, 0 );
3351
3422
});
3352
3423
3353
3424
if (!ok) {
@@ -3763,12 +3834,9 @@ int whisper_set_mel_with_state(
3763
3834
return -1 ;
3764
3835
}
3765
3836
3766
- state->mel .n_len = n_len;
3767
- state->mel .n_len_org = n_len;
3768
- state->mel .n_mel = n_mel;
3769
-
3770
- state->mel .data .resize (n_len*n_mel);
3771
- memcpy (state->mel .data .data (), data, n_len*n_mel*sizeof (float ));
3837
+ state->mel .reset ();
3838
+ state->mel .init (ctx->backend , n_len, n_len, n_mel);
3839
+ ggml_backend_tensor_set (state->mel .tensor , data, 0 , ggml_nbytes (state->mel .tensor ));
3772
3840
3773
3841
return 0 ;
3774
3842
}
0 commit comments