Skip to content

Commit 7cadbe2

Browse files
committed
feat: Support for Mistral Small 3.1 24B VLM
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent 290649b commit 7cadbe2

File tree

8 files changed

+255
-7
lines changed

8 files changed

+255
-7
lines changed

examples/models/core/multimodal/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def print_result(model, input_text, output_text, args):
4141
0][0].lower()
4242
elif model.model_type in [
4343
'blip2', 'neva', 'phi-3-vision', 'llava_next',
44-
'phi-4-multimodal'
44+
'phi-4-multimodal', 'pixtral'
4545
]:
4646
assert 'singapore' in output_text[0][0].lower()
4747
elif model.model_type == 'video-neva':

tensorrt_llm/models/llama/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ def from_hugging_face(
137137
# InternLM-XComposer2 has a mask for partial lora
138138
# Therefore we need an additional flag for this mask
139139
has_partial_lora_mask = True
140+
if hf_config.model_type == 'mistral3':
141+
from transformers import Mistral3Config
142+
hf_config = Mistral3Config.from_pretrained(
143+
hf_config_dir).text_config
144+
hf_config.architectures = ["MistralForCausalLM"]
140145

141146
num_key_value_heads = getattr(hf_config, "num_key_value_heads",
142147
hf_config.num_attention_heads)

tensorrt_llm/runtime/multimodal_model_runner.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,13 @@ def __init__(self, args):
413413
if self.num_frames is None:
414414
self.num_frames = 8
415415
assert self.args.video_path is None or self.args.image_path is None
416+
if self.model_type == "pixtral":
417+
hf_config = AutoConfig.from_pretrained(self.args.hf_model_dir)
418+
self.image_size = hf_config.vision_config.image_size
419+
self.patch_size = hf_config.vision_config.patch_size
420+
self.vocab_size = hf_config.text_config.vocab_size
421+
self.image_token_index = hf_config.image_token_index
422+
self.spatial_merge_size = hf_config.spatial_merge_size
416423

417424
self.audio_input_names = self.audio_output_names = None
418425
if self.model_type == "mllama":
@@ -617,6 +624,10 @@ def init_processor(self):
617624
self.processor = AutoProcessor.from_pretrained(
618625
self.args.hf_model_dir, trust_remote_code=True, num_crops=16)
619626

627+
elif 'pixtral' in self.model_type:
628+
self.processor = AutoProcessor.from_pretrained(
629+
self.args.hf_model_dir)
630+
620631
elif 'internlm' in self.model_type:
621632
image_size = 490
622633
self.processor = transforms.Compose([
@@ -895,6 +906,33 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
895906
audio_mask = audio.new_ones(*audio.shape[:2])
896907
audio_mask[-1, -pad:] = 0
897908
other_audio_inputs['attention_mask'] = audio_mask.bool()
909+
elif self.model_type == 'pixtral':
910+
# Hold on to pixel_values and input_ids.
911+
dtype = str_dtype_to_torch(self.vision_precision)
912+
pixel_values = image["pixel_values"].to(device="cuda", dtype=dtype)
913+
input_ids = image["input_ids"].to(device="cuda")
914+
915+
# Shape of pixel values from the processor varies with the raw image.
916+
# So we create a new tensor with a fixed shape as expected by the vision
917+
# encoder and create a corresponding attention mask.
918+
image_size = self.image_size
919+
patch_size = self.patch_size
920+
d_min = torch.finfo(dtype).min
921+
num_patches = (image_size // patch_size)
922+
image = torch.full((1, 3, image_size, image_size),
923+
fill_value=0,
924+
dtype=dtype,
925+
device="cuda")
926+
attention_mask = torch.full((1, num_patches, num_patches),
927+
fill_value=d_min,
928+
dtype=dtype,
929+
device="cuda")
930+
h, w = pixel_values.shape[-2:]
931+
image[..., :h, :w] = pixel_values
932+
attention_mask[..., :h // patch_size, :w // patch_size] = 0
933+
other_vision_inputs = {
934+
"attention_mask": attention_mask,
935+
}
898936
elif self.model_type == 'llava_next':
899937
input = image
900938
image = input['pixel_values']
@@ -1108,6 +1146,17 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
11081146
audio_features = audio_features.unsqueeze(0).repeat(
11091147
self.args.batch_size, 1, 1)
11101148
length = input_ids.shape[1]
1149+
1150+
elif self.model_type == 'pixtral':
1151+
relevant_patch_size = self.patch_size * self.spatial_merge_size
1152+
output_img_size = self.image_size // relevant_patch_size
1153+
visual_features = visual_features.reshape(
1154+
output_img_size, output_img_size,
1155+
-1)[:h // relevant_patch_size, :w //
1156+
relevant_patch_size].flatten(0, 1)
1157+
input_ids = self.ptuning_setup_pixtral(input_ids=input_ids)
1158+
length = input_ids.shape[1]
1159+
11111160
elif self.model_type == 'llava_next':
11121161
visual_features = LlavaNextUtils.rearrange_image_features(
11131162
visual_features, self.image_newlines["image_newline"],
@@ -1208,7 +1257,7 @@ def preprocess(self, pre_prompt, post_prompt, image, other_vision_inputs,
12081257
torch.int32)
12091258

12101259
if self.model_type in [
1211-
'fuyu', 'kosmos-2', 'phi-3-vision', 'llava_next'
1260+
'fuyu', 'kosmos-2', 'phi-3-vision', 'llava_next', 'pixtral'
12121261
]:
12131262
return input_ids, input_lengths, [
12141263
visual_features
@@ -1976,6 +2025,20 @@ def ptuning_setup_fuyu(self, input_ids, image_patches_indices):
19762025
res_input_ids.append(cur_input_ids)
19772026
return res_input_ids
19782027

2028+
def ptuning_setup_pixtral(self, input_ids):
2029+
# input_ids obtained from processor has token_ids for text as well as image tokens
2030+
# where each image token is represented the same image_token_index (10 for this model).
2031+
image_token_index = self.image_token_index
2032+
vocab_size = self.vocab_size
2033+
# Replace all image tokens with a unique token_id > text_vacab_size.
2034+
# This shall be used to lookup the prompt table.
2035+
replacer = vocab_size
2036+
for i in range(len(input_ids[0])):
2037+
if input_ids[0][i] == image_token_index:
2038+
input_ids[0][i] = replacer
2039+
replacer += 1
2040+
return input_ids
2041+
19792042
def ptuning_setup_llava_next(self, visual_features, pre_prompt,
19802043
post_prompt):
19812044
input_ids = []
@@ -2342,6 +2405,18 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None):
23422405
audios=[raw_audio],
23432406
return_tensors="pt")
23442407

2408+
elif 'pixtral' in self.model_type:
2409+
# Send image and text prompt to processor.
2410+
pre_prompt = "<s>[INST][IMG]"
2411+
if input_text is None:
2412+
input_text = "What is in the image?"
2413+
post_prompt = "[/INST]"
2414+
prompt = pre_prompt + input_text + post_prompt
2415+
dtype = str_dtype_to_torch(self.vision_precision)
2416+
image = self.processor(text=prompt,
2417+
images=[raw_image],
2418+
return_tensors="pt").to(dtype)
2419+
23452420
elif 'internvl' in self.model_type:
23462421
pre_prompt = "<|system|>\n你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|end|><|user|>\n<image>\n"
23472422
if input_text is None:
@@ -2526,7 +2601,8 @@ def setup_inputs(self, input_text, raw_image, raw_audio=None):
25262601
post_prompt = [post_prompt] * self.args.batch_size
25272602
if self.model_type not in [
25282603
'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision',
2529-
'phi-4-multimodal', 'llava_next', 'internvl', 'llava_onevision'
2604+
'phi-4-multimodal', 'llava_next', 'internvl', 'llava_onevision',
2605+
'pixtral'
25302606
]:
25312607
if image is not None:
25322608
if image.dim() == 5:

tensorrt_llm/tools/multimodal_builder.py

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def add_multimodal_arguments(parser):
4141
'fuyu', 'pix2struct', 'neva', 'kosmos-2',
4242
'video-neva', 'phi-3-vision', 'phi-4-multimodal',
4343
'mllama', 'internvl', 'qwen2_vl',
44-
'internlm-xcomposer2', 'qwen2_audio'
44+
'internlm-xcomposer2', 'qwen2_audio', 'pixtral'
4545
],
4646
help="Model type")
4747
parser.add_argument(
@@ -142,6 +142,8 @@ def build(self):
142142
build_qwen2_vl_engine(args)
143143
elif args.model_type == 'qwen2_audio':
144144
build_qwen2_audio_engine(args)
145+
elif args.model_type == "pixtral":
146+
build_pixtral_engine(args)
145147
else:
146148
raise RuntimeError(f"Invalid model type {args.model_type}")
147149

@@ -1577,3 +1579,158 @@ def forward(self, x, mask):
15771579
'num_mul_bins': args.num_mul_bins,
15781580
'max_mel_seq_len': args.max_mel_seq_len
15791581
})
1582+
1583+
1584+
def build_pixtral_engine(args):
1585+
processor = AutoProcessor.from_pretrained(args.model_path)
1586+
hf_config = AutoConfig.from_pretrained(args.model_path)
1587+
vision_config = hf_config.vision_config
1588+
raw_image = Image.new(
1589+
'RGB',
1590+
[vision_config.image_size, vision_config.image_size]) # dummy image
1591+
1592+
inputs = processor(text="dummy", images=[raw_image], return_tensors="pt")
1593+
pixel_values = inputs["pixel_values"].to(args.device, torch.bfloat16)
1594+
attention_mask = torch.zeros(
1595+
1, vision_config.image_size // vision_config.patch_size,
1596+
vision_config.image_size // vision_config.patch_size).to(
1597+
args.device, torch.bfloat16)
1598+
1599+
# isort: off
1600+
from transformers.models.pixtral.modeling_pixtral import \
1601+
apply_rotary_pos_emb
1602+
from transformers import Mistral3ForConditionalGeneration
1603+
from transformers.models.pixtral.modeling_pixtral import (PixtralAttention,
1604+
PixtralVisionModel
1605+
)
1606+
from transformers.models.mistral3.modeling_mistral3 import (
1607+
Mistral3MultiModalProjector, Mistral3PatchMerger)
1608+
# isort: on
1609+
@torch.no_grad
1610+
def attn_forward(self,
1611+
hidden_states,
1612+
attention_mask,
1613+
position_embeddings,
1614+
output_attentions=False):
1615+
batch, patches, _ = hidden_states.size()
1616+
1617+
q = self.q_proj(hidden_states)
1618+
k = self.k_proj(hidden_states)
1619+
v = self.v_proj(hidden_states)
1620+
1621+
q = q.view(batch, patches, self.num_heads,
1622+
self.head_dim).transpose(1, 2)
1623+
k = k.view(batch, patches, self.num_heads,
1624+
self.head_dim).transpose(1, 2)
1625+
v = v.view(batch, patches, self.num_heads,
1626+
self.head_dim).transpose(1, 2)
1627+
cos, sin = position_embeddings
1628+
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
1629+
1630+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1631+
q, k, v, attn_mask=attention_mask).transpose(1, 2).contiguous()
1632+
attn_output = attn_output.reshape(batch, patches, -1)
1633+
attn_output = self.o_proj(attn_output)
1634+
1635+
return attn_output, None
1636+
1637+
@torch.no_grad
1638+
def vision_tower_forward(self, pixel_values, attention_mask):
1639+
patch_embeds = self.patch_conv(pixel_values) # (bs, c, h, w)
1640+
1641+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # (bs, h*w, c)
1642+
attention_mask = attention_mask.flatten(1) # (bs, h*w)
1643+
1644+
patch_embeds = self.ln_pre(patch_embeds)
1645+
position_ids = self.position_ids.flatten() # (h*w, )
1646+
position_embeddings = self.patch_positional_embedding(
1647+
patch_embeds, position_ids)
1648+
1649+
out = self.transformer(patch_embeds,
1650+
attention_mask=attention_mask,
1651+
position_embeddings=position_embeddings,
1652+
output_hidden_states=False,
1653+
output_attentions=False,
1654+
return_dict=False)[0]
1655+
return out
1656+
1657+
@torch.no_grad
1658+
def patch_merger_forward(self, image_features, attention_mask):
1659+
h, w = attention_mask.shape[-2:]
1660+
bs, n, d = image_features.shape
1661+
image_grid = image_features.view(bs, h, w, d).permute(0, 3, 1, 2)
1662+
image_features = torch.nn.functional.unfold(image_grid, 2,
1663+
stride=2).transpose(1, 2)
1664+
image_features = self.merging_layer(image_features)
1665+
return image_features
1666+
1667+
@torch.no_grad
1668+
def mm_projector_forward(self, image_features, attention_mask):
1669+
image_features = self.norm(image_features)
1670+
image_features = self.patch_merger(image_features, attention_mask)
1671+
hidden_states = self.linear_2(self.act(self.linear_1(image_features)))
1672+
return hidden_states
1673+
1674+
class PixtralVisionWrapper(torch.nn.Module):
1675+
1676+
def __init__(self, vision_tower, mm_projector):
1677+
super().__init__()
1678+
self.vision_tower = vision_tower
1679+
self.mm_projector = mm_projector
1680+
1681+
@torch.no_grad
1682+
def forward(self, pixel_values, attention_mask):
1683+
features = self.vision_tower(pixel_values, attention_mask)
1684+
out = self.mm_projector(features, attention_mask)
1685+
return out
1686+
1687+
model = Mistral3ForConditionalGeneration.from_pretrained(args.model_path,
1688+
torch_dtype="auto")
1689+
vision_tower = model.vision_tower
1690+
mm_projector = model.multi_modal_projector
1691+
1692+
height = width = vision_config.image_size // vision_config.patch_size
1693+
mesh = torch.meshgrid(torch.arange(height),
1694+
torch.arange(width),
1695+
indexing="ij")
1696+
h_grid, v_grid = torch.stack(mesh, dim=-1).chunk(2, -1)
1697+
ids = h_grid[..., 0] * width + v_grid[..., 0]
1698+
vision_tower.register_buffer("position_ids", ids)
1699+
1700+
PixtralAttention.forward = attn_forward
1701+
PixtralVisionModel.forward = vision_tower_forward
1702+
1703+
Mistral3PatchMerger.forward = patch_merger_forward
1704+
Mistral3MultiModalProjector.forward = mm_projector_forward
1705+
1706+
vision_tower = vision_tower.to(args.device, torch.bfloat16)
1707+
mm_projector = mm_projector.to(args.device, torch.bfloat16)
1708+
vision_tower.eval()
1709+
mm_projector.eval()
1710+
wrapper = PixtralVisionWrapper(vision_tower, mm_projector)
1711+
1712+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
1713+
part_name = 'vision'
1714+
onnx_dir = f"{args.output_dir}/{part_name}/onnx"
1715+
1716+
export_onnx(wrapper,
1717+
input=(pixel_values, attention_mask),
1718+
onnx_dir=onnx_dir,
1719+
input_names=['input', 'attention_mask'],
1720+
dynamic_axes={
1721+
'input': {
1722+
0: "batch"
1723+
},
1724+
'attention_mask': {
1725+
0: "batch"
1726+
}
1727+
})
1728+
build_trt_engine(
1729+
args.model_type,
1730+
input_sizes=[[list(pixel_values.shape[1:]) for _ in range(3)],
1731+
[list(attention_mask.shape[1:]) for _ in range(3)]],
1732+
onnx_dir=onnx_dir,
1733+
engine_dir=args.output_dir,
1734+
max_batch_size=args.max_batch_size,
1735+
engine_name=f"model.engine",
1736+
dtype=torch.bfloat16)

tests/integration/defs/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,8 @@ def multimodal_model_root(request, llm_venv):
794794
tllm_model_name = tllm_model_name + ".nemo"
795795
elif 'Llama-3.2' in tllm_model_name:
796796
models_root = os.path.join(llm_models_root(), 'llama-3.2-models')
797+
elif 'Mistral-Small' in tllm_model_name:
798+
models_root = llm_models_root()
797799

798800
multimodal_model_root = os.path.join(models_root, tllm_model_name)
799801

tests/integration/defs/examples/test_multimodal.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _test_llm_multimodal_general(llm_venv,
115115
mllama_model = 'Llama-3.2' in model_name
116116
qwen2_vl_model = 'Qwen2-VL' in model_name
117117
internlm_model = 'internlm-xcomposer2' in model_name
118-
118+
mistral_model = 'Mistral-Small' in model_name
119119
if enc_dec_model:
120120
builder_root = enc_dec_example_root
121121
if nougat_model:
@@ -134,6 +134,8 @@ def _test_llm_multimodal_general(llm_venv,
134134
builder_root, model_type = internlm_example_root, "internlm"
135135
elif llava_model or vila_model:
136136
builder_root, model_type = llama_example_root, "llama"
137+
elif mistral_model:
138+
builder_root, model_type = llama_example_root, "llama"
137139
elif cogvlm_model:
138140
builder_root, model_type = cogvlm_example_root, "cogvlm"
139141
elif nemotron_model:
@@ -214,7 +216,7 @@ def _test_llm_multimodal_general(llm_venv,
214216
print("Build LLM engines...")
215217
model_name = model_name.split('/')[-1] # Remove HF directory name
216218
llm_engine_dir = f"{engine_dir}/{model_name}/{world_size}-gpu"
217-
if "opt" in model_name or llava_model or vila_model or gpt_model or nemotron_model or phi3_model or phi4_model or qwen2_vl_model:
219+
if "opt" in model_name or llava_model or vila_model or gpt_model or nemotron_model or phi3_model or phi4_model or qwen2_vl_model or mistral_model:
218220
max_input_len_text = 1024
219221
max_output_len = 200
220222
if llava_next_model:
@@ -227,7 +229,9 @@ def _test_llm_multimodal_general(llm_venv,
227229
multimodal_len = 196
228230
elif phi3_model:
229231
multimodal_len = 5120
230-
elif phi4_model: # @B: Confirm this.
232+
elif phi4_model:
233+
multimodal_len = 5120
234+
elif mistral_model:
231235
multimodal_len = 5120
232236
elif "fuyu" in model_name:
233237
multimodal_len = 2640
@@ -386,6 +390,7 @@ def _test_llm_multimodal_general(llm_venv,
386390
elif 'Llama-3.2' in model_name: vision_model_type = 'mllama'
387391
elif "Qwen2-VL" in model_name: vision_model_type = 'qwen2_vl'
388392
elif 'internlm' in model_name: vision_model_type = 'internlm-xcomposer2'
393+
elif 'Mistral-Small' in model_name: vision_model_type = 'pixtral'
389394

390395
vit_batch_size = batch_size
391396
if vision_model_type == "llava_next":
@@ -623,6 +628,7 @@ def _test_llm_multimodal_general(llm_venv,
623628
'Llama-3.2-11B-Vision',
624629
'Qwen2-VL-7B-Instruct',
625630
'internlm-xcomposer2-vl-7b',
631+
'Mistral-Small-3.1-24B-Instruct-2503',
626632
],
627633
indirect=True)
628634
def test_llm_multimodal_general(llm_venv, llm_root, llm_datasets_root,

0 commit comments

Comments
 (0)