Skip to content

Commit 60ca463

Browse files
committed
[AArch64][PAC] Move emission of LR checks in tail calls to AsmPrinter
Move the emission of the checks performed on the authenticated LR value during tail calls to AArch64AsmPrinter class, so that different checker sequences can be reused by pseudo instructions expanded there. This adds one more option to AuthCheckMethod enumeration, the generic XPAC variant which is not restricted to checking the LR register. Shorten the generic XPAC-based non-trapping sequence by one `mov` instruction: perform XPAC on the tested register itself instead of the scratch one as XPACLRI cannot operate on the scratch register anyway.
1 parent f788c67 commit 60ca463

9 files changed

+192
-303
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

+112-31
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class AArch64AsmPrinter : public AsmPrinter {
153153
void emitPtrauthCheckAuthenticatedValue(Register TestedReg,
154154
Register ScratchReg,
155155
AArch64PACKey::ID Key,
156+
AArch64PAuth::AuthCheckMethod Method,
156157
bool ShouldTrap,
157158
const MCSymbol *OnFailure);
158159

@@ -1731,7 +1732,8 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
17311732
/// of proceeding to the next instruction (only if ShouldTrap is false).
17321733
void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
17331734
Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key,
1734-
bool ShouldTrap, const MCSymbol *OnFailure) {
1735+
AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap,
1736+
const MCSymbol *OnFailure) {
17351737
// Insert a sequence to check if authentication of TestedReg succeeded,
17361738
// such as:
17371739
//
@@ -1757,38 +1759,70 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
17571759
// Lsuccess:
17581760
// ...
17591761
//
1760-
// This sequence is expensive, but we need more information to be able to
1761-
// do better.
1762-
//
1763-
// We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
1764-
// on failure.
1765-
// We can't TST the PAC bits because we don't always know how the address
1766-
// space is setup for the target environment (and the bottom PAC bit is
1767-
// based on that).
1768-
// Either way, we also don't always know whether TBI is enabled or not for
1769-
// the specific target environment.
1762+
// See the documentation on AuthCheckMethod enumeration constants for
1763+
// the specific code sequences that can be used to perform the check.
1764+
using AArch64PAuth::AuthCheckMethod;
17701765

1771-
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1766+
if (Method == AuthCheckMethod::None)
1767+
return;
1768+
if (Method == AuthCheckMethod::DummyLoad) {
1769+
EmitToStreamer(MCInstBuilder(AArch64::LDRWui)
1770+
.addReg(getWRegFromXReg(ScratchReg))
1771+
.addReg(TestedReg)
1772+
.addImm(0));
1773+
assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error");
1774+
return;
1775+
}
17721776

17731777
MCSymbol *SuccessSym = createTempSymbol("auth_success_");
1778+
if (Method == AuthCheckMethod::XPAC || Method == AuthCheckMethod::XPACHint) {
1779+
// mov Xscratch, Xtested
1780+
emitMovXReg(ScratchReg, TestedReg);
17741781

1775-
// mov Xscratch, Xtested
1776-
emitMovXReg(ScratchReg, TestedReg);
1777-
1778-
// xpac(i|d) Xscratch
1779-
EmitToStreamer(MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
1782+
if (Method == AuthCheckMethod::XPAC) {
1783+
// xpac(i|d) Xscratch
1784+
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1785+
EmitToStreamer(
1786+
MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
1787+
} else {
1788+
// xpaclri
1789+
1790+
// Note that this method applies XPAC to TestedReg instead of ScratchReg.
1791+
assert(TestedReg == AArch64::LR &&
1792+
"XPACHint mode is only compatible with checking the LR register");
1793+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1794+
"XPACHint mode is only compatible with I-keys");
1795+
EmitToStreamer(MCInstBuilder(AArch64::XPACLRI));
1796+
}
17801797

1781-
// cmp Xtested, Xscratch
1782-
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
1783-
.addReg(AArch64::XZR)
1784-
.addReg(TestedReg)
1785-
.addReg(ScratchReg)
1786-
.addImm(0));
1798+
// cmp Xtested, Xscratch
1799+
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
1800+
.addReg(AArch64::XZR)
1801+
.addReg(TestedReg)
1802+
.addReg(ScratchReg)
1803+
.addImm(0));
17871804

1788-
// b.eq Lsuccess
1789-
EmitToStreamer(MCInstBuilder(AArch64::Bcc)
1790-
.addImm(AArch64CC::EQ)
1791-
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1805+
// b.eq Lsuccess
1806+
EmitToStreamer(
1807+
MCInstBuilder(AArch64::Bcc)
1808+
.addImm(AArch64CC::EQ)
1809+
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1810+
} else if (Method == AuthCheckMethod::HighBitsNoTBI) {
1811+
// eor Xscratch, Xtested, Xtested, lsl #1
1812+
EmitToStreamer(MCInstBuilder(AArch64::EORXrs)
1813+
.addReg(ScratchReg)
1814+
.addReg(TestedReg)
1815+
.addReg(TestedReg)
1816+
.addImm(1));
1817+
// tbz Xscratch, #62, Lsuccess
1818+
EmitToStreamer(
1819+
MCInstBuilder(AArch64::TBZX)
1820+
.addReg(ScratchReg)
1821+
.addImm(62)
1822+
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1823+
} else {
1824+
llvm_unreachable("Unsupported check method");
1825+
}
17921826

17931827
if (ShouldTrap) {
17941828
assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap");
@@ -1802,9 +1836,26 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
18021836
// Note that this can introduce an authentication oracle (such as based on
18031837
// the high bits of the re-signed value).
18041838

1805-
// FIXME: Can we simply return the AUT result, already in TestedReg?
1806-
// mov Xtested, Xscratch
1807-
emitMovXReg(TestedReg, ScratchReg);
1839+
// FIXME: The XPAC method can be optimized by applying XPAC to TestedReg
1840+
// instead of ScratchReg, thus eliminating one `mov` instruction.
1841+
// Both XPAC and XPACHint can be further optimized by not using a
1842+
// conditional branch jumping over an unconditional one.
1843+
1844+
switch (Method) {
1845+
case AuthCheckMethod::XPACHint:
1846+
// LR is already XPAC-ed at this point.
1847+
break;
1848+
case AuthCheckMethod::XPAC:
1849+
// mov Xtested, Xscratch
1850+
emitMovXReg(TestedReg, ScratchReg);
1851+
break;
1852+
default:
1853+
// If Xtested was not XPAC-ed so far, emit XPAC here.
1854+
// xpac(i|d) Xtested
1855+
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1856+
EmitToStreamer(
1857+
MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg));
1858+
}
18081859

18091860
if (OnFailure) {
18101861
// b Lend
@@ -1830,7 +1881,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18301881
// ; sign x16 (if AUTPAC)
18311882
// Lend: ; if not trapping on failure
18321883
//
1833-
// with the checking sequence chosen depending on whether we should check
1884+
// with the checking sequence chosen depending on whether/how we should check
18341885
// the pointer and whether we should trap on failure.
18351886

18361887
// By default, auth/resign sequences check for auth failures.
@@ -1890,6 +1941,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18901941
EndSym = createTempSymbol("resign_end_");
18911942

18921943
emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AUTKey,
1944+
AArch64PAuth::AuthCheckMethod::XPAC,
18931945
ShouldTrap, EndSym);
18941946
}
18951947

@@ -2260,11 +2312,34 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
22602312
OutStreamer->emitLabel(LOHLabel);
22612313
}
22622314

2315+
// With Pointer Authentication, it may be needed to explicitly check the
2316+
// authenticated value in LR when performing a tail call.
2317+
// Otherwise, the callee may re-sign the invalid return address,
2318+
// introducing a signing oracle.
2319+
auto CheckLRInTailCall = [this](Register CallDestinationReg) {
2320+
if (!AArch64FI->shouldSignReturnAddress(*MF))
2321+
return;
2322+
2323+
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
2324+
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
2325+
return;
2326+
2327+
Register ScratchReg =
2328+
CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
2329+
AArch64PACKey::ID Key =
2330+
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
2331+
emitPtrauthCheckAuthenticatedValue(
2332+
AArch64::LR, ScratchReg, Key, LRCheckMethod,
2333+
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
2334+
};
2335+
22632336
AArch64TargetStreamer *TS =
22642337
static_cast<AArch64TargetStreamer *>(OutStreamer->getTargetStreamer());
22652338
// Do any manual lowerings.
22662339
switch (MI->getOpcode()) {
22672340
default:
2341+
assert(!AArch64InstrInfo::isTailCallReturnInst(*MI) &&
2342+
"Unhandled tail call instruction");
22682343
break;
22692344
case AArch64::HINT: {
22702345
// CurrentPatchableFunctionEntrySym can be CurrentFnBegin only for
@@ -2404,6 +2479,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
24042479
? AArch64::X17
24052480
: AArch64::X16;
24062481

2482+
CheckLRInTailCall(MI->getOperand(0).getReg());
2483+
24072484
unsigned DiscReg = AddrDisc;
24082485
if (Disc) {
24092486
if (AddrDisc != AArch64::NoRegister) {
@@ -2434,13 +2511,17 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
24342511
case AArch64::TCRETURNrix17:
24352512
case AArch64::TCRETURNrinotx16:
24362513
case AArch64::TCRETURNriALL: {
2514+
CheckLRInTailCall(MI->getOperand(0).getReg());
2515+
24372516
MCInst TmpInst;
24382517
TmpInst.setOpcode(AArch64::BR);
24392518
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
24402519
EmitToStreamer(*OutStreamer, TmpInst);
24412520
return;
24422521
}
24432522
case AArch64::TCRETURNdi: {
2523+
CheckLRInTailCall(AArch64::NoRegister);
2524+
24442525
MCOperand Dest;
24452526
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
24462527
MCInst TmpInst;

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
107107
unsigned NumBytes = 0;
108108
const MCInstrDesc &Desc = MI.getDesc();
109109

110+
if (!MI.isBundle() && isTailCallReturnInst(MI)) {
111+
NumBytes = Desc.getSize() ? Desc.getSize() : 4;
112+
113+
const auto *MFI = MF->getInfo<AArch64FunctionInfo>();
114+
if (!MFI->shouldSignReturnAddress(MF))
115+
return NumBytes;
116+
117+
auto &STI = MF->getSubtarget<AArch64Subtarget>();
118+
auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
119+
NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
120+
return NumBytes;
121+
}
122+
110123
// Size should be preferably set in
111124
// llvm/lib/Target/AArch64/AArch64InstrInfo.td (default case).
112125
// Specific cases handle instructions of variable sizes

llvm/lib/Target/AArch64/AArch64InstrInfo.td

+2
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,8 @@ let Predicates = [HasPAuth] in {
19031903
}
19041904

19051905
// Size 16: 4 fixed + 8 variable, to compute discriminator.
1906+
// The size returned by getInstSizeInBytes() is incremented according
1907+
// to the variant of LR check.
19061908
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
19071909
Uses = [SP] in {
19081910
def AUTH_TCRETURN

0 commit comments

Comments
 (0)