-
-
Save fracaron/6334ef10dac92b0ecbf97235b2e164d8 to your computer and use it in GitHub Desktop.
Batch select in Exposed, inspired by Rails' find_in_batches
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.jetbrains.exposed.sql | |
import org.jetbrains.exposed.dao.EntityID | |
import org.jetbrains.exposed.sql.FieldSet | |
import org.jetbrains.exposed.sql.Op | |
import org.jetbrains.exposed.sql.ResultRow | |
import org.jetbrains.exposed.sql.SortOrder | |
import org.jetbrains.exposed.sql.SqlExpressionBuilder | |
import org.jetbrains.exposed.sql.and | |
import org.jetbrains.exposed.sql.isAutoInc | |
import org.jetbrains.exposed.sql.select | |
fun FieldSet.selectBatched( | |
batchSize: Int = 1000, | |
where: SqlExpressionBuilder.() -> Op<Boolean> | |
): Iterable<Iterable<ResultRow>> { | |
return selectBatched(batchSize, SqlExpressionBuilder.where()) | |
} | |
fun FieldSet.selectAllBatched( | |
batchSize: Int = 1000 | |
): Iterable<Iterable<ResultRow>> { | |
return selectBatched(batchSize, Op.TRUE) | |
} | |
private fun FieldSet.selectBatched( | |
batchSize: Int = 1000, | |
whereOp: Op<Boolean> | |
): Iterable<Iterable<ResultRow>> { | |
require(batchSize > 0) { "Batch size should be greater than 0" } | |
val autoIncColumn = try { | |
source.columns.first { it.columnType.isAutoInc } | |
} catch (_: NoSuchElementException) { | |
throw UnsupportedOperationException("Batched select only works on tables with an autoincrementing column") | |
} | |
return object : Iterable<Iterable<ResultRow>> { | |
override fun iterator(): Iterator<Iterable<ResultRow>> { | |
return iterator { | |
var lastOffset = 0L | |
while (true) { | |
val query = | |
select { whereOp and (autoIncColumn greater lastOffset) } | |
.limit(batchSize) | |
.orderBy(autoIncColumn, SortOrder.ASC) | |
// query.iterator() executes the query | |
val results = query.iterator().asSequence().toList() | |
if (results.isEmpty()) break | |
yield(results) | |
lastOffset = toLong(results.last()[autoIncColumn]!!) | |
} | |
} | |
} | |
private fun toLong(autoIncVal: Any): Long = when (autoIncVal) { | |
is EntityID<*> -> autoIncVal.value as Long | |
else -> autoIncVal as Long | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.jetbrains.exposed.sql | |
import org.jetbrains.exposed.config.KMySqlContainer | |
import io.kotlintest.properties.Gen | |
import io.kotlintest.properties.forAll | |
import io.kotlintest.shouldBe | |
import io.kotlintest.shouldThrow | |
import java.security.SecureRandom | |
import java.util.* | |
import org.jetbrains.exposed.dao.LongIdTable | |
import org.jetbrains.exposed.sql.Database | |
import org.jetbrains.exposed.sql.ResultRow | |
import org.jetbrains.exposed.sql.SchemaUtils | |
import org.jetbrains.exposed.sql.Table | |
import org.jetbrains.exposed.sql.batchInsert | |
import org.jetbrains.exposed.sql.transactions.transaction | |
import org.junit.jupiter.api.AfterEach | |
import org.junit.jupiter.api.BeforeEach | |
import org.junit.jupiter.api.Test | |
internal class BatchSelectTest { | |
object Cities : LongIdTable() { | |
val name = varchar("name", length = 50) | |
val type = integer("type") | |
} | |
object TableWithNoAutoIncCol : Table() { | |
val name = varchar("name", length = 50).primaryKey() | |
} | |
data class City( | |
val name: String, | |
val type: CityType | |
) | |
enum class CityType { | |
SMALL, | |
BIG | |
} | |
@BeforeEach | |
fun setUp() { | |
transaction { | |
SchemaUtils.create(Cities) | |
} | |
} | |
@AfterEach | |
fun tearDown() { | |
transaction { | |
SchemaUtils.drop(Cities) | |
} | |
} | |
@Test | |
fun `should respect 'where' expression and the provided batch size`() { | |
transaction { | |
val smallCities = generateCities(count = 50, cityType = CityType.SMALL) | |
val bigCities = generateCities(count = 50, cityType = CityType.BIG) | |
insert(smallCities) | |
insert(bigCities) | |
val batches = Cities.selectBatched(batchSize = 25) { | |
Cities.type eq CityType.SMALL.ordinal | |
} | |
.toList() | |
.map { it.toCityList() } | |
batches shouldBe listOf( | |
smallCities.take(25), | |
smallCities.takeLast(25) | |
) | |
} | |
} | |
@Test | |
fun `when batch size is greater than the amount of available items, should return 1 batch`() { | |
transaction { | |
val cities = generateCities(count = 25, cityType = CityType.SMALL) | |
insert(cities) | |
val batches = Cities.selectBatched(batchSize = 100) { | |
Cities.type eq CityType.SMALL.ordinal | |
} | |
.toList() | |
.map { it.toCityList() } | |
batches shouldBe listOf(cities) | |
} | |
} | |
@Test | |
fun `when selecting all by batches, should return all available items`() { | |
transaction { | |
val smallCities = generateCities(count = 30, cityType = CityType.SMALL) | |
val bigCities = generateCities(count = 30, cityType = CityType.BIG) | |
insert(smallCities) | |
insert(bigCities) | |
val batches = Cities.selectAllBatched(batchSize = 30) | |
.toList() | |
.map { it.toCityList() } | |
batches shouldBe listOf( | |
smallCities, | |
bigCities | |
) | |
} | |
} | |
@Test | |
fun `when there are no items, should return an empty iterable`() { | |
transaction { | |
val batches = Cities.selectAllBatched() | |
.toList() | |
batches shouldBe emptyList() | |
} | |
} | |
@Test | |
fun `when there are no items of the given condition, should return an empty iterable`() { | |
transaction { | |
val cities = generateCities(count = 25, cityType = CityType.SMALL) | |
insert(cities) | |
val batches = Cities.selectBatched(batchSize = 100) { | |
Cities.type eq CityType.BIG.ordinal | |
}.toList() | |
batches shouldBe emptyList() | |
} | |
} | |
@Test | |
fun `when the table doesn't have an autoinc column, should throw an exception`() { | |
shouldThrow<UnsupportedOperationException> { | |
TableWithNoAutoIncCol.selectAllBatched() | |
} | |
} | |
@Test | |
fun `when batch size is 0 or less, should throw an exception`() { | |
val zeroOrNegativeIntGenerator = object : Gen<Int> { | |
override fun constants(): Iterable<Int> = listOf(0) | |
override fun random(): Sequence<Int> = generateSequence { | |
secureRandom.nextInt(Integer.MAX_VALUE) - Integer.MAX_VALUE | |
} | |
} | |
forAll(zeroOrNegativeIntGenerator) { size -> | |
runCatching { Cities.selectAllBatched(batchSize = size) } | |
.exceptionOrNull() is IllegalArgumentException | |
} | |
} | |
private fun generateCities(count: Int = 50, cityType: CityType = CityType.SMALL): List<City> { | |
return List(count) { | |
City( | |
name = UUID.randomUUID().toString(), | |
type = cityType | |
) | |
} | |
} | |
private fun insert(cities: List<City>) { | |
Cities.batchInsert(cities) { (name, type) -> | |
this[Cities.name] = name | |
this[Cities.type] = type.ordinal | |
} | |
} | |
private fun ResultRow.toCity() = City( | |
name = this[Cities.name], | |
type = CityType.values()[this[Cities.type]] | |
) | |
private fun Iterable<ResultRow>.toCityList(): List<City> = map { it.toCity() } | |
companion object { | |
private val mySqlContainer = KMySqlContainer() | |
private val secureRandom = SecureRandom() | |
init { | |
mySqlContainer.start() | |
Database.connect( | |
url = mySqlContainer.jdbcUrl, | |
driver = "com.mysql.cj.jdbc.Driver", | |
user = mySqlContainer.username, | |
password = mySqlContainer.password | |
) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment