Skip to content

Commit 7673ac6

Browse files
authored
feat: update pre-commit max-length=80 (#307)
* feat: update pre-commit length=120 * feat: update pre-commit max-length=80
1 parent 06c10ac commit 7673ac6

19 files changed

+428
-123
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ repos:
2626
rev: 24.10.0
2727
hooks:
2828
- id: black-jupyter
29+
args:
30+
- --line-length=80
2931
- repo: https://github.com/pre-commit/mirrors-clang-format
3032
rev: v18.1.8
3133
hooks:

kernels/elementwise/elementwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def run_benchmark(
8888
run_benchmark(lib.elementwise_add_f16, a_f16, b_f16, "f16", c_f16)
8989
run_benchmark(lib.elementwise_add_f16x2, a_f16, b_f16, "f16x2", c_f16)
9090
run_benchmark(lib.elementwise_add_f16x8, a_f16, b_f16, "f16x8", c_f16)
91-
run_benchmark(lib.elementwise_add_f16x8_pack, a_f16, b_f16, "f16x8pack", c_f16)
91+
run_benchmark(
92+
lib.elementwise_add_f16x8_pack, a_f16, b_f16, "f16x8pack", c_f16
93+
)
9294
run_benchmark(partial(torch.add, out=c_f16), a_f16, b_f16, "f16_th")
9395
print("-" * 85)

kernels/flash-attn/flash_attn_mma.py

Lines changed: 138 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ def get_args():
3232
parser.add_argument("--check-all", action="store_true")
3333
parser.add_argument("--show-all", "--show", action="store_true")
3434
parser.add_argument("--show-matrix", action="store_true")
35-
parser.add_argument("--only-flops-matmul", "--flops-mm", action="store_true")
36-
parser.add_argument("--run-acc-f32", "--acc-f32", "--f32", action="store_true")
35+
parser.add_argument(
36+
"--only-flops-matmul", "--flops-mm", action="store_true"
37+
)
38+
parser.add_argument(
39+
"--run-acc-f32", "--acc-f32", "--f32", action="store_true"
40+
)
3741
parser.add_argument("--B", type=int, default=None)
3842
parser.add_argument("--H", type=int, default=None)
3943
parser.add_argument("--N", type=int, default=None)
@@ -46,7 +50,9 @@ def get_args():
4650
parser.add_argument("--iters", "--i", type=int, default=5)
4751
parser.add_argument("--range-k", "--gk", action="store_true")
4852
parser.add_argument("--build-others", "--others", action="store_true")
49-
parser.add_argument("--tag-hints", "--tags", "--hints", type=str, default=None)
53+
parser.add_argument(
54+
"--tag-hints", "--tags", "--hints", type=str, default=None
55+
)
5056
return parser.parse_args()
5157

5258

@@ -84,20 +90,30 @@ def get_build_sources():
8490
build_sources.append("./mma/basic/flash_attn_mma_share_kv_F32F16F16F32.cu")
8591
build_sources.append("./mma/basic/flash_attn_mma_share_qkv_F32F16F16F32.cu")
8692
build_sources.append("./mma/basic/flash_attn_mma_tiling_qk_F32F16F16F32.cu")
87-
build_sources.append("./mma/basic/flash_attn_mma_tiling_qkv_F32F16F16F32.cu")
93+
build_sources.append(
94+
"./mma/basic/flash_attn_mma_tiling_qkv_F32F16F16F32.cu"
95+
)
8896
# Swizzle
8997
build_sources.append("./mma/swizzle/flash_attn_mma_share_kv_swizzle_q.cu")
9098
build_sources.append("./mma/swizzle/flash_attn_mma_share_kv_swizzle_qk.cu")
9199
build_sources.append("./mma/swizzle/flash_attn_mma_share_kv_swizzle_qkv.cu")
92100
build_sources.append("./mma/swizzle/flash_attn_mma_share_qkv_swizzle_q.cu")
93101
build_sources.append("./mma/swizzle/flash_attn_mma_share_qkv_swizzle_qk.cu")
94-
build_sources.append("./mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu")
102+
build_sources.append(
103+
"./mma/swizzle/flash_attn_mma_share_qkv_swizzle_qkv.cu"
104+
)
95105
build_sources.append("./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_q.cu")
96106
build_sources.append("./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qk.cu")
97-
build_sources.append("./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu")
107+
build_sources.append(
108+
"./mma/swizzle/flash_attn_mma_tiling_qk_swizzle_qkv.cu"
109+
)
98110
build_sources.append("./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q.cu")
99-
build_sources.append("./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu")
100-
build_sources.append("./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu")
111+
build_sources.append(
112+
"./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qk.cu"
113+
)
114+
build_sources.append(
115+
"./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_qkv.cu"
116+
)
101117
build_sources.append(
102118
"./mma/swizzle/flash_attn_mma_tiling_qkv_swizzle_q_F32F16F16F32.cu"
103119
)
@@ -110,15 +126,21 @@ def get_build_sources():
110126
# Others
111127
if args.build_others:
112128
build_sources.append("./mma/others/flash_attn_mma_share_qkv_Os2g.cu")
113-
build_sources.append("./mma/others/flash_attn_mma_share_kv_F32F16F16F32_rr.cu")
114-
build_sources.append("./mma/others/flash_attn_mma_share_qkv_F32F16F16F32_rr.cu")
129+
build_sources.append(
130+
"./mma/others/flash_attn_mma_share_kv_F32F16F16F32_rr.cu"
131+
)
132+
build_sources.append(
133+
"./mma/others/flash_attn_mma_share_qkv_F32F16F16F32_rr.cu"
134+
)
115135
# Pybind
116136
build_sources.append("./pybind/flash_attn.cc")
117137
return build_sources
118138

119139

120140
def get_project_dir():
121-
return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
141+
return os.path.dirname(
142+
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
143+
)
122144

123145

124146
project_dir = get_project_dir()
@@ -153,7 +175,9 @@ def get_build_cuda_cflags(build_pkg: bool = False):
153175
extra_cuda_cflags.append(
154176
"-diag-suppress 177" if not build_pkg else "--ptxas-options=-v"
155177
)
156-
extra_cuda_cflags.append("-Xptxas -v" if not build_pkg else "--ptxas-options=-O3")
178+
extra_cuda_cflags.append(
179+
"-Xptxas -v" if not build_pkg else "--ptxas-options=-O3"
180+
)
157181
extra_cuda_cflags.append(f"-I {project_dir}/kernels/flash-attn")
158182
extra_cuda_cflags.append(f"-I {project_dir}/kernels/flash-attn/utils")
159183
extra_cuda_cflags.append(f"-I {project_dir}/kernels/flash-attn/mma")
@@ -163,14 +187,18 @@ def get_build_cuda_cflags(build_pkg: bool = False):
163187
extra_cuda_cflags.append(f"-I {project_dir}/kernels/flash-attn/cutlass")
164188
extra_cuda_cflags.append(f"-I {project_dir}/kernels/flash-attn/pybind")
165189
extra_cuda_cflags.append(f"-I {project_dir}/third-party/cutlass/include")
166-
extra_cuda_cflags.append(f"-I {project_dir}/third-party/cutlass/tools/util/include")
190+
extra_cuda_cflags.append(
191+
f"-I {project_dir}/third-party/cutlass/tools/util/include"
192+
)
167193
return extra_cuda_cflags
168194

169195

170196
def get_build_cflags():
171197
extra_cflags = []
172198
extra_cflags.append("-std=c++17")
173-
extra_cflags.append("-DBUILD_FLASH_ATTN_MMA_OTHERS" if args.build_others else "")
199+
extra_cflags.append(
200+
"-DBUILD_FLASH_ATTN_MMA_OTHERS" if args.build_others else ""
201+
)
174202
return extra_cflags
175203

176204

@@ -200,8 +228,12 @@ def pretty_print_line(m: str = "", sep: str = "-", width: int = 150):
200228
if not args.build_others:
201229
fake_fa_func = lambda q, k, v, o, s: o # fake FA func
202230
setattr(lib, "flash_attn_mma_stages_split_q_shared_qkv_Os2g", fake_fa_func)
203-
setattr(lib, "flash_attn_mma_stages_split_q_shared_kv_acc_f32_rr", fake_fa_func)
204-
setattr(lib, "flash_attn_mma_stages_split_q_shared_qkv_acc_f32_rr", fake_fa_func)
231+
setattr(
232+
lib, "flash_attn_mma_stages_split_q_shared_kv_acc_f32_rr", fake_fa_func
233+
)
234+
setattr(
235+
lib, "flash_attn_mma_stages_split_q_shared_qkv_acc_f32_rr", fake_fa_func
236+
)
205237

206238

207239
def get_mha_tflops(
@@ -342,7 +374,9 @@ def run_benchmark(
342374
mean_time = total_time / iters
343375
mean_secs = total_secs / iters
344376

345-
TFLOPS = get_mha_tflops(B, H, N, D, mean_secs, only_matmul=args.only_flops_matmul)
377+
TFLOPS = get_mha_tflops(
378+
B, H, N, D, mean_secs, only_matmul=args.only_flops_matmul
379+
)
346380
out_info = f"{tag}"
347381
out_val_first = out.flatten()[:3].detach().cpu().numpy().tolist()
348382
out_val_last = out.flatten()[-3:].detach().cpu().numpy().tolist()
@@ -562,17 +596,41 @@ def check_all_close(
562596
out_unfused, _ = run_benchmark(unfused_standard_attn, q, k, v, "(unfused)")
563597
# Split-KV
564598
out_mma_split_kv1, _ = run_benchmark(
565-
lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage1)", o, stages=1
599+
lib.flash_attn_mma_stages_split_kv,
600+
q,
601+
k,
602+
v,
603+
"mma(split-kv+stage1)",
604+
o,
605+
stages=1,
566606
)
567607
out_mma_split_kv2, _ = run_benchmark(
568-
lib.flash_attn_mma_stages_split_kv, q, k, v, "mma(split-kv+stage2)", o, stages=2
608+
lib.flash_attn_mma_stages_split_kv,
609+
q,
610+
k,
611+
v,
612+
"mma(split-kv+stage2)",
613+
o,
614+
stages=2,
569615
)
570616
# Split-Q
571617
out_mma_split_q1, _ = run_benchmark(
572-
lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage1)", o, stages=1
618+
lib.flash_attn_mma_stages_split_q,
619+
q,
620+
k,
621+
v,
622+
"mma(split-q+stage1)",
623+
o,
624+
stages=1,
573625
)
574626
out_mma_split_q2, _ = run_benchmark(
575-
lib.flash_attn_mma_stages_split_q, q, k, v, "mma(split-q+stage2)", o, stages=2
627+
lib.flash_attn_mma_stages_split_q,
628+
q,
629+
k,
630+
v,
631+
"mma(split-q+stage2)",
632+
o,
633+
stages=2,
576634
)
577635
# Split-Q + Shared KV SMEM + Swizzle
578636
out_mma_share_kv1, _ = run_benchmark(
@@ -1049,7 +1107,9 @@ def check_all_close(
10491107
)
10501108
# FA2, SDPA official
10511109
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
1052-
out_sdpa, _ = run_benchmark(partial(sdpa, use_flash=(D <= 256)), q, k, v, "(sdpa)")
1110+
out_sdpa, _ = run_benchmark(
1111+
partial(sdpa, use_flash=(D <= 256)), q, k, v, "(sdpa)"
1112+
)
10531113
pretty_print_line()
10541114

10551115
torch.cuda.synchronize()
@@ -1058,10 +1118,16 @@ def check_all_close(
10581118
pretty_print_line()
10591119
# Split-KV
10601120
check_all_close(
1061-
out_flash, out_mma_split_kv1, "out_mma_split_kv1", args.check_all
1121+
out_flash,
1122+
out_mma_split_kv1,
1123+
"out_mma_split_kv1",
1124+
args.check_all,
10621125
)
10631126
check_all_close(
1064-
out_flash, out_mma_split_kv2, "out_mma_split_kv2", args.check_all
1127+
out_flash,
1128+
out_mma_split_kv2,
1129+
"out_mma_split_kv2",
1130+
args.check_all,
10651131
)
10661132
# Split-Q
10671133
check_all_close(
@@ -1072,10 +1138,16 @@ def check_all_close(
10721138
)
10731139
# Split-Q + Shared KV SMEM
10741140
check_all_close(
1075-
out_flash, out_mma_share_kv1, "out_mma_share_kv1", args.check_all
1141+
out_flash,
1142+
out_mma_share_kv1,
1143+
"out_mma_share_kv1",
1144+
args.check_all,
10761145
)
10771146
check_all_close(
1078-
out_flash, out_mma_share_kv2, "out_mma_share_kv2", args.check_all
1147+
out_flash,
1148+
out_mma_share_kv2,
1149+
"out_mma_share_kv2",
1150+
args.check_all,
10791151
)
10801152
check_all_close(
10811153
out_flash,
@@ -1090,10 +1162,16 @@ def check_all_close(
10901162
args.check_all,
10911163
)
10921164
check_all_close(
1093-
out_flash, out_mma_share_kv_sq1, "out_mma_share_kv_sq1", args.check_all
1165+
out_flash,
1166+
out_mma_share_kv_sq1,
1167+
"out_mma_share_kv_sq1",
1168+
args.check_all,
10941169
)
10951170
check_all_close(
1096-
out_flash, out_mma_share_kv_sq2, "out_mma_share_kv_sq2", args.check_all
1171+
out_flash,
1172+
out_mma_share_kv_sq2,
1173+
"out_mma_share_kv_sq2",
1174+
args.check_all,
10971175
)
10981176
check_all_close(
10991177
out_flash,
@@ -1121,10 +1199,16 @@ def check_all_close(
11211199
)
11221200
# Split-Q + Fully Shared QKV SMEM
11231201
check_all_close(
1124-
out_flash, out_mma_share_qkv1, "out_mma_share_qkv1", args.check_all
1202+
out_flash,
1203+
out_mma_share_qkv1,
1204+
"out_mma_share_qkv1",
1205+
args.check_all,
11251206
)
11261207
check_all_close(
1127-
out_flash, out_mma_share_qkv2, "out_mma_share_qkv2", args.check_all
1208+
out_flash,
1209+
out_mma_share_qkv2,
1210+
"out_mma_share_qkv2",
1211+
args.check_all,
11281212
)
11291213
check_all_close(
11301214
out_flash,
@@ -1176,10 +1260,16 @@ def check_all_close(
11761260
)
11771261
# Split-Q + QK Fine-grained Tiling
11781262
check_all_close(
1179-
out_flash, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all
1263+
out_flash,
1264+
out_mma_tiling_qk1,
1265+
"out_mma_tiling_qk1",
1266+
args.check_all,
11801267
)
11811268
check_all_close(
1182-
out_flash, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all
1269+
out_flash,
1270+
out_mma_tiling_qk2,
1271+
"out_mma_tiling_qk2",
1272+
args.check_all,
11831273
)
11841274
check_all_close(
11851275
out_flash,
@@ -1231,10 +1321,16 @@ def check_all_close(
12311321
)
12321322
# Split-Q + Fully QKV Fine-grained Tiling
12331323
check_all_close(
1234-
out_flash, out_mma_tiling_qkv1, "out_mma_tiling_qkv1", args.check_all
1324+
out_flash,
1325+
out_mma_tiling_qkv1,
1326+
"out_mma_tiling_qkv1",
1327+
args.check_all,
12351328
)
12361329
check_all_close(
1237-
out_flash, out_mma_tiling_qkv2, "out_mma_tiling_qkv2", args.check_all
1330+
out_flash,
1331+
out_mma_tiling_qkv2,
1332+
"out_mma_tiling_qkv2",
1333+
args.check_all,
12381334
)
12391335
check_all_close(
12401336
out_flash,
@@ -1322,10 +1418,16 @@ def check_all_close(
13221418
)
13231419
# Others, O s2g, etc.
13241420
check_all_close(
1325-
out_flash, out_mma_share_kv_rr1, "out_mma_share_kv_rr1", args.check_all
1421+
out_flash,
1422+
out_mma_share_kv_rr1,
1423+
"out_mma_share_kv_rr1",
1424+
args.check_all,
13261425
)
13271426
check_all_close(
1328-
out_flash, out_mma_share_kv_rr2, "out_mma_share_kv_rr2", args.check_all
1427+
out_flash,
1428+
out_mma_share_kv_rr2,
1429+
"out_mma_share_kv_rr2",
1430+
args.check_all,
13291431
)
13301432
check_all_close(
13311433
out_flash,

kernels/flash-attn/tools/print_swizzle_layout.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,21 @@ def print_smem_swizzle_layout(
121121
max_bank_str_len = 0
122122
if logical_col_stride >= 16 and (not use_logical_col_stride):
123123
for k in range(int(logical_col_stride / 16)):
124-
for j in range(banks_start, banks_end, banks_per_num_elems_per_128b):
124+
for j in range(
125+
banks_start, banks_end, banks_per_num_elems_per_128b
126+
):
125127
curr_bank_str = (
126128
f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|"
127129
)
128130
max_bank_str_len = max(max_bank_str_len, len(curr_bank_str))
129131
bank_layout_str += curr_bank_str
130132
else:
131-
for j in range(banks_start, banks_end, banks_per_num_elems_per_128b):
132-
curr_bank_str = f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|"
133+
for j in range(
134+
banks_start, banks_end, banks_per_num_elems_per_128b
135+
):
136+
curr_bank_str = (
137+
f"b{j:>2}~{j + banks_per_num_elems_per_128b - 1:<2}|"
138+
)
133139
max_bank_str_len = max(max_bank_str_len, len(curr_bank_str))
134140
bank_layout_str += curr_bank_str
135141

@@ -195,7 +201,8 @@ def print_smem_swizzle_layout(
195201
print("-" * str_len)
196202
pretty_print_line("swizzle layout", width=str_len)
197203
pretty_print_line(
198-
f"logical col 0~{logical_col_stride}, " f"step {num_elems_per_128b}",
204+
f"logical col 0~{logical_col_stride}, "
205+
f"step {num_elems_per_128b}",
199206
width=str_len,
200207
)
201208
pretty_print_line(
@@ -219,7 +226,9 @@ def get_args():
219226
parser = argparse.ArgumentParser()
220227
parser.add_argument("--rows", type=int, default=16)
221228
parser.add_argument("--smem-padding", "--pad", type=int, default=0)
222-
parser.add_argument("--num-elems-per-128b", "--num-elems", type=int, default=8)
229+
parser.add_argument(
230+
"--num-elems-per-128b", "--num-elems", type=int, default=8
231+
)
223232
parser.add_argument(
224233
"--logical-col-stride", "--logical-col", "--col", type=int, default=64
225234
)

0 commit comments

Comments
 (0)