@@ -65,11 +65,7 @@ class VulkanLaunchFuncToVulkanCallsPass
65
65
void initializeCachedTypes () {
66
66
llvmFloatType = Float32Type::get (&getContext ());
67
67
llvmVoidType = LLVM::LLVMVoidType::get (&getContext ());
68
- if (useOpaquePointers)
69
- llvmPointerType = LLVM::LLVMPointerType::get (&getContext ());
70
- else
71
- llvmPointerType =
72
- LLVM::LLVMPointerType::get (IntegerType::get (&getContext (), 8 ));
68
+ llvmPointerType = LLVM::LLVMPointerType::get (&getContext ());
73
69
llvmInt32Type = IntegerType::get (&getContext (), 32 );
74
70
llvmInt64Type = IntegerType::get (&getContext (), 64 );
75
71
}
@@ -85,9 +81,6 @@ class VulkanLaunchFuncToVulkanCallsPass
85
81
// int64_t sizes[Rank]; // omitted when rank == 0
86
82
// int64_t strides[Rank]; // omitted when rank == 0
87
83
// };
88
- auto llvmPtrToElementType = useOpaquePointers
89
- ? llvmPointerType
90
- : LLVM::LLVMPointerType::get (elemenType);
91
84
auto llvmArrayRankElementSizeType =
92
85
LLVM::LLVMArrayType::get (getInt64Type (), rank);
93
86
@@ -96,7 +89,7 @@ class VulkanLaunchFuncToVulkanCallsPass
96
89
// [`rank` x i64], [`rank` x i64]}">`.
97
90
return LLVM::LLVMStructType::getLiteral (
98
91
&getContext (),
99
- {llvmPtrToElementType, llvmPtrToElementType , getInt64Type (),
92
+ {llvmPointerType, llvmPointerType , getInt64Type (),
100
93
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
101
94
}
102
95
@@ -280,13 +273,6 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
280
273
281
274
auto symbolName =
282
275
llvm::formatv (" bindMemRef{0}D{1}" , rank, stringifyType (type)).str ();
283
- // Special case for fp16 type. Since it is not a supported type in C we use
284
- // int16_t and bitcast the descriptor.
285
- if (!useOpaquePointers && isa<Float16Type>(type)) {
286
- auto memRefTy = getMemRefType (rank, IntegerType::get (&getContext (), 16 ));
287
- ptrToMemRefDescriptor = builder.create <LLVM::BitcastOp>(
288
- loc, LLVM::LLVMPointerType::get (memRefTy), ptrToMemRefDescriptor);
289
- }
290
276
// Create call to `bindMemRef`.
291
277
builder.create <LLVM::CallOp>(
292
278
loc, TypeRange (), StringRef (symbolName.data (), symbolName.size ()),
@@ -303,16 +289,9 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value launchCallArg,
303
289
if (!alloca )
304
290
return failure ();
305
291
306
- LLVM::LLVMStructType llvmDescriptorTy;
307
- if (std::optional<Type> elementType = alloca .getElemType ()) {
308
- llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
309
- } else {
310
- // This case is only possible if we are not using opaque pointers
311
- // since opaque pointer producing allocas require an element type.
312
- llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(
313
- alloca .getRes ().getType ().getElementType ());
314
- }
315
-
292
+ std::optional<Type> elementType = alloca .getElemType ();
293
+ assert (elementType && " expected to work with opaque pointers" );
294
+ auto llvmDescriptorTy = dyn_cast<LLVM::LLVMStructType>(*elementType);
316
295
// template <typename Elem, size_t Rank>
317
296
// struct {
318
297
// Elem *allocated;
@@ -379,10 +358,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
379
358
if (!module.lookupSymbol (fnName)) {
380
359
auto fnType = LLVM::LLVMFunctionType::get (
381
360
getVoidType (),
382
- {getPointerType (), getInt32Type (), getInt32Type (),
383
- useOpaquePointers
384
- ? llvmPointerType
385
- : LLVM::LLVMPointerType::get (getMemRefType (i, type))},
361
+ {llvmPointerType, getInt32Type (), getInt32Type (), llvmPointerType},
386
362
/* isVarArg=*/ false );
387
363
builder.create <LLVM::LLVMFuncOp>(loc, fnName, fnType);
388
364
}
@@ -410,8 +386,7 @@ Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant(
410
386
411
387
std::string entryPointGlobalName = (name + " _spv_entry_point_name" ).str ();
412
388
return LLVM::createGlobalString (loc, builder, entryPointGlobalName,
413
- shaderName, LLVM::Linkage::Internal,
414
- useOpaquePointers);
389
+ shaderName, LLVM::Linkage::Internal);
415
390
}
416
391
417
392
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall (
@@ -429,7 +404,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
429
404
// that data to runtime call.
430
405
Value ptrToSPIRVBinary = LLVM::createGlobalString (
431
406
loc, builder, kSPIRVBinary , spirvAttributes.blob .getValue (),
432
- LLVM::Linkage::Internal, useOpaquePointers );
407
+ LLVM::Linkage::Internal);
433
408
434
409
// Create LLVM constant for the size of SPIR-V binary shader.
435
410
Value binarySize = builder.create <LLVM::ConstantOp>(
0 commit comments