Skip to content

Commit 6c4fab6

Browse files
authored
Introduce wasm/webgpu inference chain to prevent "Session already started" errors (#1293)
* Add PKV generation unit test with onnxruntime-genai GQA model * Introduce wasm/webgpu inference chain to prevent "Session already started" errors
1 parent ce3c266 commit 6c4fab6

File tree

2 files changed

+85
-8
lines changed

2 files changed

+85
-8
lines changed

src/models.js

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
237237
const session_config = {
238238
dtype: selectedDtype,
239239
kv_cache_dtype,
240+
device: selectedDevice,
240241
}
241242

242243
// Construct the model file name
@@ -417,6 +418,10 @@ function validateInputs(session, inputs) {
417418
return checkedInputs;
418419
}
419420

421+
// Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
422+
// For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
423+
let webInferenceChain = Promise.resolve();
424+
420425
/**
421426
* Executes an InferenceSession using the specified inputs.
422427
* NOTE: `inputs` must contain at least the input names of the model.
@@ -433,17 +438,28 @@ async function sessionRun(session, inputs) {
433438
try {
434439
// pass the original ort tensor
435440
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
436-
let output = await session.run(ortFeed);
437-
output = replaceTensors(output);
438-
return output;
441+
const run = () => session.run(ortFeed);
442+
const output = await ((apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV)
443+
? (webInferenceChain = webInferenceChain.then(run))
444+
: run());
445+
return replaceTensors(output);
439446
} catch (e) {
440447
// Error messages can be long (nested) and uninformative. For this reason,
441448
// we apply minor formatting to show the most important information
442449
const formatted = Object.fromEntries(Object.entries(checkedInputs)
443-
.map(([k, { type, dims, data }]) => [k, {
450+
.map(([k, tensor]) => {
444451
// Extract these properties from the underlying ORT tensor
445-
type, dims, data,
446-
}]));
452+
const unpacked = {
453+
type: tensor.type,
454+
dims: tensor.dims,
455+
location: tensor.location,
456+
}
457+
if (unpacked.location !== "gpu-buffer") {
458+
// Only return the data if it's not a GPU buffer
459+
unpacked.data = tensor.data;
460+
}
461+
return [k, unpacked];
462+
}));
447463

448464
// This usually occurs when the inputs are of the wrong type.
449465
console.error(`An error occurred during model execution: "${e}".`);
@@ -5223,7 +5239,7 @@ export class RTDetrV2ForObjectDetection extends RTDetrV2PreTrainedModel {
52235239
}
52245240
}
52255241

5226-
export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput {}
5242+
export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
52275243
//////////////////////////////////////////////////
52285244

52295245
//////////////////////////////////////////////////
@@ -5238,7 +5254,7 @@ export class RFDetrForObjectDetection extends RFDetrPreTrainedModel {
52385254
}
52395255
}
52405256

5241-
export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput {}
5257+
export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
52425258
//////////////////////////////////////////////////
52435259

52445260
//////////////////////////////////////////////////

tests/utils/generation.test.js

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,67 @@ describe("PKV caching", () => {
282282
}, MAX_MODEL_DISPOSE_TIME);
283283
});
284284

285+
describe("LlamaForCausalLM (onnxruntime-genai)", () => {
286+
const model_id = "onnx-internal-testing/tiny-random-LlamaForCausalLM-GQA";
287+
/** @type {LlamaForCausalLM} */
288+
let model;
289+
/** @type {LlamaTokenizer} */
290+
let tokenizer;
291+
beforeAll(async () => {
292+
model = await LlamaForCausalLM.from_pretrained(model_id, DEFAULT_MODEL_OPTIONS);
293+
tokenizer = await LlamaTokenizer.from_pretrained(model_id);
294+
}, MAX_MODEL_LOAD_TIME);
295+
296+
it(
297+
"batch_size=1",
298+
async () => {
299+
const inputs = tokenizer("1");
300+
301+
// Generate first sequence w/o PKV
302+
// NOTE: `return_dict_in_generate=true` is required to get PKV
303+
const { past_key_values, sequences } = await model.generate({
304+
...inputs,
305+
max_new_tokens: 5,
306+
do_sample: false,
307+
return_dict_in_generate: true,
308+
});
309+
310+
// Update output with new text
311+
const decoded = tokenizer.batch_decode(sequences, {
312+
skip_special_tokens: false,
313+
})[0];
314+
const new_inputs = tokenizer(decoded + "2", {
315+
add_special_tokens: false,
316+
});
317+
318+
// Run w/o PKV
319+
const generated_ids = await model.generate({
320+
...new_inputs,
321+
max_new_tokens: 3,
322+
do_sample: false,
323+
});
324+
325+
// Run w/ PKV
326+
const generated_ids_pkv = await model.generate({
327+
...new_inputs,
328+
past_key_values,
329+
max_new_tokens: 3,
330+
do_sample: false,
331+
});
332+
333+
const target = [[128000n, 16n, 34732n, 98805n, 116404n, 68265n, 99392n, 17n, 21855n, 60933n, 14285n]];
334+
335+
expect(generated_ids.tolist()).toEqual(target);
336+
expect(generated_ids_pkv.tolist()).toEqual(target);
337+
},
338+
MAX_TEST_EXECUTION_TIME,
339+
);
340+
341+
afterAll(async () => {
342+
await model?.dispose();
343+
}, MAX_MODEL_DISPOSE_TIME);
344+
});
345+
285346
describe("LlavaForConditionalGeneration", () => {
286347
const model_id = "Xenova/tiny-random-LlavaForConditionalGeneration";
287348
/** @type {LlavaForConditionalGeneration} */

0 commit comments

Comments
 (0)