Skip to content

Commit fecb1f8

Browse files
committed
feat: tests pass & can execute llama2
1 parent 26cfc08 commit fecb1f8

File tree

4 files changed

+52
-8
lines changed

4 files changed

+52
-8
lines changed

test/sparsity/test_marlin.py

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def test_pack_unpack_equivalence(self):
6969
w_24, scales, zeros, n_bit=4, groupsize=group_size
7070
)
7171

72+
scales = scales.reshape(-1, w_q_24.shape[1])
73+
7274
# Test pack/unpack equivalence
7375
q_w_comp, packed_scales, meta = pack_to_marlin_24(
7476
w_q_24, scales, num_bits, group_size

torchao/dtypes/affine_quantized_tensor.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,9 @@ def from_plain(
618618

619619
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
620620
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
621+
# NOTE(reviewers): Please check if this is what I should do.
621622
q_w_24 = int_data.t()
623+
scale = scale.reshape(-1, q_w_24.shape[1])
622624

623625
if q_w_24.dtype != torch.int32:
624626
raise ValueError("Only `torch.int32` weights are supported.")
@@ -631,15 +633,14 @@ def from_plain(
631633

632634
# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
633635
# will require a bit more work to get our current quantization flow to work with it.
634-
# Check the below link for a reference:
635-
# https://github.com/neuralmagic/nm-vllm/tree/main
636+
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
636637
num_bits = 4 if torch.max(q_w_24) < 16 else -1
637638
if num_bits not in [4]:
638639
raise ValueError(
639-
f"Only {const.SUPPORTED_NUM_BITS} bits are supported, got {num_bits}."
640+
f"Only {[4]} bits are supported, got {num_bits}."
640641
)
641642

642-
group_size = in_features // scale.shape[-1]
643+
group_size = in_features // scale.shape[0]
643644
if group_size == 0:
644645
group_size = in_features
645646
assert group_size <= in_features, "Group size must be less than or equal to in_features."
@@ -1043,27 +1044,44 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
10431044
isinstance(weight_tensor.layout_type, MarlinSparseLayoutType)
10441045
)
10451046

1046-
10471047
def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
1048-
from torchao.sparsity.marlin import marlin_24_workspace
1048+
from torchao.sparsity.marlin import marlin_24_workspace, const
10491049

10501050
sparse_w_int4 = weight_tensor.layout_tensor.int_data
10511051
scale = weight_tensor.layout_tensor.scale
10521052
meta = weight_tensor.layout_tensor.meta
10531053
original_shape = weight_tensor.layout_tensor.original_shape
10541054
num_bits = weight_tensor.layout_tensor.num_bits
10551055

1056+
# Saves batch size for reshaping back to original shape after the matmul
1057+
# Reshapes tensor to (m, k) where m is in_features * batch and k is out_features
1058+
# NOTE(reviewers): Please check if I am handling the batch size correctly
1059+
batch_size = -1
1060+
if input_tensor.dim() == 3:
1061+
batch_size = input_tensor.size(0)
1062+
input_tensor = input_tensor.reshape(-1, input_tensor.shape[-1]).contiguous()
1063+
10561064
size_m = input_tensor.shape[0]
1057-
size_n = original_shape[0]
1065+
size_n = original_shape[1]
10581066
size_k = input_tensor.shape[1]
10591067
workspace_24 = marlin_24_workspace(original_shape[1])
10601068

1069+
# Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
1070+
if size_k % const.TILE != 0:
1071+
pad_size = find_multiple(size_k, const.TILE)
1072+
input_tensor = torch.nn.functional.pad(input_tensor, (0, pad_size - size_k))
1073+
size_k = pad_size
1074+
10611075
out = torchao.ops.marlin_24_gemm(
10621076
input_tensor, sparse_w_int4, meta, scale,
10631077
workspace_24, num_bits, size_m, size_n, size_k
10641078
)
10651079
torch.cuda.synchronize()
10661080

1081+
# Reshape back to original shape
1082+
if batch_size != -1:
1083+
out = out.reshape(batch_size, -1, out.shape[-1])
1084+
10671085
if bias is not None:
10681086
out += bias.to(out.dtype)
10691087
return out

torchao/sparsity/marlin/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _from_marlin_scale(
333333
if group_size < size_k and group_size != -1:
334334
reverse_perms = reverse_marlin_24_scale_perm[num_bits]
335335
scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms]
336-
return scales.reshape((size_n, -1))
336+
return scales.reshape((size_k // group_size, size_n))
337337
else:
338338
reverse_perms = reverse_marlin_24_scale_perm_single[num_bits]
339339
scales = scales.reshape((-1, len(reverse_perms)))[:, reverse_perms]

wip_test_llama2.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
from torchao import quantize_
3+
from torchao.quantization import int4_weight_only
4+
from torchao.dtypes import MarlinSparseLayoutType
5+
from transformers import AutoTokenizer, LlamaForCausalLM
6+
7+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8+
name = "meta-llama/Llama-2-7b-hf"
9+
token = "your token"
10+
11+
model = LlamaForCausalLM.from_pretrained(name, torch_dtype=torch.float16, token=token).to(device)
12+
tokenizer = AutoTokenizer.from_pretrained(name, token=token)
13+
14+
prompt = "Hey, are you conscious? Can you talk to me? I'm"
15+
inputs = tokenizer(prompt, return_tensors="pt")
16+
17+
# Quantize
18+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
19+
20+
# Generate
21+
ids = inputs.input_ids.to(device)
22+
generate_ids = model.generate(ids, max_length=30)
23+
out = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
24+
print(out)

0 commit comments

Comments
 (0)