Skip to content

Commit

Permalink
Merge pull request #779 from Kotlin/sort-by-columnreference
Browse files Browse the repository at this point in the history
Make sortBy(ColumnReference) accept pathOf without extra cast
  • Loading branch information
koperagen authored Jul 19, 2024
2 parents 0daccc8 + ee753e0 commit 89db8e6
Show file tree
Hide file tree
Showing 4 changed files with 672 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public fun <T, C> DataFrame<T>.sortBy(columns: SortColumnsSelector<T, C>): DataF
UnresolvedColumnsPolicy.Fail, columns
)

public fun <T> DataFrame<T>.sortBy(vararg cols: ColumnReference<Comparable<*>?>): DataFrame<T> =
public fun <T> DataFrame<T>.sortBy(vararg cols: ColumnReference<*>): DataFrame<T> =
sortBy { cols.toColumnSet() }

public fun <T> DataFrame<T>.sortBy(vararg cols: String): DataFrame<T> = sortBy { cols.toColumnSet() }
Expand All @@ -132,7 +132,7 @@ public fun <T, C> DataFrame<T>.sortByDesc(vararg columns: KProperty<Comparable<C

public fun <T> DataFrame<T>.sortByDesc(vararg columns: String): DataFrame<T> = sortByDesc { columns.toColumnSet() }

public fun <T, C> DataFrame<T>.sortByDesc(vararg columns: ColumnReference<Comparable<C>?>): DataFrame<T> =
public fun <T> DataFrame<T>.sortByDesc(vararg columns: ColumnReference<*>): DataFrame<T> =
sortByDesc { columns.toColumnSet() }

// endregion
Expand All @@ -141,7 +141,7 @@ public fun <T, C> DataFrame<T>.sortByDesc(vararg columns: ColumnReference<Compar

public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: String): GroupBy<T, G> = sortBy { cols.toColumnSet() }

public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: ColumnReference<Comparable<*>?>): GroupBy<T, G> =
public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: ColumnReference<*>): GroupBy<T, G> =
sortBy { cols.toColumnSet() }

public fun <T, G> GroupBy<T, G>.sortBy(vararg cols: KProperty<Comparable<*>?>): GroupBy<T, G> =
Expand All @@ -151,7 +151,7 @@ public fun <T, G, C> GroupBy<T, G>.sortBy(selector: SortColumnsSelector<G, C>):

public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: String): GroupBy<T, G> = sortByDesc { cols.toColumnSet() }

public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: ColumnReference<Comparable<*>?>): GroupBy<T, G> =
public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: ColumnReference<*>): GroupBy<T, G> =
sortByDesc { cols.toColumnSet() }

public fun <T, G> GroupBy<T, G>.sortByDesc(vararg cols: KProperty<Comparable<*>?>): GroupBy<T, G> =
Expand Down
28 changes: 28 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package org.jetbrains.kotlinx.dataframe.api

import io.kotest.assertions.throwables.shouldThrowMessage
import io.kotest.matchers.shouldBe
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.io.readDataFrame
import org.jetbrains.kotlinx.dataframe.nrow
import org.jetbrains.kotlinx.dataframe.testResource
import org.jetbrains.kotlinx.dataframe.testSets.*
import org.jetbrains.kotlinx.dataframe.testSets.DsSalaries
import org.junit.Test

class SortDataColumn {
Expand Down Expand Up @@ -67,4 +72,27 @@ class SortDataColumn {
col.sortWith { df1, df2 -> df1[a] - df2[a] } shouldBe sortedCol
col.sortWith(compareBy { it[a] }) shouldBe sortedCol
}

@Test
fun `sort by nested column`() {
val df = testResource("ds_salaries.csv").readDataFrame().cast<DsSalaries>()
val aggregate = df.pivot(false) { companySize }.groupBy { companyLocation }.aggregate {
maxOf { salaryInUsd } into "salary"
maxBy { salaryInUsd } into "extra"
}
aggregate.sortBy(pathOf("L", "salary"))[0][pathOf("L", "salary")] shouldBe null
aggregate.sortByDesc(pathOf("L", "salary"))[0][pathOf("L", "salary")] shouldBe 600_000
}

@Test
fun `sort by invalid nested column`() {
val df = testResource("ds_salaries.csv").readDataFrame().cast<DsSalaries>()
val aggregate = df.pivot(false) { companySize }.groupBy { companyLocation }.aggregate {
maxOf { salaryInUsd } into "salary"
maxBy { salaryInUsd } into "extra"
}
shouldThrowMessage("Can not use ColumnGroup as sort column") {
aggregate.sortBy(pathOf("L", "extra"))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.jetbrains.kotlinx.dataframe.testSets

import org.jetbrains.kotlinx.dataframe.annotations.ColumnName
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema

// Dataset from https://www.kaggle.com/datasets/ruchi798/data-science-job-salaries
@Suppress("unused")
@DataSchema
interface DsSalaries {
@ColumnName("company_location")
val companyLocation: String
@ColumnName("company_size")
val companySize: String
@ColumnName("employee_residence")
val employeeResidence: String
@ColumnName("employment_type")
val employmentType: String
@ColumnName("experience_level")
val experienceLevel: String
@ColumnName("job_title")
val jobTitle: String
@ColumnName("remote_ratio")
val remoteRatio: Int
val salary: Int
@ColumnName("salary_currency")
val salaryCurrency: String
@ColumnName("salary_in_usd")
val salaryInUsd: Int
val untitled: Int
@ColumnName("work_year")
val workYear: Int
}
Loading

0 comments on commit 89db8e6

Please sign in to comment.