Skip to content

Commit 8b7b538

Browse files
authored
Revert "Add layout option to woq int4 api (#670)"
This reverts commit 009f55f.
1 parent 009f55f commit 8b7b538

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

test/integration/test_integration.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22-
from torchao.dtypes import TensorCoreTiledLayoutType
2322
from torchao.quantization.quant_api import (
2423
int4_weight_only,
2524
int8_weight_only,
@@ -853,20 +852,18 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
853852
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
854853
for groupsize in [64, 32]:
855854
for inner_k_tiles in [4, 2]:
856-
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
855+
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
857856

858857
def api(mod):
859-
kwargs_copy = kwargs.copy()
860858
if TORCH_VERSION_AFTER_2_4:
859+
kwargs_copy = kwargs.copy()
861860
kwargs_copy["group_size"] = groupsize
862861
del kwargs_copy["groupsize"]
863862
quantize_(mod, int4_weight_only(**kwargs_copy))
864863
if not TORCH_VERSION_AFTER_2_5:
865864
unwrap_tensor_subclass(mod)
866865
else:
867-
kwargs_copy["inner_k_tiles"] = inner_k_tiles
868-
del kwargs_copy["layout_type"]
869-
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)
866+
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
870867

871868
self._test_lin_weight_subclass_api_impl(
872869
api,

torchao/quantization/quant_api.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,7 @@
2121
import torch.nn.functional as F
2222
from typing import Any, Callable, Union, Dict, Optional
2323

24-
from torchao.dtypes.uintx.Uintx import UintxLayoutType
25-
from torchao.dtypes import (
26-
to_affine_quantized,
27-
TensorCoreTiledLayoutType,
28-
PlainLayoutType,
29-
AffineQuantizedTensor,
30-
SemiSparseLayoutType
31-
)
24+
from torchao.dtypes import PlainLayoutType
3225
from torchao.utils import (
3326
TORCH_VERSION_AFTER_2_4,
3427
unwrap_tensor_subclass,
@@ -189,6 +182,9 @@ def _replace_with_custom_fn_if_matches_filter(
189182

190183

191184
def _is_linear(mod, *args):
185+
# avoid circular dep
186+
from torchao.dtypes import AffineQuantizedTensor
187+
192188
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
193189
# when it is shared by multiple linear modules
194190
return (
@@ -332,6 +328,9 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
332328
)
333329

334330
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
331+
# avoid circular dep
332+
from torchao.dtypes import to_affine_quantized
333+
335334
mapping_type = MappingType.ASYMMETRIC
336335
target_dtype = torch.int8
337336
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
@@ -340,6 +339,9 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
340339
if weight.shape[-1] % group_size != 0:
341340
return weight
342341

342+
# avoid circular dep
343+
from torchao.dtypes import to_affine_quantized
344+
343345
# weight settings
344346
mapping_type = MappingType.SYMMETRIC
345347
block_size = (1, group_size)
@@ -371,7 +373,7 @@ def insert_subclass(lin):
371373
return insert_subclass
372374

373375

374-
def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)):
376+
def int4_weight_only(group_size=128, inner_k_tiles=8):
375377
"""
376378
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
377379
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -387,12 +389,16 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner
387389
Args:
388390
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
389391
size is more fine grained, choices are [256, 128, 64, 32]
390-
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
392+
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
391393
"""
392394
def apply_int4_weight_only_quant(weight):
393395
if weight.shape[-1] % group_size != 0:
394396
return weight
395397

398+
# avoid circular dep
399+
from torchao.dtypes import to_affine_quantized
400+
from torchao.dtypes import TensorCoreTiledLayoutType
401+
396402
mapping_type = MappingType.ASYMMETRIC
397403
block_size = (1, group_size)
398404
target_dtype = torch.int32
@@ -402,6 +408,7 @@ def apply_int4_weight_only_quant(weight):
402408
preserve_zero = False
403409
zero_point_dtype = torch.bfloat16
404410
zero_point_domain = ZeroPointDomain.FLOAT
411+
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
405412
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)
406413

407414
return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
@@ -412,6 +419,9 @@ def int8_weight_only():
412419
Applies int8 weight-only symmetric per-channel quantization to linear layers.
413420
"""
414421
def apply_int8wo_quant(weight):
422+
# avoid circular dep
423+
from torchao.dtypes import to_affine_quantized
424+
415425
mapping_type = MappingType.SYMMETRIC
416426
target_dtype = torch.int8
417427
eps = torch.finfo(torch.float32).eps
@@ -422,6 +432,8 @@ def apply_int8wo_quant(weight):
422432
return _get_linear_subclass_inserter(apply_int8wo_quant)
423433

424434
def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
435+
# avoid circular dep
436+
from torchao.dtypes import to_affine_quantized
425437
mapping_type = MappingType.SYMMETRIC
426438
target_dtype = torch.int8
427439
eps = 1e-5
@@ -441,6 +453,8 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
441453
if in_features <= 16:
442454
return weight
443455

456+
# avoid circular dep
457+
from torchao.dtypes import to_affine_quantized
444458
# weight settings
445459
mapping_type = MappingType.SYMMETRIC
446460
def get_weight_block_size(x):
@@ -465,6 +479,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
465479
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
466480
quantization + 2:4 sparsity to linear layers.
467481
"""
482+
from torchao.dtypes import SemiSparseLayoutType
468483
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
469484

470485

@@ -480,6 +495,8 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
480495
quantize_affine,
481496
dequantize_affine,
482497
)
498+
from torchao.dtypes.uintx.Uintx import UintxLayoutType
499+
from torchao.dtypes import to_affine_quantized
483500
from torchao.quantization.quant_api import _get_linear_subclass_inserter
484501
def apply_uintx_weight_only_quant(weight):
485502

0 commit comments

Comments
 (0)