@@ -32,8 +32,12 @@ def get_args():
32
32
parser .add_argument ("--check-all" , action = "store_true" )
33
33
parser .add_argument ("--show-all" , "--show" , action = "store_true" )
34
34
parser .add_argument ("--show-matrix" , action = "store_true" )
35
- parser .add_argument ("--only-flops-matmul" , "--flops-mm" , action = "store_true" )
36
- parser .add_argument ("--run-acc-f32" , "--acc-f32" , "--f32" , action = "store_true" )
35
+ parser .add_argument (
36
+ "--only-flops-matmul" , "--flops-mm" , action = "store_true"
37
+ )
38
+ parser .add_argument (
39
+ "--run-acc-f32" , "--acc-f32" , "--f32" , action = "store_true"
40
+ )
37
41
parser .add_argument ("--B" , type = int , default = None )
38
42
parser .add_argument ("--H" , type = int , default = None )
39
43
parser .add_argument ("--N" , type = int , default = None )
@@ -46,7 +50,9 @@ def get_args():
46
50
parser .add_argument ("--iters" , "--i" , type = int , default = 5 )
47
51
parser .add_argument ("--range-k" , "--gk" , action = "store_true" )
48
52
parser .add_argument ("--build-others" , "--others" , action = "store_true" )
49
- parser .add_argument ("--tag-hints" , "--tags" , "--hints" , type = str , default = None )
53
+ parser .add_argument (
54
+ "--tag-hints" , "--tags" , "--hints" , type = str , default = None
55
+ )
50
56
return parser .parse_args ()
51
57
52
58
@@ -84,20 +90,30 @@ def get_build_sources():
84
90
build_sources .append ("./mma/basic/flash_attn_mma_share_kv_F32F16F16F32.cu" )
85
91
build_sources .append ("./mma/basic/flash_attn_mma_share_qkv_F32F16F16F32.cu" )
86
92
build_sources .append ("./mma/basic/flash_attn_mma_tiling_qk_F32F16F16F32.cu" )
87
- build_sources .append ("./mma/basic/flash_attn_mma_tiling_qkv_F32F16F16F32.cu" )
93
+ build_sources .append (
94
+ "./mma/basic/flash_attn_mma_tiling_qkv_F32F16F16F32.cu"
95
+ )
88
96
# Swizzle
89
97
build_sources .append ("./mma/swizzle/flash_attn_mma_share_kv_swizzle_q.cu" )
90
98
build_sources .append ("./mma/swizzle/flash_attn_mma_share_kv_swizzle_qk.cu" )
91
99
build_sources .append ("./mma/swizzle/flash_attn_mma_share_kv_swizzle_qkv.cu" )
92
100
build_sources .append ("./mma/swizzle/flash_attn_mma_share_qkv_swizzle_q.cu" )
93
101
build_sources .append ("./mma/swizzle/flash_attn_mma_share_qkv_swizzle_qk.cu" )
94
- build_sources .append ("./mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu" )
102
+ build_sources .append (
103
+ "./mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu"
104
+ )
95
105
build_sources .append ("./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu" )
96
106
build_sources .append ("./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu" )
97
- build_sources .append ("./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu" )
107
+ build_sources .append (
108
+ "./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu"
109
+ )
98
110
build_sources .append ("./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu" )
99
- build_sources .append ("./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu" )
100
- build_sources .append ("./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu" )
111
+ build_sources .append (
112
+ "./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu"
113
+ )
114
+ build_sources .append (
115
+ "./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu"
116
+ )
101
117
build_sources .append (
102
118
"./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu"
103
119
)
@@ -110,15 +126,21 @@ def get_build_sources():
110
126
# Others
111
127
if args .build_others :
112
128
build_sources .append ("./mma/others/flash_attn_mma_share_qkv_Os2g.cu" )
113
- build_sources .append ("./mma/others/flash_attn_mma_share_kv_F32F16F16F32_rr.cu" )
114
- build_sources .append ("./mma/others/flash_attn_mma_share_qkv_F32F16F16F32_rr.cu" )
129
+ build_sources .append (
130
+ "./mma/others/flash_attn_mma_share_kv_F32F16F16F32_rr.cu"
131
+ )
132
+ build_sources .append (
133
+ "./mma/others/flash_attn_mma_share_qkv_F32F16F16F32_rr.cu"
134
+ )
115
135
# Pybind
116
136
build_sources .append ("./pybind/flash_attn.cc" )
117
137
return build_sources
118
138
119
139
120
140
def get_project_dir ():
121
- return os .path .dirname (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
141
+ return os .path .dirname (
142
+ os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
143
+ )
122
144
123
145
124
146
project_dir = get_project_dir ()
@@ -153,7 +175,9 @@ def get_build_cuda_cflags(build_pkg: bool = False):
153
175
extra_cuda_cflags .append (
154
176
"-diag-suppress 177" if not build_pkg else "--ptxas-options=-v"
155
177
)
156
- extra_cuda_cflags .append ("-Xptxas -v" if not build_pkg else "--ptxas-options=-O3" )
178
+ extra_cuda_cflags .append (
179
+ "-Xptxas -v" if not build_pkg else "--ptxas-options=-O3"
180
+ )
157
181
extra_cuda_cflags .append (f"-I { project_dir } /kernels/flash-attn" )
158
182
extra_cuda_cflags .append (f"-I { project_dir } /kernels/flash-attn/utils" )
159
183
extra_cuda_cflags .append (f"-I { project_dir } /kernels/flash-attn/mma" )
@@ -163,14 +187,18 @@ def get_build_cuda_cflags(build_pkg: bool = False):
163
187
extra_cuda_cflags .append (f"-I { project_dir } /kernels/flash-attn/cutlass" )
164
188
extra_cuda_cflags .append (f"-I { project_dir } /kernels/flash-attn/pybind" )
165
189
extra_cuda_cflags .append (f"-I { project_dir } /third-party/cutlass/include" )
166
- extra_cuda_cflags .append (f"-I { project_dir } /third-party/cutlass/tools/util/include" )
190
+ extra_cuda_cflags .append (
191
+ f"-I { project_dir } /third-party/cutlass/tools/util/include"
192
+ )
167
193
return extra_cuda_cflags
168
194
169
195
170
196
def get_build_cflags ():
171
197
extra_cflags = []
172
198
extra_cflags .append ("-std=c++17" )
173
- extra_cflags .append ("-DBUILD_FLASH_ATTN_MMA_OTHERS" if args .build_others else "" )
199
+ extra_cflags .append (
200
+ "-DBUILD_FLASH_ATTN_MMA_OTHERS" if args .build_others else ""
201
+ )
174
202
return extra_cflags
175
203
176
204
@@ -200,8 +228,12 @@ def pretty_print_line(m: str = "", sep: str = "-", width: int = 150):
200
228
if not args .build_others :
201
229
fake_fa_func = lambda q , k , v , o , s : o # fake FA func
202
230
setattr (lib , "flash_attn_mma_stages_split_q_shared_qkv_Os2g" , fake_fa_func )
203
- setattr (lib , "flash_attn_mma_stages_split_q_shared_kv_acc_f32_rr" , fake_fa_func )
204
- setattr (lib , "flash_attn_mma_stages_split_q_shared_qkv_acc_f32_rr" , fake_fa_func )
231
+ setattr (
232
+ lib , "flash_attn_mma_stages_split_q_shared_kv_acc_f32_rr" , fake_fa_func
233
+ )
234
+ setattr (
235
+ lib , "flash_attn_mma_stages_split_q_shared_qkv_acc_f32_rr" , fake_fa_func
236
+ )
205
237
206
238
207
239
def get_mha_tflops (
@@ -342,7 +374,9 @@ def run_benchmark(
342
374
mean_time = total_time / iters
343
375
mean_secs = total_secs / iters
344
376
345
- TFLOPS = get_mha_tflops (B , H , N , D , mean_secs , only_matmul = args .only_flops_matmul )
377
+ TFLOPS = get_mha_tflops (
378
+ B , H , N , D , mean_secs , only_matmul = args .only_flops_matmul
379
+ )
346
380
out_info = f"{ tag } "
347
381
out_val_first = out .flatten ()[:3 ].detach ().cpu ().numpy ().tolist ()
348
382
out_val_last = out .flatten ()[- 3 :].detach ().cpu ().numpy ().tolist ()
@@ -562,17 +596,41 @@ def check_all_close(
562
596
out_unfused , _ = run_benchmark (unfused_standard_attn , q , k , v , "(unfused)" )
563
597
# Split-KV
564
598
out_mma_split_kv1 , _ = run_benchmark (
565
- lib .flash_attn_mma_stages_split_kv , q , k , v , "mma(split-kv+stage1)" , o , stages = 1
599
+ lib .flash_attn_mma_stages_split_kv ,
600
+ q ,
601
+ k ,
602
+ v ,
603
+ "mma(split-kv+stage1)" ,
604
+ o ,
605
+ stages = 1 ,
566
606
)
567
607
out_mma_split_kv2 , _ = run_benchmark (
568
- lib .flash_attn_mma_stages_split_kv , q , k , v , "mma(split-kv+stage2)" , o , stages = 2
608
+ lib .flash_attn_mma_stages_split_kv ,
609
+ q ,
610
+ k ,
611
+ v ,
612
+ "mma(split-kv+stage2)" ,
613
+ o ,
614
+ stages = 2 ,
569
615
)
570
616
# Split-Q
571
617
out_mma_split_q1 , _ = run_benchmark (
572
- lib .flash_attn_mma_stages_split_q , q , k , v , "mma(split-q+stage1)" , o , stages = 1
618
+ lib .flash_attn_mma_stages_split_q ,
619
+ q ,
620
+ k ,
621
+ v ,
622
+ "mma(split-q+stage1)" ,
623
+ o ,
624
+ stages = 1 ,
573
625
)
574
626
out_mma_split_q2 , _ = run_benchmark (
575
- lib .flash_attn_mma_stages_split_q , q , k , v , "mma(split-q+stage2)" , o , stages = 2
627
+ lib .flash_attn_mma_stages_split_q ,
628
+ q ,
629
+ k ,
630
+ v ,
631
+ "mma(split-q+stage2)" ,
632
+ o ,
633
+ stages = 2 ,
576
634
)
577
635
# Split-Q + Shared KV SMEM + Swizzle
578
636
out_mma_share_kv1 , _ = run_benchmark (
@@ -1049,7 +1107,9 @@ def check_all_close(
1049
1107
)
1050
1108
# FA2, SDPA official
1051
1109
out_flash , _ = run_benchmark (flash_attn_func , fq , fk , fv , "(flash)" )
1052
- out_sdpa , _ = run_benchmark (partial (sdpa , use_flash = (D <= 256 )), q , k , v , "(sdpa)" )
1110
+ out_sdpa , _ = run_benchmark (
1111
+ partial (sdpa , use_flash = (D <= 256 )), q , k , v , "(sdpa)"
1112
+ )
1053
1113
pretty_print_line ()
1054
1114
1055
1115
torch .cuda .synchronize ()
@@ -1058,10 +1118,16 @@ def check_all_close(
1058
1118
pretty_print_line ()
1059
1119
# Split-KV
1060
1120
check_all_close (
1061
- out_flash , out_mma_split_kv1 , "out_mma_split_kv1" , args .check_all
1121
+ out_flash ,
1122
+ out_mma_split_kv1 ,
1123
+ "out_mma_split_kv1" ,
1124
+ args .check_all ,
1062
1125
)
1063
1126
check_all_close (
1064
- out_flash , out_mma_split_kv2 , "out_mma_split_kv2" , args .check_all
1127
+ out_flash ,
1128
+ out_mma_split_kv2 ,
1129
+ "out_mma_split_kv2" ,
1130
+ args .check_all ,
1065
1131
)
1066
1132
# Split-Q
1067
1133
check_all_close (
@@ -1072,10 +1138,16 @@ def check_all_close(
1072
1138
)
1073
1139
# Split-Q + Shared KV SMEM
1074
1140
check_all_close (
1075
- out_flash , out_mma_share_kv1 , "out_mma_share_kv1" , args .check_all
1141
+ out_flash ,
1142
+ out_mma_share_kv1 ,
1143
+ "out_mma_share_kv1" ,
1144
+ args .check_all ,
1076
1145
)
1077
1146
check_all_close (
1078
- out_flash , out_mma_share_kv2 , "out_mma_share_kv2" , args .check_all
1147
+ out_flash ,
1148
+ out_mma_share_kv2 ,
1149
+ "out_mma_share_kv2" ,
1150
+ args .check_all ,
1079
1151
)
1080
1152
check_all_close (
1081
1153
out_flash ,
@@ -1090,10 +1162,16 @@ def check_all_close(
1090
1162
args .check_all ,
1091
1163
)
1092
1164
check_all_close (
1093
- out_flash , out_mma_share_kv_sq1 , "out_mma_share_kv_sq1" , args .check_all
1165
+ out_flash ,
1166
+ out_mma_share_kv_sq1 ,
1167
+ "out_mma_share_kv_sq1" ,
1168
+ args .check_all ,
1094
1169
)
1095
1170
check_all_close (
1096
- out_flash , out_mma_share_kv_sq2 , "out_mma_share_kv_sq2" , args .check_all
1171
+ out_flash ,
1172
+ out_mma_share_kv_sq2 ,
1173
+ "out_mma_share_kv_sq2" ,
1174
+ args .check_all ,
1097
1175
)
1098
1176
check_all_close (
1099
1177
out_flash ,
@@ -1121,10 +1199,16 @@ def check_all_close(
1121
1199
)
1122
1200
# Split-Q + Fully Shared QKV SMEM
1123
1201
check_all_close (
1124
- out_flash , out_mma_share_qkv1 , "out_mma_share_qkv1" , args .check_all
1202
+ out_flash ,
1203
+ out_mma_share_qkv1 ,
1204
+ "out_mma_share_qkv1" ,
1205
+ args .check_all ,
1125
1206
)
1126
1207
check_all_close (
1127
- out_flash , out_mma_share_qkv2 , "out_mma_share_qkv2" , args .check_all
1208
+ out_flash ,
1209
+ out_mma_share_qkv2 ,
1210
+ "out_mma_share_qkv2" ,
1211
+ args .check_all ,
1128
1212
)
1129
1213
check_all_close (
1130
1214
out_flash ,
@@ -1176,10 +1260,16 @@ def check_all_close(
1176
1260
)
1177
1261
# Split-Q + QK Fine-grained Tiling
1178
1262
check_all_close (
1179
- out_flash , out_mma_tiling_qk1 , "out_mma_tiling_qk1" , args .check_all
1263
+ out_flash ,
1264
+ out_mma_tiling_qk1 ,
1265
+ "out_mma_tiling_qk1" ,
1266
+ args .check_all ,
1180
1267
)
1181
1268
check_all_close (
1182
- out_flash , out_mma_tiling_qk2 , "out_mma_tiling_qk2" , args .check_all
1269
+ out_flash ,
1270
+ out_mma_tiling_qk2 ,
1271
+ "out_mma_tiling_qk2" ,
1272
+ args .check_all ,
1183
1273
)
1184
1274
check_all_close (
1185
1275
out_flash ,
@@ -1231,10 +1321,16 @@ def check_all_close(
1231
1321
)
1232
1322
# Split-Q + Fully QKV Fine-grained Tiling
1233
1323
check_all_close (
1234
- out_flash , out_mma_tiling_qkv1 , "out_mma_tiling_qkv1" , args .check_all
1324
+ out_flash ,
1325
+ out_mma_tiling_qkv1 ,
1326
+ "out_mma_tiling_qkv1" ,
1327
+ args .check_all ,
1235
1328
)
1236
1329
check_all_close (
1237
- out_flash , out_mma_tiling_qkv2 , "out_mma_tiling_qkv2" , args .check_all
1330
+ out_flash ,
1331
+ out_mma_tiling_qkv2 ,
1332
+ "out_mma_tiling_qkv2" ,
1333
+ args .check_all ,
1238
1334
)
1239
1335
check_all_close (
1240
1336
out_flash ,
@@ -1322,10 +1418,16 @@ def check_all_close(
1322
1418
)
1323
1419
# Others, O s2g, etc.
1324
1420
check_all_close (
1325
- out_flash , out_mma_share_kv_rr1 , "out_mma_share_kv_rr1" , args .check_all
1421
+ out_flash ,
1422
+ out_mma_share_kv_rr1 ,
1423
+ "out_mma_share_kv_rr1" ,
1424
+ args .check_all ,
1326
1425
)
1327
1426
check_all_close (
1328
- out_flash , out_mma_share_kv_rr2 , "out_mma_share_kv_rr2" , args .check_all
1427
+ out_flash ,
1428
+ out_mma_share_kv_rr2 ,
1429
+ "out_mma_share_kv_rr2" ,
1430
+ args .check_all ,
1329
1431
)
1330
1432
check_all_close (
1331
1433
out_flash ,
0 commit comments