Skip to content

Commit e283743

Browse files
authored
Revert "Pin PT version: Fix FPX Inductor error" (#843)
* Revert "Pin PT version: Fix FPX Inductor error (#790)" This reverts commit 287458c. * udpates * yolo * yolo * yolo * yolo
1 parent d2226a4 commit e283743

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

.github/workflows/regression_test.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ jobs:
3636
torch-spec: 'torch==2.4.0'
3737
gpu-arch-type: "cuda"
3838
gpu-arch-version: "12.1"
39-
- name: CUDA Nightly (Aug 29, 2024)
39+
- name: CUDA Nightly
4040
runs-on: linux.g5.12xlarge.nvidia.gpu
41-
torch-spec: '--pre torch==2.5.0.dev20240829+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
41+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
4242
gpu-arch-type: "cuda"
4343
gpu-arch-version: "12.1"
4444

@@ -57,9 +57,9 @@ jobs:
5757
torch-spec: 'torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu'
5858
gpu-arch-type: "cpu"
5959
gpu-arch-version: ""
60-
- name: CPU Nightly (Aug 29, 2024)
60+
- name: CPU Nightly
6161
runs-on: linux.4xlarge
62-
torch-spec: '--pre torch==2.5.0.dev20240829+cpu --index-url https://download.pytorch.org/whl/nightly/cpu'
62+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
6363
gpu-arch-type: "cpu"
6464
gpu-arch-version: ""
6565

test/dtypes/test_bitnet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torchao.prototype.dtypes import BitnetTensor
55
from torchao.prototype.dtypes.uint2 import unpack_uint2
66
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
7-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5
88

99
if not TORCH_VERSION_AT_LEAST_2_4:
1010
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -58,6 +58,7 @@ def fn(mod):
5858
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
5959
)
6060

61+
@pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="Regression introdued in nightlies")
6162
@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]])
6263
def test_uint2_quant(input_shape):
6364
device = 'cuda' if torch.cuda.is_available() else 'cpu'

test/integration/test_integration.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,9 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
817817
@parameterized.expand(COMMON_DEVICE_DTYPE)
818818
@unittest.skipIf(is_fbcode(), "broken in fbcode")
819819
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
820+
if TORCH_VERSION_AT_LEAST_2_5 and device == "cpu":
821+
self.skipTest("Regression introduced in PT nightlies")
822+
820823
undo_recommended_configs()
821824
self._test_lin_weight_subclass_api_impl(
822825
_int8wo_api, device, 40, test_dtype=dtype
@@ -826,6 +829,9 @@ def test_int8_weight_only_quant_subclass_api(self, device, dtype):
826829
@torch._inductor.config.patch({"freezing": True})
827830
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "freeze requires torch 2.4 and after.")
828831
def test_int8_weight_only_quant_with_freeze(self, device, dtype):
832+
if TORCH_VERSION_AT_LEAST_2_5 and device == "cpu":
833+
self.skipTest("Regression introduced in PT nightlies")
834+
829835
self._test_lin_weight_subclass_api_impl(
830836
_int8wo_api, device, 40, test_dtype=dtype
831837
)
@@ -1039,7 +1045,10 @@ def test_save_load_dqtensors(self, device, dtype):
10391045
@parameterized.expand(COMMON_DEVICE_DTYPE)
10401046
@torch.no_grad()
10411047
@unittest.skipIf(is_fbcode(), "broken in fbcode")
1042-
def test_save_load_int8woqtensors(self, device, dtype):
1048+
def test_save_load_int8woqtensors(self, device, dtype):
1049+
if TORCH_VERSION_AT_LEAST_2_5 and device == "cpu":
1050+
self.skipTest(f"Regression introduced in PT nightlies")
1051+
10431052
undo_recommended_configs()
10441053
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)
10451054

0 commit comments

Comments
 (0)