Skip to content

Commit

Permalink
Merge pull request #1055 from Kotlin/group-by-shortcuts
Browse files Browse the repository at this point in the history
[Compiler plugin] Support more GroupBy shortcuts
  • Loading branch information
koperagen authored Feb 13, 2025
2 parents eba950c + 981a47c commit a626961
Show file tree
Hide file tree
Showing 18 changed files with 194 additions and 11 deletions.
2 changes: 1 addition & 1 deletion core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -3637,6 +3637,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConcatKt {
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/DataRow;[Lorg/jetbrains/kotlinx/dataframe/DataRow;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concatRows (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concatT (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
}
Expand Down Expand Up @@ -4938,7 +4939,6 @@ public final class org/jetbrains/kotlinx/dataframe/api/InsertKt {
}

public final class org/jetbrains/kotlinx/dataframe/api/IntoKt {
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun into (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun into (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun into (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;Lorg/jetbrains/kotlinx/dataframe/columns/ColumnAccessor;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ public fun <T, G> GroupBy<T, G>.concat(): DataFrame<G> = groups.concat()

// endregion

// region ReducedGroupBy

public fun <T, G> ReducedGroupBy<T, G>.concat(): DataFrame<G> =
groupBy.groups.values()
.map { reducer(it, it) }
.concat()

// endregion

// region Iterable

public fun <T> Iterable<DataFrame<T>>.concat(): DataFrame<T> = concatImpl(asList())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.Predicate
import org.jetbrains.kotlinx.dataframe.RowFilter
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateValue

// region DataColumn
Expand Down Expand Up @@ -37,9 +39,13 @@ public fun <T> DataFrame<T>.count(predicate: RowFilter<T>): Int = rows().count {

// region GroupBy

@Refine
@Interpretable("GroupByCount0")
public fun <T> Grouped<T>.count(resultName: String = "count"): DataFrame<T> =
aggregateValue(resultName) { count() default 0 }

@Refine
@Interpretable("GroupByCount0")
public fun <T> Grouped<T>.count(resultName: String = "count", predicate: RowFilter<T>): DataFrame<T> =
aggregateValue(resultName) { count(predicate) default 0 }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowFilter
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
Expand Down Expand Up @@ -55,8 +56,10 @@ public fun <T> DataFrame<T>.firstOrNull(predicate: RowFilter<T>): DataRow<T>? =

// region GroupBy

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.first(): ReducedGroupBy<T, G> = reduce { firstOrNull() }

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.first(predicate: RowFilter<G>): ReducedGroupBy<T, G> = reduce { firstOrNull(predicate) }

// endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public inline fun <T, G, reified V> ReducedGroupBy<T, G>.into(
noinline expression: RowExpression<G, V>,
): DataFrame<G> = into(column.columnName, expression)

@Refine
@Interpretable("GroupByReduceInto")
public fun <T, G> ReducedGroupBy<T, G>.into(columnName: String): DataFrame<G> = into(columnName) { this }

@AccessApiOverload
Expand All @@ -87,9 +89,4 @@ public fun <T, G> ReducedGroupBy<T, G>.into(column: ColumnAccessor<AnyRow>): Dat
@AccessApiOverload
public fun <T, G> ReducedGroupBy<T, G>.into(column: KProperty<AnyRow>): DataFrame<G> = into(column) { this }

public fun <T, G> ReducedGroupBy<T, G>.concat(): DataFrame<G> =
groupBy.groups.values()
.map { reducer(it, it) }
.concat()

// endregion
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowFilter
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
Expand Down Expand Up @@ -56,8 +57,10 @@ public fun <T> DataFrame<T>.last(): DataRow<T> {

// region GroupBy

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.last(): ReducedGroupBy<T, G> = reduce { lastOrNull() }

@Interpretable("GroupByReducePredicate")
public fun <T, G> GroupBy<T, G>.last(predicate: RowFilter<G>): ReducedGroupBy<T, G> = reduce { lastOrNull(predicate) }

// endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.columns.values
Expand Down Expand Up @@ -163,11 +165,14 @@ public fun <T, C : Comparable<C>> Grouped<T>.max(
public fun <T, C : Comparable<C>> Grouped<T>.max(vararg columns: KProperty<C?>, name: String? = null): DataFrame<T> =
max(name) { columns.toColumnSet() }

@Refine
@Interpretable("GroupByMaxOf")
public fun <T, C : Comparable<C>> Grouped<T>.maxOf(
name: String? = null,
expression: RowExpression<T, C>,
): DataFrame<T> = Aggregators.max.aggregateOfDelegated(this, name) { maxOfOrNull(expression) }

@Interpretable("GroupByReduceExpression")
public fun <T, G, R : Comparable<R>> GroupBy<T, G>.maxBy(rowExpression: RowExpression<G, R?>): ReducedGroupBy<T, G> =
reduce { maxByOrNull(rowExpression) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.RowExpression
import org.jetbrains.kotlinx.dataframe.aggregation.ColumnsForAggregateSelector
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.columns.values
Expand Down Expand Up @@ -163,11 +165,14 @@ public fun <T, C : Comparable<C>> Grouped<T>.min(
public fun <T, C : Comparable<C>> Grouped<T>.min(vararg columns: KProperty<C?>, name: String? = null): DataFrame<T> =
min(name) { columns.toColumnSet() }

@Refine
@Interpretable("GroupByMinOf")
public fun <T, C : Comparable<C>> Grouped<T>.minOf(
name: String? = null,
expression: RowExpression<T, C>,
): DataFrame<T> = Aggregators.min.aggregateOfDelegated(this, name) { minOfOrNull(expression) }

@Interpretable("GroupByReduceExpression")
public fun <T, G, R : Comparable<R>> GroupBy<T, G>.minBy(rowExpression: RowExpression<G, R?>): ReducedGroupBy<T, G> =
reduce { minByOrNull(rowExpression) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fun KotlinTypeFacade.simpleColumnOf(name: String, type: ConeKotlinType): SimpleC
}
}

private fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol {
internal fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol {
return when (column) {
is SimpleColumnGroup -> {
SimpleColumnGroup(column.name, column.columns().map { makeNullable(it) })
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable

class GroupByReducePredicate : AbstractInterpreter<GroupBy>() {
val Arguments.receiver by groupBy()
val Arguments.predicate by ignore()
override fun Arguments.interpret(): GroupBy {
return receiver
}
}

class GroupByReduceExpression : AbstractInterpreter<GroupBy>() {
val Arguments.receiver by groupBy()
val Arguments.rowExpression by ignore()
override fun Arguments.interpret(): GroupBy {
return receiver
}
}

class GroupByReduceInto : AbstractSchemaModificationInterpreter() {
val Arguments.receiver by groupBy()
val Arguments.columnName: String by arg()
override fun Arguments.interpret(): PluginDataFrameSchema {
val group = makeNullable(SimpleColumnGroup(columnName, receiver.groups.columns()))
return PluginDataFrameSchema(receiver.keys.columns() + group)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
import org.jetbrains.kotlinx.dataframe.plugin.impl.add
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore

class GroupByCount0 : AbstractSchemaModificationInterpreter() {
val Arguments.receiver by groupBy()
val Arguments.resultName: String by arg(defaultValue = Present("count"))
val Arguments.predicate by ignore()

override fun Arguments.interpret(): PluginDataFrameSchema {
return receiver.keys.add(resultName, session.builtinTypes.intType.type, context = this)
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter
import org.jetbrains.kotlinx.dataframe.plugin.interpret
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
import org.jetbrains.kotlin.fir.expressions.FirAnonymousFunctionExpression
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirReturnExpression
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlinx.dataframe.plugin.InterpretationErrorReporter
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
Expand All @@ -23,8 +21,12 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.add
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
import org.jetbrains.kotlinx.dataframe.plugin.impl.makeNullable
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.interpret
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter

class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)

Expand Down Expand Up @@ -173,9 +175,24 @@ class GroupByToDataFrame : AbstractSchemaModificationInterpreter() {
class GroupByAdd : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: GroupBy by groupBy()
val Arguments.name: String by arg()
val Arguments.infer by ignore()
val Arguments.type: TypeApproximation by type(name("expression"))

override fun Arguments.interpret(): GroupBy {
return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this))
}
}

abstract class GroupByAggregator(val defaultName: String) : AbstractSchemaModificationInterpreter() {
val Arguments.receiver by groupBy()
val Arguments.name: String? by arg(defaultValue = Present(null))
val Arguments.expression by type()

override fun Arguments.interpret(): PluginDataFrameSchema {
val aggregated = makeNullable(simpleColumnOf(name ?: defaultName, expression.type))
return PluginDataFrameSchema(receiver.keys.columns() + aggregated)
}
}

class GroupByMaxOf : GroupByAggregator(defaultName = "max")
class GroupByMinOf : GroupByAggregator(defaultName = "min")
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fun <T> KotlinTypeFacade.interpret(
assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) {
"'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType"
}

// ok for ReducedGroupBy too
val resolvedType = it.expression.resolvedType.fullyExpandedType(session)
val keys = pluginDataFrameSchema(resolvedType.typeArguments[0])
val groups = pluginDataFrameSchema(resolvedType.typeArguments[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByCount0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByInto
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMaxOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByMinOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceExpression
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReduceInto
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByReducePredicate
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Merge0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MergeId
Expand Down Expand Up @@ -295,6 +301,12 @@ internal inline fun <reified T> String.load(): T {
"MergeBy0" -> MergeBy0()
"MergeBy1" -> MergeBy1()
"ReorderColumnsByName" -> ReorderColumnsByName()
"GroupByCount0" -> GroupByCount0()
"GroupByReducePredicate" -> GroupByReducePredicate()
"GroupByReduceExpression" -> GroupByReduceExpression()
"GroupByReduceInto" -> GroupByReduceInto()
"GroupByMaxOf" -> GroupByMaxOf()
"GroupByMinOf" -> GroupByMinOf()
else -> error("$this")
} as T
}
16 changes: 16 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupBy_count.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.count()
val i: Int = df.count[0]

val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.count { a > 1 }
val i1: Int = df1.count[0]

val df2 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.count("myCol") { a > 1 }
val i2: Int = df2.myCol[0]
return "OK"
}
22 changes: 22 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupBy_maxOfMinOf.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.maxOf { 123 }
val df1 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.minOf { 123 }

val max = df.max[0]
val min = df1.min[0]

df.compareSchemas()
df1.compareSchemas()

val df2 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.maxOf("myMax") { 123 }
val df3 = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }.minOf("myMin") { 123 }

df2.myMax
df3.myMin
return "OK"
}
16 changes: 16 additions & 0 deletions plugins/kotlin-dataframe/testData/box/reducedGroupBy.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val groupBy = dataFrameOf("a")(1, 1, 2, 3, 3).groupBy { a }.add("id") { index() }
groupBy.maxBy { id }.into("group").compareSchemas()
groupBy.maxBy { id }.into("group").compareSchemas()
groupBy.first { id == 1 }.into("group").compareSchemas()
groupBy.first().into("group").compareSchemas()
groupBy.last { id == 1 }.into("group").compareSchemas()
groupBy.last().into("group").compareSchemas()
groupBy.minBy { id == 1 }.into("group").compareSchemas()
return "OK"
}
Loading

0 comments on commit a626961

Please sign in to comment.