Skip to content

Commit 68afac5

Browse files
AlexMacleangithub-actions[bot]
authored andcommitted
Automerge: [NVPTX] Basic support for fp128 as a storage type (#136006)
While fp128 operations are not natively supported in hardware, emulation for them is supported by nvcc. This change adds basic support for fp128 as a storage type allowing for lowering of IR containing these types. Fixes: llvm/llvm-project#95471
2 parents 51b9f5a + 3001387 commit 68afac5

File tree

7 files changed

+137
-96
lines changed

7 files changed

+137
-96
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

+45-53
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
249249
return MCOperand::createExpr(Expr);
250250
}
251251

252-
static bool ShouldPassAsArray(Type *Ty) {
253-
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
254-
Ty->isHalfTy() || Ty->isBFloatTy();
255-
}
256-
257252
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
258253
const DataLayout &DL = getDataLayout();
259254
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
264259
return;
265260
O << " (";
266261

267-
if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
268-
!ShouldPassAsArray(Ty)) {
269-
unsigned size = 0;
270-
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
271-
size = ITy->getBitWidth();
272-
} else {
273-
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
274-
size = Ty->getPrimitiveSizeInBits();
275-
}
276-
size = promoteScalarArgumentSize(size);
277-
O << ".param .b" << size << " func_retval0";
278-
} else if (isa<PointerType>(Ty)) {
279-
O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
280-
<< " func_retval0";
281-
} else if (ShouldPassAsArray(Ty)) {
282-
unsigned totalsz = DL.getTypeAllocSize(Ty);
283-
Align RetAlignment = TLI->getFunctionArgumentAlignment(
262+
auto PrintScalarRetVal = [&](unsigned Size) {
263+
O << ".param .b" << promoteScalarArgumentSize(Size) << " func_retval0";
264+
};
265+
if (shouldPassAsArray(Ty)) {
266+
const unsigned TotalSize = DL.getTypeAllocSize(Ty);
267+
const Align RetAlignment = TLI->getFunctionArgumentAlignment(
284268
F, Ty, AttributeList::ReturnIndex, DL);
285269
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
286-
<< totalsz << "]";
270+
<< TotalSize << "]";
271+
} else if (Ty->isFloatingPointTy()) {
272+
PrintScalarRetVal(Ty->getPrimitiveSizeInBits());
273+
} else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
274+
PrintScalarRetVal(ITy->getBitWidth());
275+
} else if (isa<PointerType>(Ty)) {
276+
PrintScalarRetVal(TLI->getPointerTy(DL).getSizeInBits());
287277
} else
288278
llvm_unreachable("Unknown return type");
289279
O << ") ";
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
975965
O << " .align "
976966
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
977967

978-
if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
979-
(ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
968+
if (ETy->isPointerTy() || ((ETy->isIntegerTy() || ETy->isFloatingPointTy()) &&
969+
ETy->getScalarSizeInBits() <= 64)) {
980970
O << " .";
981971
// Special case: ABI requires that we use .u8 for predicates
982972
if (ETy->isIntegerTy(1))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
10161006
// and vectors are lowered into arrays of bytes.
10171007
switch (ETy->getTypeID()) {
10181008
case Type::IntegerTyID: // Integers larger than 64 bits
1009+
case Type::FP128TyID:
10191010
case Type::StructTyID:
10201011
case Type::ArrayTyID:
10211012
case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
12661257
O << " .align "
12671258
<< GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
12681259

1269-
// Special case for i128
1270-
if (ETy->isIntegerTy(128)) {
1260+
// Special case for i128/fp128
1261+
if (ETy->getScalarSizeInBits() == 128) {
12711262
O << " .b8 ";
12721263
getSymbol(GVar)->print(O, MAI);
12731264
O << "[16]";
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
13831374
continue;
13841375
}
13851376

1386-
if (ShouldPassAsArray(Ty)) {
1377+
if (shouldPassAsArray(Ty)) {
13871378
// Just print .param .align <a> .b8 .param[size];
13881379
// <a> = optimal alignment for the element type; always multiple of
13891380
// PAL.getParamAlignment
@@ -1682,48 +1673,49 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
16821673
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
16831674
AggBuffer *aggBuffer) {
16841675
const DataLayout &DL = getDataLayout();
1685-
int Bytes;
1676+
1677+
auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1678+
for (unsigned I : llvm::seq(Val.getBitWidth() / 8))
1679+
Buffer->addByte(Val.extractBitsAsZExtValue(8, I * 8));
1680+
};
16861681

16871682
// Integers of arbitrary width
16881683
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1689-
APInt Val = CI->getValue();
1690-
for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1691-
uint8_t Byte = Val.getLoBits(8).getZExtValue();
1692-
aggBuffer->addBytes(&Byte, 1, 1);
1693-
Val.lshrInPlace(8);
1694-
}
1684+
ExtendBuffer(CI->getValue(), aggBuffer);
16951685
return;
16961686
}
16971687

1688+
// f128
1689+
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1690+
if (CFP->getType()->isFP128Ty()) {
1691+
ExtendBuffer(CFP->getValueAPF().bitcastToAPInt(), aggBuffer);
1692+
return;
1693+
}
1694+
}
1695+
16981696
// Old constants
16991697
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1700-
if (CPV->getNumOperands())
1701-
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1702-
bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1698+
for (const auto &Op : CPV->operands())
1699+
bufferLEByte(cast<Constant>(Op), 0, aggBuffer);
17031700
return;
17041701
}
17051702

1706-
if (const ConstantDataSequential *CDS =
1707-
dyn_cast<ConstantDataSequential>(CPV)) {
1708-
if (CDS->getNumElements())
1709-
for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1710-
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1711-
aggBuffer);
1703+
if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
1704+
for (unsigned I : llvm::seq(CDS->getNumElements()))
1705+
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(I)), 0, aggBuffer);
17121706
return;
17131707
}
17141708

