@@ -204,5 +204,60 @@ def _(
204
204
size_n : int ,
205
205
size_k : int ,
206
206
) -> Tensor :
207
- # NOTE: Checks in kernel
207
+ TILE_SIZE = 16
208
+ MIN_THREAD_N = 128
209
+ MAX_PARALLELISM = 64
210
+
211
+ # Verify num_bits
212
+ torch ._check (bits == 4 or bits == 8 , lambda : f"num_bits must be 4 or 8. Got = { bits } " )
213
+ pack_factor = 32 // bits
214
+
215
+ # Verify M
216
+ torch ._check (size_m == x .size (0 ), lambda : f"Shape mismatch: x.size(0) = { x .size (0 )} , size_m = { size_m } " )
217
+
218
+ # Verify K
219
+ torch ._check (size_k == x .size (1 ), lambda : f"Shape mismatch: x.size(1) = { x .size (1 )} , size_k = { size_k } " )
220
+ torch ._check (size_k % TILE_SIZE == 0 , lambda : f"size_k = { size_k } is not divisible by tile_size = { TILE_SIZE } " )
221
+ torch ._check ((size_k // TILE_SIZE // 2 ) == weight_marlin .size (0 ), lambda : f"Shape mismatch: weight_marlin.size(0) = { weight_marlin .size (0 )} , size_k = { size_k } , tile_size = { TILE_SIZE } " )
222
+
223
+ # Verify N
224
+ torch ._check (s .size (1 ) == size_n , lambda : f"s.size(1) = { s .size (1 )} , size_n = { size_n } " )
225
+ torch ._check (weight_marlin .size (1 ) % TILE_SIZE == 0 , lambda : f"weight_marlin.size(1) = { weight_marlin .size (1 )} is not divisible by tile_size = { TILE_SIZE } " )
226
+
227
+ actual_size_n = (weight_marlin .size (1 ) // TILE_SIZE ) * pack_factor
228
+ torch ._check (size_n == actual_size_n , lambda : f"size_n = { size_n } , actual_size_n = { actual_size_n } " )
229
+
230
+ # Verify meta
231
+ torch ._check (meta .size (0 ) == size_k // 8 // 2 // 2 , lambda : f"meta.size(0) = { meta .size (0 )} is not size_k / 8 / 2 / 2 = { size_k // 8 // 2 // 2 } " )
232
+ torch ._check (meta .size (1 ) == size_n * 2 , lambda : f"meta.size(1) = { meta .size (1 )} is not size_n * 2 = { size_n * 2 } " )
233
+
234
+ # Verify A device and strides
235
+ torch ._check (x .is_cuda , lambda : "x is not on GPU" )
236
+ torch ._check (x .is_contiguous (), lambda : "x is not contiguous" )
237
+
238
+ # Verify B device and strides
239
+ torch ._check (weight_marlin .is_cuda , lambda : "weight_marlin is not on GPU" )
240
+ torch ._check (weight_marlin .is_contiguous (), lambda : "weight_marlin is not contiguous" )
241
+
242
+ # Verify meta device and strides
243
+ torch ._check (meta .is_cuda , lambda : "meta is not on GPU" )
244
+ torch ._check (meta .is_contiguous (), lambda : "meta is not contiguous" )
245
+
246
+ # Verify scales device and strides
247
+ torch ._check (s .is_cuda , lambda : "s is not on GPU" )
248
+ torch ._check (s .is_contiguous (), lambda : "s is not contiguous" )
249
+
250
+ # Verify groupsize
251
+ groupsize = - 1
252
+ if s .size (0 ) > 1 :
253
+ torch ._check (size_k % s .size (0 ) == 0 , lambda : f"size_k = { size_k } is not divisible by s.size(0) = { s .size (0 )} " )
254
+ groupsize = size_k // s .size (0 )
255
+ groupsize //= 2 # Because of 24
256
+ torch ._check (groupsize == - 1 or groupsize == 64 , lambda : f"Unexpected groupsize = { groupsize } " )
257
+
258
+ # Verify workspace size
259
+ torch ._check (size_n % MIN_THREAD_N == 0 , lambda : f"size_n = { size_n } is not divisible by min_thread_n = { MIN_THREAD_N } " )
260
+ min_workspace_size = (size_n // MIN_THREAD_N ) * MAX_PARALLELISM
261
+ torch ._check (workspace .numel () >= min_workspace_size , lambda : f"workspace.numel = { workspace .numel ()} is below min_workspace_size = { min_workspace_size } " )
262
+
208
263
return torch .empty ((x .size (0 ), s .size (1 )), dtype = x .dtype , device = x .device )
0 commit comments