Skip to content

Commit

Permalink
Merge pull request #1050 from Kotlin/plugin-group-by
Browse files Browse the repository at this point in the history
Enhanced GroupBy support in compiler plugin
  • Loading branch information
koperagen authored Feb 3, 2025
2 parents 05d911b + 16f5d51 commit a68e18a
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ public fun <T> DataFrame<T>.add(body: AddDsl<T>.() -> Unit): DataFrame<T> {
return dataFrameOf(this@add.columns() + dsl.columns).cast()
}

@Refine
@Interpretable("GroupByAdd")
public inline fun <reified R, T, G> GroupBy<T, G>.add(
name: String,
infer: Infer = Infer.Nulls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ class FunctionCallTransformer(
val groupMarker = rootMarkers[1]

val (keySchema, groupSchema) = if (groupBy != null) {
val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop)
val groupSchema = PluginDataFrameSchema(groupBy.df.columns())
val keySchema = groupBy.keys
val groupSchema = groupBy.groups
keySchema to groupSchema
} else {
PluginDataFrameSchema.EMPTY to PluginDataFrameSchema.EMPTY
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter.*
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
import kotlin.properties.PropertyDelegateProvider
Expand Down Expand Up @@ -35,3 +36,7 @@ internal fun <T> AbstractInterpreter<T>.ignore(
): ExpectedArgumentProvider<Nothing?> =
arg(name, lens = Interpreter.Id, defaultValue = Present(null))

internal fun <T> AbstractInterpreter<T>.groupBy(
name: ArgumentName? = null
): ExpectedArgumentProvider<GroupBy> = arg(name, lens = Interpreter.GroupBy)

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ interface Interpreter<T> {

data object Schema : Lens

data object GroupBy : Lens

data object Id : Lens

// required to compute whether resulting schema should be inheritor of previous class or a new class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ data class PluginDataFrameSchema(
}
}

fun PluginDataFrameSchema.add(name: String, type: ConeKotlinType, context: KotlinTypeFacade): PluginDataFrameSchema {
return PluginDataFrameSchema(columns() + context.simpleColumnOf(name, type))
}

private fun List<SimpleCol>.asString(indent: String = ""): String {
return joinToString("\n") {
val col = when (it) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,22 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.add
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type

class GroupBy(val df: PluginDataFrameSchema, val keys: List<ColumnWithPathApproximation>, val moveToTop: Boolean)
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)

class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.moveToTop: Boolean by arg(defaultValue = Present(true))
val Arguments.cols: ColumnsResolver by arg()

override fun Arguments.interpret(): GroupBy {
return GroupBy(receiver, cols.resolve(receiver), moveToTop)
return GroupBy(keys = createPluginDataFrameSchema(cols.resolve(receiver), moveToTop), groups = receiver)
}
}

Expand All @@ -52,7 +55,7 @@ class GroupByInto : AbstractInterpreter<Unit>() {
}

class Aggregate : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: GroupBy by arg()
val Arguments.receiver: GroupBy by groupBy()
val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id)
override fun Arguments.interpret(): PluginDataFrameSchema {
return aggregate(
Expand Down Expand Up @@ -87,7 +90,7 @@ fun KotlinTypeFacade.aggregate(
)
}

val cols = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop).columns() + dsl.columns.map {
val cols = groupBy.keys.columns() + dsl.columns.map {
simpleColumnOf(it.name, it.type)
}
PluginDataFrameSchema(cols)
Expand Down Expand Up @@ -144,13 +147,23 @@ fun KotlinTypeFacade.createPluginDataFrameSchema(keys: List<ColumnWithPathApprox
}

class GroupByToDataFrame : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: GroupBy by arg()
val Arguments.receiver: GroupBy by groupBy()
val Arguments.groupedColumnName: String? by arg(defaultValue = Present(null))

override fun Arguments.interpret(): PluginDataFrameSchema {
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.df.columns()))
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.groups.columns()))
return PluginDataFrameSchema(
createPluginDataFrameSchema(receiver.keys, receiver.moveToTop).columns() + grouped
receiver.keys.columns() + grouped
)
}
}

class GroupByAdd : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: GroupBy by groupBy()
val Arguments.name: String by arg()
val Arguments.type: TypeApproximation by type(name("expression"))

override fun Arguments.interpret(): GroupBy {
return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.jetbrains.kotlin.fir.references.resolved
import org.jetbrains.kotlin.fir.references.symbol
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.scopes.collectAllProperties
import org.jetbrains.kotlin.fir.scopes.getProperties
import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope
Expand Down Expand Up @@ -78,6 +79,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation

Expand Down Expand Up @@ -277,6 +279,17 @@ fun <T> KotlinTypeFacade.interpret(
}
}

is Interpreter.GroupBy -> {
assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) {
"'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType"
}

val resolvedType = it.expression.resolvedType.fullyExpandedType(session)
val keys = pluginDataFrameSchema(resolvedType.typeArguments[0])
val groups = pluginDataFrameSchema(resolvedType.typeArguments[1])
Interpreter.Success(GroupBy(keys, groups))
}

is Interpreter.Id -> {
Interpreter.Success(it.expression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MoveAfter0
Expand Down Expand Up @@ -275,6 +276,7 @@ internal inline fun <reified T> String.load(): T {
"MoveToLeft1" -> MoveToLeft1()
"MoveToRight0" -> MoveToRight0()
"MoveAfter0" -> MoveAfter0()
"GroupByAdd" -> GroupByAdd()
else -> error("$this")
} as T
}
42 changes: 42 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupByAdd.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.io.*

enum class State {
Idle,
Productive,
Maintenance,
}

class Event(val toolId: String, val state: State, val timestamp: Long)

fun box(): String {
val tool1 = "tool_1"
val tool2 = "tool_2"
val tool3 = "tool_3"

val events = listOf(
Event(tool1, State.Idle, 0),
Event(tool1, State.Productive, 5),
Event(tool2, State.Idle, 0),
Event(tool2, State.Maintenance, 10),
Event(tool2, State.Idle, 20),
Event(tool3, State.Idle, 0),
Event(tool3, State.Productive, 25),
).toDataFrame()

val lastTimestamp = events.maxOf { timestamp }
val groupBy = events
.groupBy { toolId }
.sortBy { timestamp }
.add("stateDuration") {
(next()?.timestamp ?: lastTimestamp) - timestamp
}.toDataFrame()

groupBy.group[0].stateDuration

groupBy.compareSchemas(strict = true)
return "OK"
}
14 changes: 14 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupBy_extractSchema.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf("a", "b", "c")(1, 2, 3)

val groupBy = df.groupBy { a }

val df1 = groupBy.updateGroups { it.remove { a } }.toDataFrame()
df1.compileTimeSchema().print()
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,24 @@ public void testGroupBy() {
runTest("testData/box/groupBy.kt");
}

@Test
@TestMetadata("groupByAdd.kt")
public void testGroupByAdd() {
runTest("testData/box/groupByAdd.kt");
}

@Test
@TestMetadata("groupBy_DataRow.kt")
public void testGroupBy_DataRow() {
runTest("testData/box/groupBy_DataRow.kt");
}

@Test
@TestMetadata("groupBy_extractSchema.kt")
public void testGroupBy_extractSchema() {
runTest("testData/box/groupBy_extractSchema.kt");
}

@Test
@TestMetadata("groupBy_refine.kt")
public void testGroupBy_refine() {
Expand Down

0 comments on commit a68e18a

Please sign in to comment.