@@ -618,7 +618,9 @@ def from_plain(
618
618
619
619
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
620
620
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
621
+ # NOTE(reviewers): Please check if this is what I should do.
621
622
q_w_24 = int_data .t ()
623
+ scale = scale .reshape (- 1 , q_w_24 .shape [1 ])
622
624
623
625
if q_w_24 .dtype != torch .int32 :
624
626
raise ValueError ("Only `torch.int32` weights are supported." )
@@ -631,15 +633,14 @@ def from_plain(
631
633
632
634
# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
633
635
# will require a bit more work to get our current quantization flow to work with it.
634
- # Check the below link for a reference:
635
- # https://github.com/neuralmagic/nm-vllm/tree/main
636
+ # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
636
637
num_bits = 4 if torch .max (q_w_24 ) < 16 else - 1
637
638
if num_bits not in [4 ]:
638
639
raise ValueError (
639
- f"Only { const . SUPPORTED_NUM_BITS } bits are supported, got { num_bits } ."
640
+ f"Only { [ 4 ] } bits are supported, got { num_bits } ."
640
641
)
641
642
642
- group_size = in_features // scale .shape [- 1 ]
643
+ group_size = in_features // scale .shape [0 ]
643
644
if group_size == 0 :
644
645
group_size = in_features
645
646
assert group_size <= in_features , "Group size must be less than or equal to in_features."
@@ -1043,27 +1044,44 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
1043
1044
isinstance (weight_tensor .layout_type , MarlinSparseLayoutType )
1044
1045
)
1045
1046
1046
-
1047
1047
def _linear_fp_act_int4_weight_sparse_marlin_impl (input_tensor , weight_tensor , bias ):
1048
- from torchao .sparsity .marlin import marlin_24_workspace
1048
+ from torchao .sparsity .marlin import marlin_24_workspace , const
1049
1049
1050
1050
sparse_w_int4 = weight_tensor .layout_tensor .int_data
1051
1051
scale = weight_tensor .layout_tensor .scale
1052
1052
meta = weight_tensor .layout_tensor .meta
1053
1053
original_shape = weight_tensor .layout_tensor .original_shape
1054
1054
num_bits = weight_tensor .layout_tensor .num_bits
1055
1055
1056
+ # Saves batch size for reshaping back to original shape after the matmul
1057
+ # Reshapes tensor to (m, k) where m is in_features * batch and k is out_features
1058
+ # NOTE(reviewers): Please check if I am handling the batch size correctly
1059
+ batch_size = - 1
1060
+ if input_tensor .dim () == 3 :
1061
+ batch_size = input_tensor .size (0 )
1062
+ input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ]).contiguous ()
1063
+
1056
1064
size_m = input_tensor .shape [0 ]
1057
- size_n = original_shape [0 ]
1065
+ size_n = original_shape [1 ]
1058
1066
size_k = input_tensor .shape [1 ]
1059
1067
workspace_24 = marlin_24_workspace (original_shape [1 ])
1060
1068
1069
+ # Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
1070
+ if size_k % const .TILE != 0 :
1071
+ pad_size = find_multiple (size_k , const .TILE )
1072
+ input_tensor = torch .nn .functional .pad (input_tensor , (0 , pad_size - size_k ))
1073
+ size_k = pad_size
1074
+
1061
1075
out = torchao .ops .marlin_24_gemm (
1062
1076
input_tensor , sparse_w_int4 , meta , scale ,
1063
1077
workspace_24 , num_bits , size_m , size_n , size_k
1064
1078
)
1065
1079
torch .cuda .synchronize ()
1066
1080
1081
+ # Reshape back to original shape
1082
+ if batch_size != - 1 :
1083
+ out = out .reshape (batch_size , - 1 , out .shape [- 1 ])
1084
+
1067
1085
if bias is not None :
1068
1086
out += bias .to (out .dtype )
1069
1087
return out
0 commit comments