@@ -16076,19 +16076,21 @@ static int llama_decode_internal(
16076
16076
return -1;
16077
16077
}
16078
16078
16079
- for (uint32_t i = 0; i < n_tokens_all; ++i) {
16080
- if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= lctx.model.vocab.n_vocab) {
16081
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
16082
- return -1;
16083
- }
16084
- }
16085
-
16086
16079
const auto & model = lctx.model;
16087
16080
const auto & hparams = model.hparams;
16088
16081
const auto & cparams = lctx.cparams;
16089
16082
16090
16083
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
16091
16084
16085
+ if (batch_all.token) {
16086
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
16087
+ if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
16088
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
16089
+ return -1;
16090
+ }
16091
+ }
16092
+ }
16093
+
16092
16094
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
16093
16095
16094
16096
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
@@ -16375,19 +16377,21 @@ static int llama_encode_internal(
16375
16377
return -1;
16376
16378
}
16377
16379
16378
- for (uint32_t i = 0; i < n_tokens; ++i) {
16379
- if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= lctx.model.vocab.n_vocab) {
16380
- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
16381
- return -1;
16382
- }
16383
- }
16384
-
16385
16380
const auto & model = lctx.model;
16386
16381
const auto & hparams = model.hparams;
16387
16382
const auto & cparams = lctx.cparams;
16388
16383
16389
16384
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
16390
16385
16386
+ if (batch.token) {
16387
+ for (uint32_t i = 0; i < n_tokens; ++i) {
16388
+ if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
16389
+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
16390
+ return -1;
16391
+ }
16392
+ }
16393
+ }
16394
+
16391
16395
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
16392
16396
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
16393
16397
0 commit comments