@@ -1287,6 +1287,27 @@ static llama_vocab::id llama_sample_top_p_top_k(
1287
1287
// quantization
1288
1288
//
1289
1289
1290
+ #include " ggml_internal.h"
1291
+
1292
+ struct error_stats {
1293
+ size_t num_samples;
1294
+ double total_error;
1295
+ double max_error;
1296
+ };
1297
+
1298
+ static void update_error_stats (int64_t nelements, const float * input, const float * output, error_stats & stats) {
1299
+ for (int64_t i = 0 ; i < nelements; i++) {
1300
+ double diff = input[i] - output[i];
1301
+ stats.total_error += diff * diff;
1302
+ stats.max_error = fmax (fabs (diff), stats.max_error );
1303
+ }
1304
+ stats.num_samples += nelements;
1305
+ }
1306
+
1307
+ static void print_error_stats (const std::string & name, const error_stats & stats) {
1308
+ printf (" %-50s: mse %.8f, maxerr %.8f\n " , name.c_str (), stats.total_error / (double ) stats.num_samples , stats.max_error );
1309
+ }
1310
+
1290
1311
// TODO: reuse code from the llama_model_load() somehow
1291
1312
static bool llama_model_quantize_internal (const std::string & fname_inp, const std::string & fname_out, int itype) {
1292
1313
ggml_type type = GGML_TYPE_Q4_1;
@@ -1312,10 +1333,17 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
1312
1333
return false ;
1313
1334
}
1314
1335
1315
- auto fout = std::ofstream (fname_out, std::ios::binary);
1316
- if (!fout) {
1317
- fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname_out.c_str ());
1318
- return false ;
1336
+ bool stats = fname_out.empty ();
1337
+ error_stats total_error {};
1338
+ std::vector<float > output_scratch;
1339
+
1340
+ std::ofstream fout;
1341
+ if (!stats) {
1342
+ fout.open (fname_out, std::ios::binary);
1343
+ if (!fout) {
1344
+ fprintf (stderr, " %s: failed to open '%s' for writing\n " , __func__, fname_out.c_str ());
1345
+ return false ;
1346
+ }
1319
1347
}
1320
1348
1321
1349
// verify magic
@@ -1549,6 +1577,15 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
1549
1577
printf (" %5.3f " , hist_cur[i] / float (nelements));
1550
1578
}
1551
1579
printf (" \n " );
1580
+
1581
+ if (stats && !std::regex_match (name, std::regex (" norm" ))) {
1582
+ quantize_fns_t qfns = ggml_internal_get_quantize_fn (type);
1583
+ #define QK 32
1584
+ assert (nelements % QK == 0 );
1585
+ output_scratch.resize (nelements);
1586
+ qfns.dequantize_row_q (work.data (), output_scratch.data (), nelements);
1587
+ update_error_stats (nelements, data_f32.data (), output_scratch.data (), total_error);
1588
+ }
1552
1589
} else {
1553
1590
printf (" size = %8.3f MB\n " , data_u8.size ()/1024.0 /1024.0 );
1554
1591
fout.write (reinterpret_cast <char *>(data_u8.data ()), data_u8.size ());
@@ -1578,6 +1615,11 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
1578
1615
finp.close ();
1579
1616
fout.close ();
1580
1617
1618
+ if (stats) {
1619
+ static const char * ggml_type_str[] = { " q4_0" , " q4_1" , };
1620
+ print_error_stats (ggml_type_str[type], total_error);
1621
+ }
1622
+
1581
1623
return true ;
1582
1624
}
1583
1625
0 commit comments