Skip to content

Commit 2ce1358

Browse files
committed
Add bridges for overridden methods in lambda indy call
If a SAM trait's abstract method overrides a method in a supertrait while changing the return type, the generated invokedynamic instruction needs to pass the types of the overridden methods to `LambdaMetaFactory` so that bridge methods can be added to the generated lambda. SAM `Function` trees now have a synthetic `ClassSymbol`, into which we enter an overriding concrete method symbol that represents the runtime LMF-generated implementation of that function. `erasure` then computes bridge methods for that class, which are added to the `LMF.metafactory` call in `jvm`. Java does this instead by generating a default method in the subinterface overriding the superinterface's method. Theoretically, we could generate the bridges in the same way, but that has the downside that the project with the interfaces would need recompiled, which might not be the project that creates the lambdas. Generating bridges the LMF way means that only the classes affected by this bug need to be recompiled to get the fix. Honestly I'm surprised that this hasn't come up already. Fixes scala/bug#10512.
1 parent bfa2f8b commit 2ce1358

File tree

10 files changed

+276
-84
lines changed

10 files changed

+276
-84
lines changed

src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala

+25-9
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
626626
case Apply(fun, args) if app.hasAttachment[delambdafy.LambdaMetaFactoryCapable] =>
627627
val attachment = app.attachments.get[delambdafy.LambdaMetaFactoryCapable].get
628628
genLoadArguments(args, paramTKs(app))
629-
genInvokeDynamicLambda(attachment.target, attachment.arity, attachment.functionalInterface, attachment.sam, attachment.isSerializable, attachment.addScalaSerializableMarker)
629+
genInvokeDynamicLambda(attachment)
630630
generatedType = methodBTypeFromSymbol(fun.symbol).returnType
631631

632632
case Apply(fun, List(expr)) if currentRun.runDefinitions.isBox(fun.symbol) =>
@@ -1332,7 +1332,9 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
13321332
def genSynchronized(tree: Apply, expectedType: BType): BType
13331333
def genLoadTry(tree: Try): BType
13341334

1335-
def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, isSerializable: Boolean, addScalaSerializableMarker: Boolean) {
1335+
def genInvokeDynamicLambda(canLMF: delambdafy.LambdaMetaFactoryCapable) = {
1336+
import canLMF._
1337+
13361338
val isStaticMethod = lambdaTarget.hasFlag(Flags.STATIC)
13371339
def asmType(sym: Symbol) = classBTypeFromSymbol(sym).toASMType
13381340

@@ -1349,23 +1351,37 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder {
13491351
val constrainedType = MethodBType(lambdaParams.map(p => typeToBType(p.tpe)), typeToBType(lambdaTarget.tpe.resultType)).toASMType
13501352
val samMethodType = methodBTypeFromSymbol(sam).toASMType
13511353
val markers = if (addScalaSerializableMarker) classBTypeFromSymbol(definitions.SerializableClass).toASMType :: Nil else Nil
1352-
visitInvokeDynamicInsnLMF(bc.jmethod, sam.name.toString, invokedType, samMethodType, implMethodHandle, constrainedType, isSerializable, markers)
1354+
val overriddenMethods =
1355+
bridges.map(b => methodBTypeFromSymbol(b).toASMType)
1356+
visitInvokeDynamicInsnLMF(bc.jmethod, sam.name.toString, invokedType, samMethodType, implMethodHandle, constrainedType, overriddenMethods, isSerializable, markers)
13531357
if (isSerializable)
13541358
addIndyLambdaImplMethod(cnode.name, implMethodHandle)
13551359
}
13561360
}
13571361

13581362
private def visitInvokeDynamicInsnLMF(jmethod: MethodNode, samName: String, invokedType: String, samMethodType: asm.Type,
1359-
implMethodHandle: asm.Handle, instantiatedMethodType: asm.Type,
1360-
serializable: Boolean, markerInterfaces: Seq[asm.Type]) = {
1363+
implMethodHandle: asm.Handle, instantiatedMethodType: asm.Type, overriddenMethodTypes: Seq[asm.Type],
1364+
serializable: Boolean, markerInterfaces: Seq[asm.Type]): Unit = {
13611365
import java.lang.invoke.LambdaMetafactory.{FLAG_BRIDGES, FLAG_MARKERS, FLAG_SERIALIZABLE}
13621366
// scala/bug#10334: make sure that a lambda object for `T => U` has a method `apply(T)U`, not only the `(Object)Object`
13631367
// version. Using the lambda a structural type `{def apply(t: T): U}` causes a reflective lookup for this method.
1364-
val needsBridge = samMethodType != instantiatedMethodType
1365-
val bridges = if (needsBridge) Seq(Int.box(1), instantiatedMethodType) else Nil
1368+
val needsGenericBridge = samMethodType != instantiatedMethodType
1369+
// scala/bug#10512: any methods which `samMethod` overrides need bridges made for them
1370+
// this is done automatically during erasure for classes we generate, but LMF needs to have them explicitly mentioned
1371+
// so we have to compute them at this relatively late point.
1372+
val bridges = (
1373+
if (needsGenericBridge)
1374+
instantiatedMethodType +: overriddenMethodTypes
1375+
else overriddenMethodTypes
1376+
).distinct.filterNot(_ == samMethodType)
1377+
1378+
/* We're saving on precious BSM arg slots by not passing 0 as the bridge count */
1379+
val bridgeArgs = if (bridges.nonEmpty) Int.box(bridges.length) +: bridges else Nil
1380+
13661381
def flagIf(b: Boolean, flag: Int): Int = if (b) flag else 0
1367-
val flags = FLAG_MARKERS | flagIf(serializable, FLAG_SERIALIZABLE) | flagIf(needsBridge, FLAG_BRIDGES)
1368-
val bsmArgs = Seq(samMethodType, implMethodHandle, instantiatedMethodType, Int.box(flags), Int.box(markerInterfaces.length)) ++ markerInterfaces ++ bridges
1382+
val flags = FLAG_MARKERS | flagIf(serializable, FLAG_SERIALIZABLE) | flagIf(bridges.nonEmpty, FLAG_BRIDGES)
1383+
val bsmArgs = Seq(samMethodType, implMethodHandle, instantiatedMethodType, Int.box(flags), Int.box(markerInterfaces.length)) ++ markerInterfaces ++ bridgeArgs
1384+
13691385
jmethod.visitInvokeDynamicInsn(samName, invokedType, lambdaMetaFactoryAltMetafactoryHandle, bsmArgs: _*)
13701386
}
13711387

src/compiler/scala/tools/nsc/transform/Delambdafy.scala

+19-5
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
2828
/** the following two members override abstract members in Transform */
2929
val phaseName: String = "delambdafy"
3030

31-
final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, isSerializable: Boolean, addScalaSerializableMarker: Boolean)
31+
final case class LambdaMetaFactoryCapable(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, bridges: List[Symbol], isSerializable: Boolean, addScalaSerializableMarker: Boolean)
3232

