@@ -249,11 +249,6 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
249
249
return MCOperand::createExpr (Expr);
250
250
}
251
251
252
- static bool ShouldPassAsArray (Type *Ty) {
253
- return Ty->isAggregateType () || Ty->isVectorTy () || Ty->isIntegerTy (128 ) ||
254
- Ty->isHalfTy () || Ty->isBFloatTy ();
255
- }
256
-
257
252
void NVPTXAsmPrinter::printReturnValStr (const Function *F, raw_ostream &O) {
258
253
const DataLayout &DL = getDataLayout ();
259
254
const NVPTXSubtarget &STI = TM.getSubtarget <NVPTXSubtarget>(*F);
@@ -264,26 +259,21 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
264
259
return ;
265
260
O << " (" ;
266
261
267
- if ((Ty->isFloatingPointTy () || Ty->isIntegerTy ()) &&
268
- !ShouldPassAsArray (Ty)) {
269
- unsigned size = 0 ;
270
- if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
271
- size = ITy->getBitWidth ();
272
- } else {
273
- assert (Ty->isFloatingPointTy () && " Floating point type expected here" );
274
- size = Ty->getPrimitiveSizeInBits ();
275
- }
276
- size = promoteScalarArgumentSize (size);
277
- O << " .param .b" << size << " func_retval0" ;
278
- } else if (isa<PointerType>(Ty)) {
279
- O << " .param .b" << TLI->getPointerTy (DL).getSizeInBits ()
280
- << " func_retval0" ;
281
- } else if (ShouldPassAsArray (Ty)) {
282
- unsigned totalsz = DL.getTypeAllocSize (Ty);
283
- Align RetAlignment = TLI->getFunctionArgumentAlignment (
262
+ auto PrintScalarRetVal = [&](unsigned Size ) {
263
+ O << " .param .b" << promoteScalarArgumentSize (Size ) << " func_retval0" ;
264
+ };
265
+ if (shouldPassAsArray (Ty)) {
266
+ const unsigned TotalSize = DL.getTypeAllocSize (Ty);
267
+ const Align RetAlignment = TLI->getFunctionArgumentAlignment (
284
268
F, Ty, AttributeList::ReturnIndex, DL);
285
269
O << " .param .align " << RetAlignment.value () << " .b8 func_retval0["
286
- << totalsz << " ]" ;
270
+ << TotalSize << " ]" ;
271
+ } else if (Ty->isFloatingPointTy ()) {
272
+ PrintScalarRetVal (Ty->getPrimitiveSizeInBits ());
273
+ } else if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
274
+ PrintScalarRetVal (ITy->getBitWidth ());
275
+ } else if (isa<PointerType>(Ty)) {
276
+ PrintScalarRetVal (TLI->getPointerTy (DL).getSizeInBits ());
287
277
} else
288
278
llvm_unreachable (" Unknown return type" );
289
279
O << " ) " ;
@@ -975,8 +965,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
975
965
O << " .align "
976
966
<< GVar->getAlign ().value_or (DL.getPrefTypeAlign (ETy)).value ();
977
967
978
- if (ETy->isFloatingPointTy () || ETy->isPointerTy () ||
979
- (ETy-> isIntegerTy () && ETy->getScalarSizeInBits () <= 64 )) {
968
+ if (ETy->isPointerTy () || (( ETy->isIntegerTy () || ETy-> isFloatingPointTy ()) &&
969
+ ETy->getScalarSizeInBits () <= 64 )) {
980
970
O << " ." ;
981
971
// Special case: ABI requires that we use .u8 for predicates
982
972
if (ETy->isIntegerTy (1 ))
@@ -1016,6 +1006,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1016
1006
// and vectors are lowered into arrays of bytes.
1017
1007
switch (ETy->getTypeID ()) {
1018
1008
case Type::IntegerTyID: // Integers larger than 64 bits
1009
+ case Type::FP128TyID:
1019
1010
case Type::StructTyID:
1020
1011
case Type::ArrayTyID:
1021
1012
case Type::FixedVectorTyID: {
@@ -1266,8 +1257,8 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1266
1257
O << " .align "
1267
1258
<< GVar->getAlign ().value_or (DL.getPrefTypeAlign (ETy)).value ();
1268
1259
1269
- // Special case for i128
1270
- if (ETy->isIntegerTy ( 128 ) ) {
1260
+ // Special case for i128/fp128
1261
+ if (ETy->getScalarSizeInBits () == 128 ) {
1271
1262
O << " .b8 " ;
1272
1263
getSymbol (GVar)->print (O, MAI);
1273
1264
O << " [16]" ;
@@ -1383,7 +1374,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1383
1374
continue ;
1384
1375
}
1385
1376
1386
- if (ShouldPassAsArray (Ty)) {
1377
+ if (shouldPassAsArray (Ty)) {
1387
1378
// Just print .param .align <a> .b8 .param[size];
1388
1379
// <a> = optimal alignment for the element type; always multiple of
1389
1380
// PAL.getParamAlignment
@@ -1682,48 +1673,49 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1682
1673
void NVPTXAsmPrinter::bufferAggregateConstant (const Constant *CPV,
1683
1674
AggBuffer *aggBuffer) {
1684
1675
const DataLayout &DL = getDataLayout ();
1685
- int Bytes;
1676
+
1677
+ auto ExtendBuffer = [](APInt Val, AggBuffer *Buffer) {
1678
+ for (unsigned I : llvm::seq (Val.getBitWidth () / 8 ))
1679
+ Buffer->addByte (Val.extractBitsAsZExtValue (8 , I * 8 ));
1680
+ };
1686
1681
1687
1682
// Integers of arbitrary width
1688
1683
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1689
- APInt Val = CI->getValue ();
1690
- for (unsigned I = 0 , E = DL.getTypeAllocSize (CPV->getType ()); I < E; ++I) {
1691
- uint8_t Byte = Val.getLoBits (8 ).getZExtValue ();
1692
- aggBuffer->addBytes (&Byte , 1 , 1 );
1693
- Val.lshrInPlace (8 );
1694
- }
1684
+ ExtendBuffer (CI->getValue (), aggBuffer);
1695
1685
return ;
1696
1686
}
1697
1687
1688
+ // f128
1689
+ if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1690
+ if (CFP->getType ()->isFP128Ty ()) {
1691
+ ExtendBuffer (CFP->getValueAPF ().bitcastToAPInt (), aggBuffer);
1692
+ return ;
1693
+ }
1694
+ }
1695
+
1698
1696
// Old constants
1699
1697
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1700
- if (CPV->getNumOperands ())
1701
- for (unsigned i = 0 , e = CPV->getNumOperands (); i != e; ++i)
1702
- bufferLEByte (cast<Constant>(CPV->getOperand (i)), 0 , aggBuffer);
1698
+ for (const auto &Op : CPV->operands ())
1699
+ bufferLEByte (cast<Constant>(Op), 0 , aggBuffer);
1703
1700
return ;
1704
1701
}
1705
1702
1706
- if (const ConstantDataSequential *CDS =
1707
- dyn_cast<ConstantDataSequential>(CPV)) {
1708
- if (CDS->getNumElements ())
1709
- for (unsigned i = 0 ; i < CDS->getNumElements (); ++i)
1710
- bufferLEByte (cast<Constant>(CDS->getElementAsConstant (i)), 0 ,
1711
- aggBuffer);
1703
+ if (const auto *CDS = dyn_cast<ConstantDataSequential>(CPV)) {
1704
+ for (unsigned I : llvm::seq (CDS->getNumElements ()))
1705
+ bufferLEByte (cast<Constant>(CDS->getElementAsConstant (I)), 0 , aggBuffer);
1712
1706
return ;
1713
1707
}
1714
1708
1715
1709
if (isa<ConstantStruct>(CPV)) {
1716
1710
if (CPV->getNumOperands ()) {
1717
1711
StructType *ST = cast<StructType>(CPV->getType ());
1718
- for (unsigned i = 0 , e = CPV->getNumOperands (); i != e; ++i) {
1719
- if (i == (e - 1 ))
1720
- Bytes = DL.getStructLayout (ST)->getElementOffset (0 ) +
1721
- DL.getTypeAllocSize (ST) -
1722
- DL.getStructLayout (ST)->getElementOffset (i);
1723
- else
1724
- Bytes = DL.getStructLayout (ST)->getElementOffset (i + 1 ) -
1725
- DL.getStructLayout (ST)->getElementOffset (i);
1726
- bufferLEByte (cast<Constant>(CPV->getOperand (i)), Bytes, aggBuffer);
1712
+ for (unsigned I : llvm::seq (CPV->getNumOperands ())) {
1713
+ int EndOffset = (I + 1 == CPV->getNumOperands ())
1714
+ ? DL.getStructLayout (ST)->getElementOffset (0 ) +
1715
+ DL.getTypeAllocSize (ST)
1716
+ : DL.getStructLayout (ST)->getElementOffset (I + 1 );
1717
+ int Bytes = EndOffset - DL.getStructLayout (ST)->getElementOffset (I);
1718
+ bufferLEByte (cast<Constant>(CPV->getOperand (I)), Bytes, aggBuffer);
1727
1719
}
1728
1720
}
1729
1721
return ;
0 commit comments