File tree 3 files changed +8
-4
lines changed
3 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -272,7 +272,8 @@ class HfRunner:
272
272
def get_default_device (self ):
273
273
from vllm .platforms import current_platform
274
274
275
- return ("cpu" if current_platform .is_cpu () else "cuda" )
275
+ return ("cpu"
276
+ if current_platform .is_cpu () else current_platform .device_type )
276
277
277
278
def wrap_device (self , x : _T , device : Optional [str ] = None ) -> _T :
278
279
if x is None or isinstance (x , (bool , )):
Original file line number Diff line number Diff line change 6
6
import pytest
7
7
import torch
8
8
9
+ from vllm .platforms import current_platform
9
10
from vllm .utils import make_tensor_with_pad
10
11
from vllm .v1 .sample .metadata import SamplingMetadata
11
12
from vllm .v1 .sample .sampler import Sampler
12
13
13
14
VOCAB_SIZE = 1024
14
15
NUM_OUTPUT_TOKENS = 20
15
16
CUDA_DEVICES = [
16
- f"cuda:{ i } " for i in range (1 if torch .cuda .device_count () == 1 else 2 )
17
+ f"{ current_platform .device_type } :{ i } "
18
+ for i in range (1 if current_platform .device_count () == 1 else 2 )
17
19
]
18
20
MAX_NUM_PROMPT_TOKENS = 64
19
21
Original file line number Diff line number Diff line change 14
14
SamplerOutput ,
15
15
SamplingMetadata , get_logprobs ,
16
16
get_pythonized_sample_results )
17
+ from vllm .platforms import current_platform
17
18
from vllm .sequence import (CompletionSequenceGroupOutput , IntermediateTensors ,
18
19
Logprob , SequenceGroupMetadata , SequenceOutput )
19
20
from vllm .utils import PyObjectCache , async_tensor_h2d , current_stream
@@ -158,8 +159,8 @@ class StatefulModelInput(BroadcastableModelInput):
158
159
is_first_multi_step : bool = False
159
160
base_output_proc_callback : Optional [Callable ] = None
160
161
# ping-pong data structures for multi-step to wait on the previous step
161
- step_cuda_events : List [torch . cuda .Event ] = field (
162
- default_factory = lambda : [torch . cuda .Event (blocking = True )] * 2 )
162
+ step_cuda_events : List [current_platform .Event ] = field (
163
+ default_factory = lambda : [current_platform .Event (blocking = True )] * 2 )
163
164
num_seqs : int = - 1
164
165
num_queries : int = - 1
165
166
num_single_step_prefills : int = 0
You can’t perform that action at this time.
0 commit comments