You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi I have tested torchao's quantization config and using qwen models from huggingface to verify. I implement all these things on a single H100 with fp8 computation support. However the compression performance on single batch is pretty bad except for int8 weight only compression. I wonder what causes this?
import torch
import time
import argparse
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig,
)
import pandas as pd
def get_quant_configs():
return {
"float8_wo": Float8WeightOnlyConfig(),
"float8_dyn_act": Float8DynamicActivationFloat8WeightConfig(),
"int4_wo": Int4WeightOnlyConfig(),
"int8_wo": Int8WeightOnlyConfig(),
}
def build_prompt():
background = (
"You are the best and most aggressive Street Fighter III 3rd strike player in the world. "
"Your character is Ken. Your goal is to beat the other opponent."
)
hint = (
"If you are far from opponent, use Move Closer and Fireball more often. "
"If you are close to opponent or already moved closer, try to use Punch and Kick more often. "
"Megapunch, Hurricane, and other combinations use more time but are more powerful. "
"Use them when you are close to opponent and getting positive scores or winning. "
"If you are getting negative scores or losing, try to Move Away and use Kick."
)
prompt = """
The moves you can use are:
- Move Closer
- Move Away
- Fireball
- Megapunch
- Hurricane
- Low Punch
- Medium Punch
- High Punch
- Low Kick
- Medium Kick
- High Kick
- Low Punch+Low Kick
- Medium Punch+Medium Kick
- High Punch+High Kick
- Jump Closer
- Jump Away
----
Example 1:
Context:
You are very far from the opponent. Move closer to the opponent. Your opponent is on the left.
Your last action was Medium Punch. The opponent's last action was Medium Punch.
Your current score is 108.0. You are winning. Keep attacking the opponent.
Your Response:
- Move closer
- Move closer
- Low Kick
Example 2:
Context:
You are close to the opponent. You should attack him.
Your last action was High Punch. The opponent's last action was High Punch.
Your current score is 37.0. You are winning. Keep attacking the opponent.
Your Response:
- High Punch
- Low Punch
- Hurricane
Example 3:
Context:
You are very far from the opponent. Move closer to the opponent. Your opponent is on the left.
Your last action was Low. The opponent's last action was Medium Punch.
Your current score is -75.0. You are losing. Continue to attack the opponent but don't get hit.
To increase your score, move toward the opponent and attack the opponent. To prevent your score from decreasing, don't get hit by the opponent.
Your Response:
- Move Away
- Low Punch
- Fireball
Now you are provided the following context, give your response using the same format as in the example.
Context:
"""
context = """You are close to the opponent. You should attack him.
You can now use a powerfull move. The names of the powerful moves are: Megafireball, Super attack 2.
Your last action was High Punch. The opponent's last action was High Punch.
Your current score is 37.0. You are winning. Keep attacking the opponent."""
messages = [
{"role": "system", "content": background + hint},
{"role": "user", "content": prompt + context + "\nYour Response:\n"}
]
return messages
def measure(model, inputs, max_new_tokens, runs=12):
prefill_times = []
full_times = []
for _ in range(runs):
# Prefill only
torch.cuda.empty_cache()
torch.cuda.synchronize()
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
model.generate(**inputs, max_new_tokens=1, cache_implementation="static")
ender.record()
torch.cuda.synchronize()
prefill_times.append(starter.elapsed_time(ender) / 1000)
# Full (prefill + decode)
torch.cuda.empty_cache()
torch.cuda.synchronize()
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
starter.record()
output = model.generate(**inputs, max_new_tokens=max_new_tokens, cache_implementation="static")
ender.record()
torch.cuda.synchronize()
full_times.append(starter.elapsed_time(ender) / 1000)
avg_prefill = sum(prefill_times) / len(prefill_times)
avg_total = sum(full_times) / len(full_times)
avg_decode = avg_total - avg_prefill
return avg_prefill, avg_decode, avg_total, output
def run_inference(model_name, quant_name, device, runs=12, max_new_tokens=32, compile_flag=False):
print(f"\n🚀 Running {quant_name} on {device} | torch.compile={compile_flag}")
quant_config = get_quant_configs()[quant_name]
qconfig = TorchAoConfig(quant_type=quant_config)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map=device,
quantization_config=qconfig,
)
if compile_flag:
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = build_prompt()
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt")
if "cuda" in device:
inputs = inputs.to(device)
# Warm-up 3 次
print("🔄 Warm-up...")
for _ in range(3):
_ = model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
prefill, decode, total, output = measure(model, inputs, max_new_tokens, runs)
print(f"✅ {quant_name} | torch.compile={compile_flag} | prefill: {prefill:.4f}s | decode: {decode:.4f}s | total: {total:.4f}s")
print("📜 Response:")
print(tokenizer.decode(output[0], skip_special_tokens=True).split("Your Response:")[-1].strip())
return {
"quant": quant_name,
"compile": compile_flag,
"prefill": prefill,
"decode": decode,
"total": total,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--quant", type=str, default="all")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--runs", type=int, default=12)
parser.add_argument("--max_new_tokens", type=int, default=2)
parser.add_argument("--compile", type=str, default="false", choices=["true", "false"])
parser.add_argument("--output_csv", type=str, default="benchmark_results.csv")
args = parser.parse_args()
configs = get_quant_configs().keys() if args.quant == "all" else [args.quant]
compile_flag = args.compile.lower() == "true"
results = []
for name in configs:
try:
result = run_inference(args.model_name, name, device=args.device, runs=args.runs,
max_new_tokens=args.max_new_tokens, compile_flag=compile_flag)
results.append(result)
except Exception as e:
print(f"❌ Error for {name}: {e}")
df = pd.DataFrame(results)
print("\n🧾 Summary:")
print(df)
df.to_csv(args.output_csv, index=False)
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
Hi I have tested torchao's quantization config and using qwen models from huggingface to verify. I implement all these things on a single H100 with fp8 computation support. However the compression performance on single batch is pretty bad except for int8 weight only compression. I wonder what causes this?
The result:
3B
🧾 Summary:
quant compile prefill decode total
0 float8_wo True 0.043461 0.008105 0.051566
1 float8_dyn_act True 0.093110 0.010512 0.103623
2 int4_wo True 0.124646 0.007891 0.132537
3 int8_wo True 0.032728 0.007511 0.040238
7B
🧾 Summary:
quant compile prefill decode total
0 float8_wo True 0.065851 0.008619 0.074470
1 float8_dyn_act True 0.073185 0.010788 0.083972
2 int4_wo True 0.288922 0.009294 0.298216
3 int8_wo True 0.048244 0.008691 0.056935
The code is :
Command to run is:
The text was updated successfully, but these errors were encountered: