From 34ec5446a1d6a5045f512f42c7a19d14f457e4eb Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Thu, 9 Jan 2025 13:52:38 +0100 Subject: [PATCH] Add support for writing DataFrames to SQL tables Introduce functionality to write DataFrames to SQL, including table creation, batch insertion, and handling existing tables via `IfExists` strategies. Extend `DbType` with methods for Kotlin-to-SQL type conversion and nullability handling. --- .../kotlinx/dataframe/io/db/DbType.kt | 15 +++ .../kotlinx/dataframe/io/writeJdbc.kt | 103 ++++++++++++++++++ 2 files changed, 118 insertions(+) create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/writeJdbc.kt diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt index def6e815c3..4a9edce156 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/DbType.kt @@ -51,6 +51,14 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { */ public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? + /** + * Converts a Kotlin type ([KType]) to its corresponding SQL type as a String. + * + * @param kType The Kotlin type to be converted. + * @return The corresponding SQL type as a String, or null if no matching SQL type exists. + */ + public abstract fun convertKTypeToSqlType(kType: KType): String? + /** * Constructs a SQL query with a limit clause. * @@ -59,4 +67,11 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) { * @return A new SQL query with the limit clause added. */ public open fun sqlQueryLimit(sqlQuery: String, limit: Int = 1): String = "$sqlQuery LIMIT $limit" + + /** + * Handles optional type conversion for nullable values. + */ + public open fun handleNullable(sqlType: String, isNullable: Boolean): String { + return if (isNullable) "$sqlType NULL" else "$sqlType NOT NULL" + } } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/writeJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/writeJdbc.kt new file mode 100644 index 0000000000..ee01be8814 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/writeJdbc.kt @@ -0,0 +1,103 @@ +package org.jetbrains.kotlinx.dataframe.io + +import java.sql.Connection +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.db.DbType + +enum class IfExists { + FAIL, + REPLACE, + APPEND +} + +public fun DataFrame.writeToSqlTable( + con: Connection, + name: String, + schema: String? = null, + inferNullability: Boolean = true, + ifExists: IfExists = IfExists.FAIL, + batchSize: Int = 1000, + dbType: DbType? = null, +) { + val qualifiedName = if (schema != null) "$schema.$name" else name + val tableExists = doesTableExist(con, qualifiedName, dbType) + + when (ifExists) { + IfExists.FAIL -> if (tableExists) throw IllegalArgumentException("Table $qualifiedName already exists.") + IfExists.REPLACE -> { + if (tableExists) dropTable(con, qualifiedName) + createTable(this, con, qualifiedName, dbType) + } + IfExists.APPEND -> if (!tableExists) createTable(this, con, qualifiedName, dbType) + } + + val batchLimit = batchSize ?: this.rowsCount() + val insertQuery = buildInsertQuery(qualifiedName, columnNames()) + val preparedStatement: PreparedStatement = con.prepareStatement(insertQuery) + + con.autoCommit = false + try { + forEachBatch(this, batchLimit) { batch -> + batch.forEach { row -> + columnNames().forEachIndexed { index, columnName -> + val value = row[columnName] + preparedStatement.setObject(index + 1, value) + } + preparedStatement.addBatch() + } + preparedStatement.executeBatch() + } + con.commit() + } catch (exception: Exception) { + con.rollback() + throw exception + } finally { + preparedStatement.close() + } +} + + +public fun doesTableExist(connection: Connection, tableName: String, dbType: DbType?): Boolean { + val query = "SELECT 1 FROM information_schema.tables WHERE table_name = ?" + connection.prepareStatement(query).use { statement -> + statement.setString(1, tableName) + statement.executeQuery().use { resultSet -> + return resultSet.next() + } + } +} + +public fun forEachBatch(dataFrame: DataFrame, batchSize: Int, action: (DataFrame) -> Unit) { + val totalRows = dataFrame.rowsCount() + for (start in 0 until totalRows step batchSize) { + val end = minOf(start + batchSize, totalRows) + val batch = dataFrame[start until end] + action(batch) + } +} + +public fun createTable(dataFrame: DataFrame, connection: Connection, tableName: String, dbType: DbType?) { + val columnsDefinition = dataFrame.columnNames().zip(dataFrame.columnTypes()).joinToString(", ") { (name, type) -> + val sqlType = dbType.convertKTypeToSqlType(type) + dbType?.handleNullable(sqlType, type.isMarkedNullable) ?: throw IllegalArgumentException ("dbType is not specified") + "$name $sqlType" + } + val createQuery = "CREATE TABLE $tableName ($columnsDefinition)" + connection.prepareStatement(createQuery).use { it.executeUpdate() } +} + +public fun dropTable(connection: Connection, tableName: String) { + val query = "DROP TABLE $tableName" + connection.prepareStatement(query).use { it.executeUpdate() } +} + +public fun buildInsertQuery(tableName: String, columnNames: List): String { + val placeholders = columnNames.joinToString(", ") { "?" } + val columns = columnNames.joinToString(", ") + return "INSERT INTO $tableName ($columns) VALUES ($placeholders)" +} + + + + +