@@ -237,6 +237,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
237
237
const session_config = {
238
238
dtype : selectedDtype ,
239
239
kv_cache_dtype,
240
+ device : selectedDevice ,
240
241
}
241
242
242
243
// Construct the model file name
@@ -417,6 +418,10 @@ function validateInputs(session, inputs) {
417
418
return checkedInputs ;
418
419
}
419
420
421
+ // Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
422
+ // For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
423
+ let webInferenceChain = Promise . resolve ( ) ;
424
+
420
425
/**
421
426
* Executes an InferenceSession using the specified inputs.
422
427
* NOTE: `inputs` must contain at least the input names of the model.
@@ -433,17 +438,28 @@ async function sessionRun(session, inputs) {
433
438
try {
434
439
// pass the original ort tensor
435
440
const ortFeed = Object . fromEntries ( Object . entries ( checkedInputs ) . map ( ( [ k , v ] ) => [ k , v . ort_tensor ] ) ) ;
436
- let output = await session . run ( ortFeed ) ;
437
- output = replaceTensors ( output ) ;
438
- return output ;
441
+ const run = ( ) => session . run ( ortFeed ) ;
442
+ const output = await ( ( apis . IS_BROWSER_ENV || apis . IS_WEBWORKER_ENV )
443
+ ? ( webInferenceChain = webInferenceChain . then ( run ) )
444
+ : run ( ) ) ;
445
+ return replaceTensors ( output ) ;
439
446
} catch ( e ) {
440
447
// Error messages can be long (nested) and uninformative. For this reason,
441
448
// we apply minor formatting to show the most important information
442
449
const formatted = Object . fromEntries ( Object . entries ( checkedInputs )
443
- . map ( ( [ k , { type , dims , data } ] ) => [ k , {
450
+ . map ( ( [ k , tensor ] ) => {
444
451
// Extract these properties from the underlying ORT tensor
445
- type, dims, data,
446
- } ] ) ) ;
452
+ const unpacked = {
453
+ type : tensor . type ,
454
+ dims : tensor . dims ,
455
+ location : tensor . location ,
456
+ }
457
+ if ( unpacked . location !== "gpu-buffer" ) {
458
+ // Only return the data if it's not a GPU buffer
459
+ unpacked . data = tensor . data ;
460
+ }
461
+ return [ k , unpacked ] ;
462
+ } ) ) ;
447
463
448
464
// This usually occurs when the inputs are of the wrong type.
449
465
console . error ( `An error occurred during model execution: "${ e } ".` ) ;
@@ -5223,7 +5239,7 @@ export class RTDetrV2ForObjectDetection extends RTDetrV2PreTrainedModel {
5223
5239
}
5224
5240
}
5225
5241
5226
- export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
5242
+ export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
5227
5243
//////////////////////////////////////////////////
5228
5244
5229
5245
//////////////////////////////////////////////////
@@ -5238,7 +5254,7 @@ export class RFDetrForObjectDetection extends RFDetrPreTrainedModel {
5238
5254
}
5239
5255
}
5240
5256
5241
- export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
5257
+ export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
5242
5258
//////////////////////////////////////////////////
5243
5259
5244
5260
//////////////////////////////////////////////////
0 commit comments