17151709
if (isa<ConstantStruct>(CPV)) {
17161710
if (CPV->getNumOperands()) {
17171711
StructType *ST = cast<StructType>(CPV->getType());
1718-
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1719-
if (i == (e - 1))
1720-
Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1721-
DL.getTypeAllocSize(ST) -
1722-
DL.getStructLayout(ST)->getElementOffset(i);
1723-
else
1724-
Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1725-
DL.getStructLayout(ST)->getElementOffset(i);
1726-
bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1712+
for (unsigned I : llvm::seq(CPV->getNumOperands())) {
1713+
int EndOffset = (I + 1 == CPV->getNumOperands())
1714+
? DL.getStructLayout(ST)->getElementOffset(0) +
1715+
DL.getTypeAllocSize(ST)
1716+
: DL.getStructLayout(ST)->getElementOffset(I + 1);
1717+
int Bytes = EndOffset - DL.getStructLayout(ST)->getElementOffset(I);
1718+
bufferLEByte(cast<Constant>(CPV->getOperand(I)), Bytes, aggBuffer);
17271719
}
17281720
}
17291721
return;

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h

+14-19
Original file line numberDiff line numberDiff line change
@@ -111,27 +111,22 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
111111

112112
// Copy Num bytes from Ptr.
113113
// if Bytes > Num, zero fill up to Bytes.
114-
unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) {
115-
assert((curpos + Num) <= size);
116-
assert((curpos + Bytes) <= size);
117-
for (int i = 0; i < Num; ++i) {
118-
buffer[curpos] = Ptr[i];
119-
curpos++;
120-
}
121-
for (int i = Num; i < Bytes; ++i) {
122-
buffer[curpos] = 0;
123-
curpos++;
124-
}
125-
return curpos;
114+
void addBytes(const unsigned char *Ptr, unsigned Num, unsigned Bytes) {
115+
for (unsigned I : llvm::seq(Num))
116+
addByte(Ptr[I]);
117+
if (Bytes > Num)
118+
addZeros(Bytes - Num);
126119
}
127120

128-
unsigned addZeros(int Num) {
129-
assert((curpos + Num) <= size);
130-
for (int i = 0; i < Num; ++i) {
131-
buffer[curpos] = 0;
132-
curpos++;
133-
}
134-
return curpos;
121+
void addByte(uint8_t Byte) {
122+
assert(curpos < size);
123+
buffer[curpos] = Byte;
124+
curpos++;
125+
}
126+
127+
void addZeros(unsigned Num) {
128+
for (unsigned _ : llvm::seq(Num))
129+
addByte(0);
135130
}
136131

137132
void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) {

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

+10-18
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
246246
SmallVector<uint64_t, 16> TempOffsets;
247247

248248
// Special case for i128 - decompose to (i64, i64)
249-
if (Ty->isIntegerTy(128)) {
250-
ValueVTs.push_back(EVT(MVT::i64));
251-
ValueVTs.push_back(EVT(MVT::i64));
249+
if (Ty->isIntegerTy(128) || Ty->isFP128Ty()) {
250+
ValueVTs.append({MVT::i64, MVT::i64});
252251

253-
if (Offsets) {
254-
Offsets->push_back(StartingOffset + 0);
255-
Offsets->push_back(StartingOffset + 8);
256-
}
252+
if (Offsets)
253+
Offsets->append({StartingOffset + 0, StartingOffset + 8});
257254

258255
return;
259256
}
@@ -1165,11 +1162,6 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
11651162
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
11661163
}
11671164

1168-
static bool IsTypePassedAsArray(const Type *Ty) {
1169-
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
1170-
Ty->isHalfTy() || Ty->isBFloatTy();
1171-
}
1172-
11731165
std::string NVPTXTargetLowering::getPrototype(
11741166
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
11751167
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1186,7 +1178,7 @@ std::string NVPTXTargetLowering::getPrototype(
11861178
} else {
11871179
O << "(";
11881180
if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
1189-
!IsTypePassedAsArray(retTy)) {
1181+
!shouldPassAsArray(retTy)) {
11901182
unsigned size = 0;
11911183
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
11921184
size = ITy->getBitWidth();
@@ -1203,7 +1195,7 @@ std::string NVPTXTargetLowering::getPrototype(
12031195
O << ".param .b" << size << " _";
12041196
} else if (isa<PointerType>(retTy)) {
12051197
O << ".param .b" << PtrVT.getSizeInBits() << " _";
1206-
} else if (IsTypePassedAsArray(retTy)) {
1198+
} else if (shouldPassAsArray(retTy)) {
12071199
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
12081200
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
12091201
} else {
@@ -1224,7 +1216,7 @@ std::string NVPTXTargetLowering::getPrototype(
12241216
first = false;
12251217

12261218
if (!Outs[OIdx].Flags.isByVal()) {
1227-
if (IsTypePassedAsArray(Ty)) {
1219+
if (shouldPassAsArray(Ty)) {
12281220
Align ParamAlign =
12291221
getArgumentAlignment(&CB, Ty, i + AttributeList::FirstArgIndex, DL);
12301222
O << ".param .align " << ParamAlign.value() << " .b8 ";
@@ -1529,7 +1521,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15291521
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
15301522

15311523
bool NeedAlign; // Does argument declaration specify alignment?
1532-
bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
1524+
const bool PassAsArray = IsByVal || shouldPassAsArray(Ty);
15331525
if (IsVAArg) {
15341526
if (ParamCount == FirstVAArg) {
15351527
SDValue DeclareParamOps[] = {
@@ -1718,7 +1710,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
17181710
// .param .align N .b8 retval0[<size-in-bytes>], or
17191711
// .param .b<size-in-bits> retval0
17201712
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
1721-
if (!IsTypePassedAsArray(RetTy)) {
1713+
if (!shouldPassAsArray(RetTy)) {
17221714
resultsz = promoteScalarArgumentSize(resultsz);
17231715
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
17241716
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -3362,7 +3354,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
33623354

33633355
if (theArgs[i]->use_empty()) {
33643356
// argument is dead
3365-
if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
3357+
if (shouldPassAsArray(Ty) && !Ty->isVectorTy()) {
33663358
SmallVector<EVT, 16> vtparts;
33673359

33683360
ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,4 @@ bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM) {
372372
!isKernelFunction(*F);
373373
}
374374

375-
bool Isv2x16VT(EVT VT) {
376-
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
377-
}
378-
379375
} // namespace llvm

llvm/lib/Target/NVPTX/NVPTXUtilities.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,14 @@ inline unsigned promoteScalarArgumentSize(unsigned size) {
8484

8585
bool shouldEmitPTXNoReturn(const Value *V, const TargetMachine &TM);
8686

87-
bool Isv2x16VT(EVT VT);
87+
inline bool Isv2x16VT(EVT VT) {
88+
return (VT == MVT::v2f16 || VT == MVT::v2bf16 || VT == MVT::v2i16);
89+
}
90+
91+
inline bool shouldPassAsArray(Type *Ty) {
92+
return Ty->isAggregateType() || Ty->isVectorTy() ||
93+
Ty->getScalarSizeInBits() == 128 || Ty->isHalfTy() || Ty->isBFloatTy();
94+
}
8895

8996
namespace NVPTX {
9097
inline std::string getValidPTXIdentifier(StringRef Name) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_20 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s-mcpu=sm_20 | %ptxas-verify %}
4+
5+
target triple = "nvptx64-unknown-cuda"
6+
7+
define fp128 @identity(fp128 %x) {
8+
; CHECK-LABEL: identity(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b64 %rd<3>;
11+
; CHECK-EMPTY:
12+
; CHECK-NEXT: // %bb.0:
13+
; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [identity_param_0];
14+
; CHECK-NEXT: st.param.v2.b64 [func_retval0], {%rd1, %rd2};
15+
; CHECK-NEXT: ret;
16+
ret fp128 %x
17+
}
18+
19+
define void @load_store(ptr %in, ptr %out) {
20+
; CHECK-LABEL: load_store(
21+
; CHECK: {
22+
; CHECK-NEXT: .reg .b64 %rd<5>;
23+
; CHECK-EMPTY:
24+
; CHECK-NEXT: // %bb.0:
25+
; CHECK-NEXT: ld.param.u64 %rd1, [load_store_param_0];
26+
; CHECK-NEXT: ld.u64 %rd2, [%rd1+8];
27+
; CHECK-NEXT: ld.u64 %rd3, [%rd1];
28+
; CHECK-NEXT: ld.param.u64 %rd4, [load_store_param_1];
29+
; CHECK-NEXT: st.u64 [%rd4], %rd3;
30+
; CHECK-NEXT: st.u64 [%rd4+8], %rd2;
31+
; CHECK-NEXT: ret;
32+
%val = load fp128, ptr %in
33+
store fp128 %val, ptr %out
34+
ret void
35+
}
36+
37+
define void @call(fp128 %x) {
38+
; CHECK-LABEL: call(
39+
; CHECK: {
40+
; CHECK-NEXT: .reg .b64 %rd<3>;
41+
; CHECK-EMPTY:
42+
; CHECK-NEXT: // %bb.0:
43+
; CHECK-NEXT: ld.param.v2.u64 {%rd1, %rd2}, [call_param_0];
44+
; CHECK-NEXT: { // callseq 0, 0
45+
; CHECK-NEXT: .param .align 16 .b8 param0[16];
46+
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd1, %rd2};
47+
; CHECK-NEXT: call.uni
48+
; CHECK-NEXT: call,
49+
; CHECK-NEXT: (
50+
; CHECK-NEXT: param0
51+
; CHECK-NEXT: );
52+
; CHECK-NEXT: } // callseq 0
53+
; CHECK-NEXT: ret;
54+
call void @call(fp128 %x)
55+
ret void
56+
}

llvm/test/CodeGen/NVPTX/global-variable-big.ll

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
55
target triple = "nvptx64-nvidia-cuda"
66

7-
; Check that we can handle global variables of large integer type.
7+
; Check that we can handle global variables of large integer and fp128 type.
88

99
; (lsb) 0x0102'0304'0506...0F10 (msb)
1010
@gv = addrspace(1) externally_initialized global i128 21345817372864405881847059188222722561, align 16
1111
; CHECK: .visible .global .align 16 .b8 gv[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
1212

13+
@gv_fp128 = addrspace(1) externally_initialized global fp128 0xL0807060504030201100F0E0D0C0B0A09, align 16
14+
; CHECK: .visible .global .align 16 .b8 gv_fp128[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
15+
1316
; Make sure that we do not overflow on large number of elements.
1417
; CHECK: .visible .global .align 1 .b8 large_data[4831838208];
1518
@large_data = global [4831838208 x i8] zeroinitializer

0 commit comments

Comments
 (0)