Skip to content

Commit

Permalink
Merge pull request #782 from Kotlin/compiler-plugin-dev
Browse files Browse the repository at this point in the history
Refactor toDataFrame implementation in compiler plugin
  • Loading branch information
koperagen authored Jul 17, 2024
2 parents 0e97e5b + 642a2ee commit ddb9fac
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,19 @@

package org.jetbrains.kotlinx.dataframe.plugin

import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TraverseConfiguration
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeTypeProjection
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.CreateDataFrameDslImplApproximation
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.aggregate
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names.DF_CLASS_ID

fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: InterpretationErrorReporter): CallResult? {
val callReturnType = call.resolvedType
Expand All @@ -38,44 +32,6 @@ fun KotlinTypeFacade.analyzeRefinedCallShape(call: FirFunctionCall, reporter: In

val newSchema: PluginDataFrameSchema = call.interpreterName(session)?.let { name ->
when (name) {
"toDataFrameDsl" -> {
val list = call.argumentList as FirResolvedArgumentList
val lambda = (list.arguments.singleOrNull() as? FirAnonymousFunctionExpression)?.anonymousFunction
val statements = lambda?.body?.statements
if (statements != null) {
val receiver = CreateDataFrameDslImplApproximation()
statements.filterIsInstance<FirFunctionCall>().forEach {
val schemaProcessor = it.loadInterpreter() ?: return@forEach
interpret(
it,
schemaProcessor,
mapOf("dsl" to Interpreter.Success(receiver), "call" to Interpreter.Success(call)),
reporter
)
}
PluginDataFrameSchema(receiver.columns)
} else {
PluginDataFrameSchema(emptyList())
}
}
"toDataFrame" -> {
val list = call.argumentList as FirResolvedArgumentList
val argument = list.mapping.entries.firstOrNull { it.value.name == Name.identifier("maxDepth") }?.key
val maxDepth = when (argument) {
null -> 0
is FirLiteralExpression -> (argument.value as Number).toInt()
else -> null
}
if (maxDepth != null) {
toDataFrame(maxDepth, call, TraverseConfiguration())
} else {
PluginDataFrameSchema(emptyList())
}
}
"toDataFrameDefault" -> {
val maxDepth = 0
toDataFrame(maxDepth, call, TraverseConfiguration())
}
"Aggregate" -> {
val groupByCall = call.explicitReceiver as? FirFunctionCall
val interpreter = groupByCall?.loadInterpreter(session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,11 @@ fun <T> AbstractInterpreter<T>.kproperty(

internal fun <T> AbstractInterpreter<T>.string(
name: ArgumentName? = null
): ExpectedArgumentProvider<String
> = arg(name, lens = Interpreter.Value)
): ExpectedArgumentProvider<String> =
arg(name, lens = Interpreter.Value)

internal fun <T> AbstractInterpreter<T>.dsl(
name: ArgumentName? = null
): ExpectedArgumentProvider<(Any, Map<String, Interpreter.Success<Any?>>) -> Unit> =
arg(name, lens = Interpreter.Dsl, defaultValue = Present(value = {_, _ -> }))

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl
import org.jetbrains.kotlinx.dataframe.plugin.impl.string
import org.jetbrains.kotlinx.dataframe.plugin.impl.type

Expand Down Expand Up @@ -48,11 +49,11 @@ class AddDslApproximation(val columns: MutableList<SimpleCol>)

class AddWithDsl : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.body: (Any) -> Unit by arg(lens = Interpreter.Dsl)
val Arguments.body by dsl()

override fun Arguments.interpret(): PluginDataFrameSchema {
val addDsl = AddDslApproximation(receiver.columns().toMutableList())
body(addDsl)
body(addDsl, emptyMap())
return PluginDataFrameSchema(addDsl.columns)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import org.jetbrains.kotlin.fir.declarations.utils.effectiveVisibility
import org.jetbrains.kotlin.fir.declarations.utils.isEnumClass
import org.jetbrains.kotlin.fir.declarations.utils.isStatic
import org.jetbrains.kotlin.fir.expressions.FirCallableReferenceAccess
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
import org.jetbrains.kotlin.fir.java.JavaTypeParameterStack
Expand Down Expand Up @@ -46,6 +46,7 @@ import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
Expand All @@ -54,26 +55,56 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.dsl
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import java.util.*

class ToDataFrameDsl : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.body by dsl()
override fun Arguments.interpret(): PluginDataFrameSchema {
val dsl = CreateDataFrameDslImplApproximation()
body(dsl, mapOf("explicitReceiver" to Interpreter.Success(receiver)))
return PluginDataFrameSchema(dsl.columns)
}
}

class ToDataFrame : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)
val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH))

override fun Arguments.interpret(): PluginDataFrameSchema {
return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration())
}
}

class ToDataFrameDefault : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id)

override fun Arguments.interpret(): PluginDataFrameSchema {
return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration())
}
}

private const val DEFAULT_MAX_DEPTH = 0

