Skip to content

Commit a49bda7

Browse files
committed
[Compiler plugin] Use resolved type argument T of Iterable<T>.toDataFrame() _call_ instead of one from the return type of receiver *iterable*.toDataFrame()
1 parent 9839b5c commit a49bda7

File tree

4 files changed

+69
-28
lines changed

4 files changed

+69
-28
lines changed

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.jetbrains.kotlin.fir.types.ConeKotlinType
3030
import org.jetbrains.kotlin.fir.types.ConeNullability
3131
import org.jetbrains.kotlin.fir.types.ConeStarProjection
3232
import org.jetbrains.kotlin.fir.types.ConeTypeParameterType
33+
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
3334
import org.jetbrains.kotlin.fir.types.canBeNull
3435
import org.jetbrains.kotlin.fir.types.classId
3536
import org.jetbrains.kotlin.fir.types.coneType
@@ -52,7 +53,6 @@ import org.jetbrains.kotlin.name.FqName
5253
import org.jetbrains.kotlin.name.Name
5354
import org.jetbrains.kotlin.name.StandardClassIds
5455
import org.jetbrains.kotlin.name.StandardClassIds.List
55-
import org.jetbrains.kotlin.types.checker.SimpleClassicTypeSystemContext.withNullability
5656
import org.jetbrains.kotlinx.dataframe.codeGen.*
5757
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
5858
import org.jetbrains.kotlinx.dataframe.plugin.extensions.wrap
@@ -77,27 +77,31 @@ import java.util.*
7777
class ToDataFrameDsl : AbstractSchemaModificationInterpreter() {
7878
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
7979
val Arguments.body by dsl()
80+
val Arguments.typeArg0: ConeTypeProjection? by arg(lens = Interpreter.Id)
81+
8082
override fun Arguments.interpret(): PluginDataFrameSchema {
8183
val dsl = CreateDataFrameDslImplApproximation()
82-
body(dsl, mapOf("explicitReceiver" to Interpreter.Success(receiver)))
84+
body(dsl, mapOf("typeArg0" to Interpreter.Success(typeArg0)))
8385
return PluginDataFrameSchema(dsl.columns)
8486
}
8587
}
8688

8789
class ToDataFrame : AbstractSchemaModificationInterpreter() {
8890
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
8991
val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH))
92+
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)
9093

9194
override fun Arguments.interpret(): PluginDataFrameSchema {
92-
return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration())
95+
return toDataFrame(maxDepth.toInt(), typeArg0, TraverseConfiguration())
9396
}
9497
}
9598

9699
class ToDataFrameDefault : AbstractSchemaModificationInterpreter() {
97100
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
101+
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)
98102

99103
override fun Arguments.interpret(): PluginDataFrameSchema {
100-
return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration())
104+
return toDataFrame(DEFAULT_MAX_DEPTH, typeArg0, TraverseConfiguration())
101105
}
102106
}
103107

@@ -115,14 +119,14 @@ private const val DEFAULT_MAX_DEPTH = 0
115119

116120
class Properties0 : AbstractInterpreter<Unit>() {
117121
val Arguments.dsl: CreateDataFrameDslImplApproximation by arg()
118-
val Arguments.explicitReceiver: FirExpression? by arg()
119122
val Arguments.maxDepth: Int by arg()
120123
val Arguments.body by dsl()
124+
val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id)
121125

