Skip to content

Got unexpected low speed using quantization inference on qwen models. #2102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
HaoKang-Timmy opened this issue Apr 22, 2025 · 2 comments
Open

Comments

@HaoKang-Timmy
Copy link

Hi I have tested torchao's quantization config and using qwen models from huggingface to verify. I implement all these things on a single H100 with fp8 computation support. However the compression performance on single batch is pretty bad except for int8 weight only compression. I wonder what causes this?

The result:
3B
🧾 Summary:
quant compile prefill decode total
0 float8_wo True 0.043461 0.008105 0.051566
1 float8_dyn_act True 0.093110 0.010512 0.103623
2 int4_wo True 0.124646 0.007891 0.132537
3 int8_wo True 0.032728 0.007511 0.040238

7B

🧾 Summary:
quant compile prefill decode total
0 float8_wo True 0.065851 0.008619 0.074470
1 float8_dyn_act True 0.073185 0.010788 0.083972
2 int4_wo True 0.288922 0.009294 0.298216
3 int8_wo True 0.048244 0.008691 0.056935

The code is :

import torch
import time
import argparse
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Float8WeightOnlyConfig,
    Int4WeightOnlyConfig,
    Int8WeightOnlyConfig,
)
import pandas as pd

def get_quant_configs():
    return {
        "float8_wo": Float8WeightOnlyConfig(),
        "float8_dyn_act": Float8DynamicActivationFloat8WeightConfig(),
        "int4_wo": Int4WeightOnlyConfig(),
        "int8_wo": Int8WeightOnlyConfig(),
    }

def build_prompt():
    background = (
        "You are the best and most aggressive Street Fighter III 3rd strike player in the world. "
        "Your character is Ken. Your goal is to beat the other opponent."
    )
    hint = (
        "If you are far from opponent, use Move Closer and Fireball more often. "
        "If you are close to opponent or already moved closer, try to use Punch and Kick more often. "
        "Megapunch, Hurricane, and other combinations use more time but are more powerful. "
        "Use them when you are close to opponent and getting positive scores or winning. "
        "If you are getting negative scores or losing, try to Move Away and use Kick."
    )
    prompt = """
The moves you can use are:
- Move Closer
- Move Away
- Fireball
- Megapunch
- Hurricane
- Low Punch
- Medium Punch
- High Punch
- Low Kick
- Medium Kick
- High Kick
- Low Punch+Low Kick
- Medium Punch+Medium Kick
- High Punch+High Kick
- Jump Closer
- Jump Away
----

Example 1:
Context:
You are very far from the opponent. Move closer to the opponent. Your opponent is on the left.
Your last action was Medium Punch. The opponent's last action was Medium Punch.
Your current score is 108.0. You are winning. Keep attacking the opponent.

Your Response:
- Move closer
- Move closer
- Low Kick

Example 2:
Context:
You are close to the opponent. You should attack him.
Your last action was High Punch. The opponent's last action was High Punch.
Your current score is 37.0. You are winning. Keep attacking the opponent.

Your Response:
- High Punch
- Low Punch
- Hurricane

Example 3:
Context:
You are very far from the opponent. Move closer to the opponent. Your opponent is on the left.
Your last action was Low. The opponent's last action was Medium Punch.
Your current score is -75.0. You are losing. Continue to attack the opponent but don't get hit.
To increase your score, move toward the opponent and attack the opponent. To prevent your score from decreasing, don't get hit by the opponent.

Your Response:
- Move Away
- Low Punch
- Fireball

Now you are provided the following context, give your response using the same format as in the example.
Context:
"""
    context = """You are close to the opponent. You should attack him.
You can now use a powerfull move. The names of the powerful moves are: Megafireball, Super attack 2.
Your last action was High Punch. The opponent's last action was High Punch.
Your current score is 37.0. You are winning. Keep attacking the opponent."""

    messages = [
        {"role": "system", "content": background + hint},
        {"role": "user", "content": prompt + context + "\nYour Response:\n"}
    ]
    return messages