3333
/**
3434
* Get the symbol of the target lifted lambda body method from a function. I.e. if
@@ -59,7 +59,10 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
5959
private[this] lazy val methodReferencesThis: Set[Symbol] =
6060
(new ThisReferringMethodsTraverser).methodReferencesThisIn(unit.body)
6161

62-
private def mkLambdaMetaFactoryCall(fun: Function, target: Symbol, functionalInterface: Symbol, samUserDefined: Symbol, isSpecialized: Boolean): Tree = {
62+
private def mkLambdaMetaFactoryCall(fun: Function, target: Symbol, functionalInterface: Symbol, samUserDefined: Symbol, userSamCls: Symbol, isSpecialized: Boolean): Tree = {
63+
/* user-defined SAM types should have gotten a class symbol made for them in `typer` */
64+
assert(isFunctionType(fun.tpe) || (samUserDefined.exists && userSamCls.isClass), s"$fun / ${fun.symbol} / ${fun.tpe}")
65+
6366
val pos = fun.pos
6467
def isSelfParam(p: Symbol) = p.isSynthetic && p.name == nme.SELF
6568
val hasSelfParam = isSelfParam(target.firstParam)
@@ -98,13 +101,21 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
98101
val isSerializable = samUserDefined == NoSymbol || samUserDefined.owner.isNonBottomSubClass(definitions.JavaSerializableClass)
99102
val addScalaSerializableMarker = samUserDefined == NoSymbol
100103

104+
val samBridges = logResultIf[List[Symbol]](s"will add SAM bridges for $fun", _.nonEmpty) {
105+
userSamCls.fold[List[Symbol]](Nil) {
106+
_.info.decls.collect {
107+
case bridge: MethodSymbol if bridge hasFlag BRIDGE => bridge
108+
}.toList
109+
}
110+
}
111+
101112
// The backend needs to know the target of the lambda and the functional interface in order
102113
// to emit the invokedynamic instruction. We pass this information as tree attachment.
103114
//
104115
// see https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html
105116
// instantiatedMethodType is derived from lambdaTarget's signature
106117
// samMethodType is derived from samOf(functionalInterface)'s signature
107-
apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam, isSerializable, addScalaSerializableMarker))
118+
apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam, samBridges, isSerializable, addScalaSerializableMarker))
108119

