Skip to content

Commit bd4b1da

Browse files
committed
Update torchao after pytorch/pytorch#129940
Summary: Fixes torchao code after the bc breaking change in pytorch/pytorch#129940 Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4 python test/integration/test_integration.py -k test_save_load_int4woqtensors_2_cpu Reviewers: Subscribers: Tasks: Tags:
1 parent 1029df3 commit bd4b1da

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

torchao/dtypes/affine_quantized_tensor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,8 @@ def __tensor_unflatten__(
461461

462462
@classmethod
463463
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
464-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
464+
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
465+
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
465466
scale = scale.reshape(int_data.shape[0], -1)
466467
zero_point = zero_point.reshape(int_data.shape[0], -1)
467468
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)

torchao/quantization/quant_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def apply_int4_weight_only_quant(weight):
383383

384384
mapping_type = MappingType.ASYMMETRIC
385385
block_size = (1, group_size)
386-
target_dtype = torch.int32
386+
target_dtype = torch.uint8
387387
quant_min = 0
388388
quant_max = 15
389389
eps = 1e-6

torchao/quantization/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,13 @@ def groupwise_affine_quantize_tensor_from_qparams(
344344
assert w.dim() == 2
345345

346346
block_size = (1, groupsize)
347-
output_dtype = torch.int32
347+
output_dtype = torch.uint8
348348
quant_min = 0
349349
quant_max = 2 ** n_bit - 1
350350

351-
return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
351+
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
352+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
353+
return int_data
352354

353355
def groupwise_affine_dequantize_tensor_from_qparams(
354356
w_int4x8,

0 commit comments

Comments
 (0)