class Properties0 : AbstractInterpreter<Unit>() {
val Arguments.dsl: CreateDataFrameDslImplApproximation by arg()
val Arguments.call: FirFunctionCall by arg()
val Arguments.explicitReceiver: FirExpression? by arg()
val Arguments.maxDepth: Int by arg()
val Arguments.body: (Any) -> Unit by arg(lens = Interpreter.Dsl, defaultValue = Present(value = {}))
val Arguments.body by dsl()

override fun Arguments.interpret() {
dsl.configuration.maxDepth = maxDepth
body(dsl.configuration.traverseConfiguration)
val schema = toDataFrame(dsl.configuration.maxDepth, call, dsl.configuration.traverseConfiguration)
body(dsl.configuration.traverseConfiguration, emptyMap())
val schema = toDataFrame(dsl.configuration.maxDepth, explicitReceiver, dsl.configuration.traverseConfiguration)
dsl.columns.addAll(schema.columns())
}
}

class CreateDataFrameConfiguration {
var maxDepth = 0
var maxDepth = DEFAULT_MAX_DEPTH
var traverseConfiguration: TraverseConfiguration = TraverseConfiguration()
}

Expand Down Expand Up @@ -123,7 +154,7 @@ class Exclude1 : AbstractInterpreter<Unit>() {
@OptIn(SymbolInternals::class)
internal fun KotlinTypeFacade.toDataFrame(
maxDepth: Int,
call: FirFunctionCall,
explicitReceiver: FirExpression?,
traverseConfiguration: TraverseConfiguration
): PluginDataFrameSchema {
fun ConeKotlinType.isValueType() =
Expand Down Expand Up @@ -238,7 +269,7 @@ internal fun KotlinTypeFacade.toDataFrame(
}
}

val receiver = call.explicitReceiver ?: return PluginDataFrameSchema(emptyList())
val receiver = explicitReceiver ?: return PluginDataFrameSchema(emptyList())
val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema(emptyList())
return when {
arg.isStarProjection -> PluginDataFrameSchema(emptyList())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ fun <T> KotlinTypeFacade.interpret(
val actualArgsMap = refinedArguments.associateBy { it.name.identifier }.toSortedMap()
val conflictingKeys = additionalArguments.keys intersect actualArgsMap.keys
if (conflictingKeys.isNotEmpty()) {
error("Conflicting keys: $conflictingKeys")
interpretationFrameworkError("Conflicting keys: $conflictingKeys")
}
val expectedArgsMap = processor.expectedArguments
.filterNot { it.name.startsWith("typeArg") }
.associateBy { it.name }.toSortedMap().minus(additionalArguments.keys)

if (expectedArgsMap.keys - defaultArguments != actualArgsMap.keys - defaultArguments) {
val unexpectedArguments = expectedArgsMap.keys - defaultArguments != actualArgsMap.keys - defaultArguments
if (unexpectedArguments) {
val message = buildString {
appendLine("ERROR: Different set of arguments")
appendLine("Implementation class: $processor")
Expand All @@ -107,8 +108,7 @@ fun <T> KotlinTypeFacade.interpret(
appendLine("add arguments to an interpeter:")
appendLine(diff.map { actualArgsMap[it] })
}
reporter.reportInterpretationError(functionCall, message)
return null
interpretationFrameworkError(message)
}

val arguments = mutableMapOf<String, Interpreter.Success<Any?>>()
Expand Down Expand Up @@ -206,7 +206,8 @@ fun <T> KotlinTypeFacade.interpret(
}

is Interpreter.Dsl -> {
{ receiver: Any ->
{ receiver: Any, dslArguments: Map<String, Interpreter.Success<Any?>> ->
val map = mapOf("dsl" to Interpreter.Success(receiver)) + dslArguments
(it.expression as FirAnonymousFunctionExpression)
.anonymousFunction.body!!
.statements.filterIsInstance<FirFunctionCall>()
Expand All @@ -215,7 +216,7 @@ fun <T> KotlinTypeFacade.interpret(
interpret(
call,
schemaProcessor,
mapOf("dsl" to Interpreter.Success(receiver)),
map,
reporter
)
}
Expand Down Expand Up @@ -271,6 +272,10 @@ fun <T> KotlinTypeFacade.interpret(
}
}

fun interpretationFrameworkError(message: String): Nothing = throw InterpretationFrameworkError(message)

class InterpretationFrameworkError(message: String) : Error(message)

interface InterpretationErrorReporter {
val errorReported: Boolean
fun reportInterpretationError(call: FirFunctionCall, message: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameDsl
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom

internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
Expand Down Expand Up @@ -171,6 +174,9 @@ internal inline fun <reified T> String.load(): T {
"ColsOf1" -> ColsOf1()
"ColsAtAnyDepth0" -> ColsAtAnyDepth0()
"FrameCols0" -> FrameCols0()
"toDataFrameDsl" -> ToDataFrameDsl()
"toDataFrame" -> ToDataFrame()
"toDataFrameDefault" -> ToDataFrameDefault()
else -> error("$this")
} as T
}

0 comments on commit ddb9fac

Please sign in to comment.