109120
apply
110121
}
@@ -254,8 +265,11 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
254265
(functionalInterface, isSpecialized)
255266
}
256267

257-
val sam = originalFunction.attachments.get[SAMFunction].map(_.sam).getOrElse(NoSymbol)
258-
mkLambdaMetaFactoryCall(originalFunction, target, functionalInterface, sam, isSpecialized)
268+
val (sam, samCls) = originalFunction.attachments.get[SAMFunction] match {
269+
case Some(SAMFunction(_, sam, samCls)) => (sam, samCls)
270+
case None => (NoSymbol, NoSymbol)
271+
}
272+
mkLambdaMetaFactoryCall(originalFunction, target, functionalInterface, sam, samCls, isSpecialized)
259273
}
260274

261275
// here's the main entry point of the transform

src/compiler/scala/tools/nsc/transform/Erasure.scala

+48-19
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ abstract class Erasure extends InfoTransform
467467

468468
override def newTyper(context: Context) = new Eraser(context)
469469

470-
class ComputeBridges(unit: CompilationUnit, root: Symbol) {
470+
class EnterBridges(unit: CompilationUnit, root: Symbol) {
471471

472472
class BridgesCursor(root: Symbol) extends overridingPairs.Cursor(root) {
473473
override def parents = root.info.firstParent :: Nil
@@ -478,22 +478,19 @@ abstract class Erasure extends InfoTransform
478478
override def exclude(sym: Symbol) = !sym.isMethod || super.exclude(sym)
479479
}
480480

481-
var toBeRemoved = immutable.Set[Symbol]()
482481
val site = root.thisType
483482
val bridgesScope = newScope
484483
val bridgeTarget = mutable.HashMap[Symbol, Symbol]()
485-
var bridges = List[Tree]()
486484

487485
val opc = enteringExplicitOuter { new BridgesCursor(root) }
488486

489-
def compute(): (List[Tree], immutable.Set[Symbol]) = {
487+
def computeAndEnter(includeDeferred: Boolean): Unit = {
490488
while (opc.hasNext) {
491-
if (enteringExplicitOuter(!opc.low.isDeferred))
492-
checkPair(opc.currentPair)
489+
if (includeDeferred || enteringExplicitOuter(!opc.low.isDeferred))
490+
checkPair(opc. currentPair)
493491

494492
opc.next()
495493
}
496-
(bridges, toBeRemoved)
497494
}
498495

499496
/** Check that a bridge only overrides members that are also overridden by the original member.
@@ -597,17 +594,35 @@ abstract class Erasure extends InfoTransform
597594

598595
if (shouldAdd) {
599596
exitingErasure(root.info.decls enter bridge)
600-
if (other.owner == root) {
601-
exitingErasure(root.info.decls.unlink(other))
602-
toBeRemoved += other
603-
}
604597

605598
bridgesScope enter bridge
606-
bridges ::= makeBridgeDefDef(bridge, member, other)
599+
addBridge(bridge, member, other)
600+
//bridges ::= makeBridgeDefDef(bridge, member, other)
607601
}
608602
}
609603

610-
def makeBridgeDefDef(bridge: Symbol, member: Symbol, other: Symbol) = exitingErasure {
604+
protected def addBridge(bridge: Symbol, member: Symbol, other: Symbol) {} // hook for GenerateBridges
605+
}
606+
607+
class GenerateBridges(unit: CompilationUnit, root: Symbol) extends EnterBridges(unit, root) {
608+
609+
var bridges = List.empty[Tree]
610+
var toBeRemoved = immutable.Set.empty[Symbol]
611+
612+
def generate(includeDeferred: Boolean): (List[Tree], immutable.Set[Symbol]) = {
613+
super.computeAndEnter(includeDeferred)
614+
(bridges, toBeRemoved)
615+
}
616+
617+
override def addBridge(bridge: Symbol, member: Symbol, other: Symbol): Unit = {
618+
if (other.owner == root) {
619+
exitingErasure(root.info.decls.unlink(other))
620+
toBeRemoved += other
621+
}
622+
bridges ::= makeBridgeDefDef(bridge, member, other)
623+
}
624+
625+
final def makeBridgeDefDef(bridge: Symbol, member: Symbol, other: Symbol) = exitingErasure {
611626
// type checking ensures we can safely call `other`, but unless `member.tpe <:< other.tpe`,
612627
// calling `member` is not guaranteed to succeed in general, there's
613628
// nothing we can do about this, except for an unapply: when this subtype test fails,
@@ -617,8 +632,8 @@ abstract class Erasure extends InfoTransform
617632
// does the first argument list have exactly one argument -- for user-defined unapplies we can't be sure
618633
def maybeWrap(bridgingCall: Tree): Tree = {
619634
val guardExtractor = ( // can't statically know which member is going to be selected, so don't let this depend on member.isSynthetic
620-
(member.name == nme.unapply || member.name == nme.unapplySeq)
621-
&& !exitingErasure((member.tpe <:< other.tpe))) // no static guarantees (TODO: is the subtype test ever true?)
635+
(member.name == nme.unapply || member.name == nme.unapplySeq)
636+
&& !exitingErasure((member.tpe <:< other.tpe))) // no static guarantees (TODO: is the subtype test ever true?)
622637

623638
import CODE._
624639
val _false = FALSE
@@ -643,6 +658,7 @@ abstract class Erasure extends InfoTransform
643658
}
644659
DefDef(bridge, rhs)
645660
}
661+
646662
}
647663

648664
/** The modifier typer which retypes with erased types. */
@@ -806,7 +822,7 @@ abstract class Erasure extends InfoTransform
806822
tree1 match {
807823
case fun: Function =>
808824
fun.attachments.get[SAMFunction] match {
809-
case Some(SAMFunction(samTp, _)) => fun setType specialScalaErasure(samTp)
825+
case Some(SAMFunction(samTp, _, _)) => fun setType specialScalaErasure(samTp)
810826
case _ => fun
811827
}
812828

@@ -934,17 +950,24 @@ abstract class Erasure extends InfoTransform
934950
*/
935951
private def bridgeDefs(owner: Symbol): (List[Tree], immutable.Set[Symbol]) = {
936952
assert(phase == currentRun.erasurePhase, phase)
937-
new ComputeBridges(unit, owner) compute()
953+
new GenerateBridges(unit, owner).generate(includeDeferred = false)
938954
}
939955

940-
def addBridges(stats: List[Tree], base: Symbol): List[Tree] =
956+
def addBridgesToTemplate(stats: List[Tree], base: Symbol): List[Tree] =
941957
if (base.isTrait) stats
942958
else {
943959
val (bridges, toBeRemoved) = bridgeDefs(base)
944960
if (bridges.isEmpty) stats
945961
else (stats filterNot (stat => toBeRemoved contains stat.symbol)) ::: bridges
946962
}
947963

964+
def addBridgesToLambda(lambdaClass: Symbol): Unit = {
965+
assert(phase == currentRun.erasurePhase, phase)
966+
assert(lambdaClass.isClass, lambdaClass)
967+
val eb = new EnterBridges(unit, lambdaClass)
968+
eb.computeAndEnter(includeDeferred = true)
969+
}
970+
948971
/** Transform tree at phase erasure before retyping it.
949972
* This entails the following:
950973
*
@@ -1221,7 +1244,7 @@ abstract class Erasure extends InfoTransform
12211244
case Template(parents, self, body) =>
12221245
//Console.println("checking no dble defs " + tree)//DEBUG
12231246
checkNoDoubleDefs(tree.symbol.owner)
1224-
treeCopy.Template(tree, parents, noSelfType, addBridges(body, currentOwner))
1247+
treeCopy.Template(tree, parents, noSelfType, addBridgesToTemplate(body, currentOwner))
12251248

12261249
case Match(selector, cases) =>
12271250
Match(Typed(selector, TypeTree(selector.tpe)), cases)
@@ -1242,6 +1265,12 @@ abstract class Erasure extends InfoTransform
12421265
case TypeDef(_, _, _, _) =>
12431266
EmptyTree
12441267

1268+
case fun: Function =>
1269+
fun.attachments.get[SAMFunction] foreach {
1270+
samf => addBridgesToLambda(samf.samCls)
1271+
}
1272+
fun
1273+
12451274
case _ =>
12461275
tree
12471276
}

src/compiler/scala/tools/nsc/transform/UnCurry.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ abstract class UnCurry extends InfoTransform
8181
private def mustExpandFunction(fun: Function) = {
8282
// (TODO: Can't use isInterface, yet, as it hasn't been updated for the new trait encoding)
8383
val canUseLambdaMetaFactory = (fun.attachments.get[SAMFunction] match {
84-
case Some(SAMFunction(userDefinedSamTp, sam)) =>
84+
case Some(SAMFunction(userDefinedSamTp, sam, _)) =>
8585
// LambdaMetaFactory cannot mix in trait members for us, or instantiate classes -- only pure interfaces need apply
8686
erasure.compilesToPureInterface(erasure.javaErasure(userDefinedSamTp).typeSymbol) &&
8787
// impl restriction -- we currently use the boxed apply, so not really useful to allow specialized sam types (https://github.com/scala/scala/pull/4971#issuecomment-198119167)

0 commit comments

Comments
 (0)