Skip to content

Commit cec7cb9

Browse files
committed
quick & dirty: extend quantize.cpp for statistics
1 parent d51882b commit cec7cb9

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

examples/quantize/quantize.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
int main(int argc, char ** argv) {
1111
ggml_time_init();
1212

13-
if (argc != 4) {
13+
if (argc == 3) {
14+
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
15+
fprintf(stderr, "Assuming you just want statistics\n");
16+
}
17+
else if (argc != 4) {
1418
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
1519
fprintf(stderr, " type = 2 - q4_0\n");
1620
fprintf(stderr, " type = 3 - q4_1\n");
@@ -25,9 +29,9 @@ int main(int argc, char ** argv) {
2529
}
2630

2731
const std::string fname_inp = argv[1];
28-
const std::string fname_out = argv[2];
32+
const std::string fname_out = (argc == 4) ? argv[2] : "";
2933

30-
const int itype = atoi(argv[3]);
34+
const int itype = atoi(argv[argc - 1]);
3135

3236
const int64_t t_main_start_us = ggml_time_us();
3337

llama.cpp

+46-4
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,27 @@ static llama_vocab::id llama_sample_top_p_top_k(
12871287
// quantization
12881288
//
12891289

1290+
#include "ggml_internal.h"
1291+
1292+
struct error_stats {
1293+
size_t num_samples;
1294+
double total_error;
1295+
double max_error;
1296+
};
1297+
1298+
static void update_error_stats(int64_t nelements, const float * input, const float * output, error_stats & stats) {
1299+
for (int64_t i = 0; i < nelements; i++) {
1300+
double diff = input[i] - output[i];
1301+
stats.total_error += diff * diff;
1302+
stats.max_error = fmax(fabs(diff), stats.max_error);
1303+
}
1304+
stats.num_samples += nelements;
1305+
}
1306+
1307+
static void print_error_stats(const std::string & name, const error_stats & stats) {
1308+
printf("%-50s: mse %.8f, maxerr %.8f\n", name.c_str(), stats.total_error / (double) stats.num_samples, stats.max_error);
1309+
}
1310+
12901311
// TODO: reuse code from the llama_model_load() somehow
12911312
static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
12921313
ggml_type type = GGML_TYPE_Q4_1;
@@ -1312,10 +1333,17 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
13121333
return false;
13131334
}
13141335

1315-
auto fout = std::ofstream(fname_out, std::ios::binary);
1316-
if (!fout) {
1317-
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
1318-
return false;
1336+
bool stats = fname_out.empty();
1337+
error_stats total_error {};
1338+
std::vector<float> output_scratch;
1339+
1340+
std::ofstream fout;
1341+
if (!stats) {
1342+
fout.open(fname_out, std::ios::binary);
1343+
if (!fout) {
1344+
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
1345+
return false;
1346+
}
13191347
}
13201348

13211349
// verify magic
@@ -1549,6 +1577,15 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
15491577
printf("%5.3f ", hist_cur[i] / float(nelements));
15501578
}
15511579
printf("\n");
1580+
1581+
if (stats && !std::regex_match(name, std::regex("norm"))) {
1582+
quantize_fns_t qfns = ggml_internal_get_quantize_fn(type);
1583+
#define QK 32
1584+
assert(nelements % QK == 0);
1585+
output_scratch.resize(nelements);
1586+
qfns.dequantize_row_q(work.data(), output_scratch.data(), nelements);
1587+
update_error_stats(nelements, data_f32.data(), output_scratch.data(), total_error);
1588+
}
15521589
} else {
15531590
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
15541591
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
@@ -1578,6 +1615,11 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
15781615
finp.close();
15791616
fout.close();
15801617

1618+
if (stats) {
1619+
static const char * ggml_type_str[] = { "q4_0", "q4_1", };
1620+
print_error_stats(ggml_type_str[type], total_error);
1621+
}
1622+
15811623
return true;
15821624
}
15831625

0 commit comments

Comments
 (0)