122126
override fun Arguments.interpret() {
123127
dsl.configuration.maxDepth = maxDepth
124128
body(dsl.configuration.traverseConfiguration, emptyMap())
125-
val schema = toDataFrame(dsl.configuration.maxDepth, explicitReceiver, dsl.configuration.traverseConfiguration)
129+
val schema = toDataFrame(dsl.configuration.maxDepth, typeArg0, dsl.configuration.traverseConfiguration)
126130
dsl.columns.addAll(schema.columns())
127131
}
128132
}
@@ -178,8 +182,8 @@ class Exclude1 : AbstractInterpreter<Unit>() {
178182
@OptIn(SymbolInternals::class)
179183
internal fun KotlinTypeFacade.toDataFrame(
180184
maxDepth: Int,
181-
explicitReceiver: FirExpression?,
182-
traverseConfiguration: TraverseConfiguration
185+
arg: ConeTypeProjection,
186+
traverseConfiguration: TraverseConfiguration,
183187
): PluginDataFrameSchema {
184188
fun ConeKotlinType.isValueType() =
185189
this.isArrayTypeOrNullableArrayType ||
@@ -290,8 +294,6 @@ internal fun KotlinTypeFacade.toDataFrame(
290294
}
291295
}
292296

293-
val receiver = explicitReceiver ?: return PluginDataFrameSchema.EMPTY
294-
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema.EMPTY
295297
return when {
296298
arg.isStarProjection -> PluginDataFrameSchema.EMPTY
297299
else -> {

plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,29 +90,43 @@ fun <T> KotlinTypeFacade.interpret(
9090
val refinedArguments: RefinedArguments = functionCall.collectArgumentExpressions()
9191

9292
val defaultArguments = processor.expectedArguments.filter { it.defaultValue is Present }.map { it.name }.toSet()
93-
val actualArgsMap = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
94-
val conflictingKeys = additionalArguments.keys intersect actualArgsMap.keys
93+
val actualValueArguments = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
94+
val conflictingKeys = additionalArguments.keys intersect actualValueArguments.keys
9595
if (conflictingKeys.isNotEmpty()) {
9696
if (isTest) {
9797
interpretationFrameworkError("Conflicting keys: $conflictingKeys")
9898
}
9999
return null
100100
}
101101
val expectedArgsMap = processor.expectedArguments
102-
.filterNot { it.name.startsWith("typeArg") }
103102
.associateBy { it.name }.toSortedMap().minus(additionalArguments.keys)
104103

105-
val unexpectedArguments = expectedArgsMap.keys - defaultArguments != actualArgsMap.keys - defaultArguments
104+
val typeArguments = buildMap {
105+
functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
106+
val key = "typeArg$index"
107+
val lens = expectedArgsMap[key]?.lens ?: return@forEachIndexed
108+
val value: Any = if (lens == Interpreter.Id) {
109+
firTypeProjection.toConeTypeProjection()
110+
} else {
111+
val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
112+
if (type is ConeIntersectionType) return@forEachIndexed
113+
Marker(type)
114+
}
115+
put(key, Interpreter.Success(value))
116+
}
117+
}
118+
119+
val unexpectedArguments = (expectedArgsMap.keys - defaultArguments) != (actualValueArguments.keys + typeArguments.keys - defaultArguments)
106120
if (unexpectedArguments) {
107121
if (isTest) {
108122
val message = buildString {
109123
appendLine("ERROR: Different set of arguments")
110124
appendLine("Implementation class: $processor")
111-
appendLine("Not found in actual: ${expectedArgsMap.keys - actualArgsMap.keys}")
112-
val diff = actualArgsMap.keys - expectedArgsMap.keys
125+
appendLine("Not found in actual: ${expectedArgsMap.keys - actualValueArguments.keys}")
126+
val diff = actualValueArguments.keys - expectedArgsMap.keys
113127
appendLine("Passed, but not expected: ${diff}")
114128
appendLine("add arguments to an interpeter:")
115-
appendLine(diff.map { actualArgsMap[it] })
129+
appendLine(diff.map { actualValueArguments[it] })
116130
}
117131
interpretationFrameworkError(message)
118132
}
@@ -121,6 +135,7 @@ fun <T> KotlinTypeFacade.interpret(
121135

122136
val arguments = mutableMapOf<String, Interpreter.Success<Any?>>()
123137
arguments += additionalArguments
138+
arguments += typeArguments
124139
val interpretationResults = refinedArguments.refinedArguments.mapNotNull {
125140
val name = it.name.identifier
126141
val expectedArgument = expectedArgsMap[name] ?: error("$processor $name")
@@ -269,17 +284,6 @@ fun <T> KotlinTypeFacade.interpret(
269284
value?.let { value1 -> it.name.identifier to value1 }
270285
}
271286

272-
functionCall.typeArguments.forEachIndexed { index, firTypeProjection ->
273-
val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type
274-
if (type is ConeIntersectionType) return@forEachIndexed
275-
// val approximation = TypeApproximationImpl(
276-
// type.classId!!.asFqNameString(),
277-
// type.isMarkedNullable
278-
// )
279-
val approximation = Marker(type)
280-
arguments["typeArg$index"] = Interpreter.Success(approximation)
281-
}
282-
283287
return if (interpretationResults.size == refinedArguments.refinedArguments.size) {
284288
arguments.putAll(interpretationResults)
285289
when (val res = processor.interpret(arguments, this)) {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import org.jetbrains.kotlinx.dataframe.*
2+
import org.jetbrains.kotlinx.dataframe.annotations.*
3+
import org.jetbrains.kotlinx.dataframe.api.*
4+
import org.jetbrains.kotlinx.dataframe.io.*
5+
6+
@DataSchema
7+
data class D(
8+
val s: String
9+
)
10+
11+
class Subtree(
12+
val p: Int,
13+
val l: List<Int>,
14+
val ld: List<D>,
15+
)
16+
17+
class Root(val a: Subtree)
18+
19+
class MyList(val l: List<Root?>): List<Root?> by l
20+
21+
fun box(): String {
22+
val l = listOf(
23+
Root(Subtree(123, listOf(1), listOf(D("ff")))),
24+
null
25+
)
26+
val df = MyList(l).toDataFrame(maxDepth = 2)
27+
df.compareSchemas(strict = true)
28+
return "OK"
29+
}

plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ public void testToDataFrame_column() {
418418
runTest("testData/box/toDataFrame_column.kt");
419419
}
420420

421+
@Test
422+
@TestMetadata("toDataFrame_customIterable.kt")
423+
public void testToDataFrame_customIterable() {
424+
runTest("testData/box/toDataFrame_customIterable.kt");
425+
}
426+
421427
@Test
422428
@TestMetadata("toDataFrame_dataSchema.kt")
423429
public void testToDataFrame_dataSchema() {

0 commit comments

Comments
 (0)