def measure(model, inputs, max_new_tokens, runs=12):
    prefill_times = []
    full_times = []

    for _ in range(runs):
        # Prefill only
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)

        starter.record()
        model.generate(**inputs, max_new_tokens=1, cache_implementation="static")
        ender.record()
        torch.cuda.synchronize()
        prefill_times.append(starter.elapsed_time(ender) / 1000)

        # Full (prefill + decode)
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)

        starter.record()
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, cache_implementation="static")
        ender.record()
        torch.cuda.synchronize()
        full_times.append(starter.elapsed_time(ender) / 1000)

    avg_prefill = sum(prefill_times) / len(prefill_times)
    avg_total = sum(full_times) / len(full_times)
    avg_decode = avg_total - avg_prefill
    return avg_prefill, avg_decode, avg_total, output

def run_inference(model_name, quant_name, device, runs=12, max_new_tokens=32, compile_flag=False):
    print(f"\n🚀 Running {quant_name} on {device} | torch.compile={compile_flag}")

    quant_config = get_quant_configs()[quant_name]
    qconfig = TorchAoConfig(quant_type=quant_config)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map=device,
        quantization_config=qconfig,
    )

    if compile_flag:
        model = torch.compile(model)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    messages = build_prompt()
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt")
    if "cuda" in device:
        inputs = inputs.to(device)

    # Warm-up 3 次
    print("🔄 Warm-up...")
    for _ in range(3):
        _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="static")

    prefill, decode, total, output = measure(model, inputs, max_new_tokens, runs)

    print(f"✅ {quant_name} | torch.compile={compile_flag} | prefill: {prefill:.4f}s | decode: {decode:.4f}s | total: {total:.4f}s")
    print("📜 Response:")
    print(tokenizer.decode(output[0], skip_special_tokens=True).split("Your Response:")[-1].strip())

    return {
        "quant": quant_name,
        "compile": compile_flag,
        "prefill": prefill,
        "decode": decode,
        "total": total,
    }

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--quant", type=str, default="all")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--runs", type=int, default=12)
    parser.add_argument("--max_new_tokens", type=int, default=2)
    parser.add_argument("--compile", type=str, default="false", choices=["true", "false"])
    parser.add_argument("--output_csv", type=str, default="benchmark_results.csv")
    args = parser.parse_args()

    configs = get_quant_configs().keys() if args.quant == "all" else [args.quant]
    compile_flag = args.compile.lower() == "true"

    results = []
    for name in configs:
        try:
            result = run_inference(args.model_name, name, device=args.device, runs=args.runs,
                                   max_new_tokens=args.max_new_tokens, compile_flag=compile_flag)
            results.append(result)
        except Exception as e:
            print(f"❌ Error for {name}: {e}")

    df = pd.DataFrame(results)
    print("\n🧾 Summary:")
    print(df)
    df.to_csv(args.output_csv, index=False)

Command to run is:

python3 torchao_test.py --model_name Qwen/Qwen2.5-3B-Instruct --output_csv qwen_3.csv --compile true > qwen_3b_compile.txt
python3 torchao_test.py --model_name Qwen/Qwen2.5-7B-Instruct --output_csv qwen_7.csv --compile true > qwen_7b_compile.txt
@jerryzh168
Copy link
Contributor

jerryzh168 commented Apr 23, 2025

thanks for trying out torchao @HaoKang-Timmy, the benchmark code you linked may not be very reliable I think, I'd recommend use vllm to benchmark this to get a better idea, can you follow the flow in https://huggingface.co/jerryzh168/phi4-mini-float8dq? https://huggingface.co/jerryzh168/phi4-mini-float8dq#model-performance talks about what are the benchmarks you can run

@jerryzh168
Copy link
Contributor

another thing is about torch.compile, can you check out https://huggingface.co/docs/transformers/main/en/quantization/torchao on how to apply torch compile with huggingface model? specifically this part:

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants