Skip to content

Commit

Permalink
Properly pass IORuntime everywhere instead of the old ExecutionContext
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Jun 27, 2022
1 parent 9766100 commit 301c401
Show file tree
Hide file tree
Showing 38 changed files with 255 additions and 320 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package geotrellis.spark.store.accumulo
import geotrellis.store.accumulo._
import geotrellis.store.hadoop.util._
import geotrellis.spark.util._
import geotrellis.store.util.BlockingThreadPool
import geotrellis.store.util.IORuntimeTransient

import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.fs.Path
Expand All @@ -28,14 +28,11 @@ import org.apache.accumulo.core.data.{Key, Mutation, Value}
import org.apache.accumulo.core.client.mapreduce.AccumuloFileOutputFormat
import org.apache.accumulo.core.client.BatchWriterConfig

import cats.effect.IO
import cats.syntax.apply._
import cats.effect._
import cats.syntax.either._

import java.util.UUID

import scala.concurrent.ExecutionContext

object AccumuloWriteStrategy {
def DEFAULT = HdfsWriteStrategy("/geotrellis-ingest")
}
Expand Down Expand Up @@ -110,30 +107,28 @@ object HdfsWriteStrategy {
* @param config Configuration for the BatchWriters
*/
class SocketWriteStrategy(
@transient config: BatchWriterConfig = new BatchWriterConfig().setMaxMemory(128*1024*1024).setMaxWriteThreads(BlockingThreadPool.threads),
executionContext: => ExecutionContext = BlockingThreadPool.executionContext
@transient config: BatchWriterConfig = new BatchWriterConfig().setMaxMemory(128*1024*1024).setMaxWriteThreads(IORuntimeTransient.ThreadsNumber),
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
) extends AccumuloWriteStrategy {
val kwConfig = KryoWrapper(config) // BatchWriterConfig is not java serializable

def write(kvPairs: RDD[(Key, Value)], instance: AccumuloInstance, table: String): Unit = {
kvPairs.foreachPartition { partition =>
if(partition.nonEmpty) {
implicit val ec = executionContext
// TODO: runime should be configured
import cats.effect.unsafe.implicits.global
implicit val ioRuntime: unsafe.IORuntime = runtime

val writer = instance.connector.createBatchWriter(table, kwConfig.value)

try {
val mutations: fs2.Stream[IO, Mutation] = fs2.Stream.fromIterator[IO](
val mutations: fs2.Stream[IO, Mutation] = fs2.Stream.fromBlockingIterator[IO](
partition.map { case (key, value) =>
val mutation = new Mutation(key.getRow)
mutation.put(key.getColumnFamily, key.getColumnQualifier, System.currentTimeMillis(), value)
mutation
}, 1
}, chunkSize = 1
)

val write = { mutation: Mutation => fs2.Stream eval IO { writer.addMutation(mutation) } }
val write = { mutation: Mutation => fs2.Stream eval IO.blocking { writer.addMutation(mutation) } }

(mutations map write)
.parJoinUnbounded
Expand All @@ -150,7 +145,7 @@ class SocketWriteStrategy(

object SocketWriteStrategy {
def apply(
config: BatchWriterConfig = new BatchWriterConfig().setMaxMemory(128*1024*1024).setMaxWriteThreads(BlockingThreadPool.threads),
executionContext: => ExecutionContext = BlockingThreadPool.executionContext
): SocketWriteStrategy = new SocketWriteStrategy(config, executionContext)
config: BatchWriterConfig = new BatchWriterConfig().setMaxMemory(128*1024*1024).setMaxWriteThreads(IORuntimeTransient.ThreadsNumber),
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
): SocketWriteStrategy = new SocketWriteStrategy(config, runtime)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@ package geotrellis.store.accumulo
import geotrellis.layer._
import geotrellis.store.avro.codecs.KeyValueRecordCodec
import geotrellis.store.avro.{AvroEncoder, AvroRecordCodec}
import geotrellis.store.util.BlockingThreadPool
import geotrellis.store.util.IORuntimeTransient
import org.apache.accumulo.core.data.{Range => AccumuloRange}
import org.apache.accumulo.core.security.Authorizations
import org.apache.avro.Schema
import org.apache.hadoop.io.Text

import cats.effect._
import cats.syntax.apply._
import cats.syntax.either._

import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext
import scala.reflect.ClassTag

object AccumuloCollectionReader {
Expand All @@ -41,7 +39,7 @@ object AccumuloCollectionReader {
decomposeBounds: KeyBounds[K] => Seq[AccumuloRange],
filterIndexOnly: Boolean,
writerSchema: Option[Schema] = None,
executionContext: => ExecutionContext = BlockingThreadPool.executionContext
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
)(implicit instance: AccumuloInstance): Seq[(K, V)] = {
if(queryKeyBounds.isEmpty) return Seq.empty[(K, V)]

Expand All @@ -50,13 +48,11 @@ object AccumuloCollectionReader {

val ranges = queryKeyBounds.flatMap(decomposeBounds).iterator

implicit val ec = executionContext
// TODO: runime should be configured
import cats.effect.unsafe.implicits.global
implicit val ioRuntime: unsafe.IORuntime = runtime

val range: fs2.Stream[IO, AccumuloRange] = fs2.Stream.fromIterator[IO](ranges, 1)
val range: fs2.Stream[IO, AccumuloRange] = fs2.Stream.fromIterator[IO](ranges, chunkSize = 1)

val read = { range: AccumuloRange => fs2.Stream eval IO {
val read = { range: AccumuloRange => fs2.Stream eval IO.blocking {
val scanner = instance.connector.createScanner(table, new Authorizations())
scanner.setRange(range)
scanner.fetchColumnFamily(columnFamily)
Expand Down
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sbt.Keys._

ThisBuild / versionScheme := Some("semver-spec")
ThisBuild / scalaVersion := "2.12.15"
ThisBuild / scalaVersion := "2.12.16"
ThisBuild / organization := "org.locationtech.geotrellis"
ThisBuild / crossScalaVersions := List("2.12.15", "2.13.8")
ThisBuild / crossScalaVersions := List("2.12.16", "2.13.8")

lazy val root = Project("geotrellis", file("."))
.aggregate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import geotrellis.store.cassandra._
import geotrellis.store.avro.codecs.KeyValueRecordCodec
import geotrellis.store.avro.{AvroEncoder, AvroRecordCodec}
import geotrellis.store.index.{IndexRanges, MergeQueue}
import geotrellis.store.util.{BlockingThreadPool, IOUtils}
import geotrellis.store.util.{IORuntimeTransient, IOUtils}
import geotrellis.spark.util.KryoWrapper

import cats.effect._
import com.datastax.driver.core.querybuilder.QueryBuilder
import com.datastax.driver.core.querybuilder.QueryBuilder.{eq => eqs}
import org.apache.avro.Schema
Expand All @@ -35,8 +36,6 @@ import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import java.math.BigInteger

import scala.concurrent.ExecutionContext

object CassandraRDDReader {
def read[K: Boundable : AvroRecordCodec : ClassTag, V: AvroRecordCodec : ClassTag](
instance: CassandraInstance,
Expand All @@ -48,7 +47,7 @@ object CassandraRDDReader {
filterIndexOnly: Boolean,
writerSchema: Option[Schema] = None,
numPartitions: Option[Int] = None,
executionContext: => ExecutionContext = BlockingThreadPool.executionContext
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
)(implicit sc: SparkContext): RDD[(K, V)] = {
if (queryKeyBounds.isEmpty) return sc.emptyRDD[(K, V)]

Expand All @@ -73,7 +72,7 @@ object CassandraRDDReader {
sc.parallelize(bins, bins.size)
.mapPartitions { partition: Iterator[Seq[(BigInt, BigInt)]] =>
instance.withSession { session =>
implicit val ec = executionContext
implicit val ioRuntime: unsafe.IORuntime = runtime
val statement = session.prepare(query)

val result = partition map { seq =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ import geotrellis.store.avro.codecs._
import geotrellis.store.cassandra._
import geotrellis.spark.store._
import geotrellis.spark.util.KryoWrapper
import geotrellis.store.util.BlockingThreadPool
import geotrellis.store.util.IORuntimeTransient

import com.datastax.driver.core.DataType._
import com.datastax.driver.core.querybuilder.QueryBuilder
import com.datastax.driver.core.querybuilder.QueryBuilder.{eq => eqs}
import com.datastax.driver.core.ResultSet
import com.datastax.driver.core.schemabuilder.SchemaBuilder
import cats.effect.IO
import cats.syntax.apply._
import cats.effect._
import cats.syntax.either._
import org.apache.avro.Schema
import org.apache.spark.rdd.RDD
Expand All @@ -39,7 +38,6 @@ import java.nio.ByteBuffer
import java.math.BigInteger

import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext

object CassandraRDDWriter {
def write[K: AvroRecordCodec, V: AvroRecordCodec](
Expand All @@ -49,8 +47,8 @@ object CassandraRDDWriter {
decomposeKey: K => BigInt,
keyspace: String,
table: String,
executionContext: => ExecutionContext = BlockingThreadPool.executionContext
): Unit = update(rdd, instance, layerId, decomposeKey, keyspace, table, None, None, executionContext)
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
): Unit = update(rdd, instance, layerId, decomposeKey, keyspace, table, None, None, runtime)

private[cassandra] def update[K: AvroRecordCodec, V: AvroRecordCodec](
raster: RDD[(K, V)],
Expand All @@ -61,7 +59,7 @@ object CassandraRDDWriter {
table: String,
writerSchema: Option[Schema],
mergeFunc: Option[(V,V) => V],
executionContext: => ExecutionContext = BlockingThreadPool.executionContext
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
): Unit = {
implicit val sc = raster.sparkContext

Expand Down Expand Up @@ -110,15 +108,13 @@ object CassandraRDDWriter {

val rows: fs2.Stream[IO, (BigInt, Vector[(K,V)])] =
fs2.Stream.fromIterator[IO](
partition.map { case (key, value) => (key, value.toVector) }, 1
partition.map { case (key, value) => (key, value.toVector) }, chunkSize = 1
)

implicit val ec = executionContext
// TODO: runime should be configured
import cats.effect.unsafe.implicits.global
implicit val ioRuntime: unsafe.IORuntime = runtime

def elaborateRow(row: (BigInt, Vector[(K,V)])): fs2.Stream[IO, (BigInt, Vector[(K,V)])] = {
fs2.Stream eval IO {
fs2.Stream eval IO.blocking {
val (key, current) = row
val updated = LayerWriter.updateRecords(mergeFunc, current, existing = {
val oldRow = session.execute(readStatement.bind(key: BigInteger))
Expand All @@ -143,7 +139,7 @@ object CassandraRDDWriter {

def retire(row: (BigInt, ByteBuffer)): fs2.Stream[IO, ResultSet] = {
val (id, value) = row
fs2.Stream eval IO {
fs2.Stream eval IO.blocking {
session.execute(writeStatement.bind(id: BigInteger, value))
}
}
Expand All @@ -154,7 +150,7 @@ object CassandraRDDWriter {
.map(retire)
.parJoinUnbounded
.onComplete {
fs2.Stream eval IO {
fs2.Stream eval IO.blocking {
session.closeAsync()
session.getCluster.closeAsync()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import geotrellis.store.avro.codecs.KeyValueRecordCodec
import geotrellis.store.avro.{AvroEncoder, AvroRecordCodec}
import geotrellis.store.index.MergeQueue
import geotrellis.store.LayerId
import geotrellis.store.util.BlockingThreadPool
import geotrellis.store.util.IORuntimeTransient

import cats.effect._
import org.apache.avro.Schema
import com.datastax.driver.core.querybuilder.QueryBuilder
import com.datastax.driver.core.querybuilder.QueryBuilder.{eq => eqs}
Expand All @@ -32,8 +33,6 @@ import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import java.math.BigInteger

import scala.concurrent.ExecutionContext

object CassandraCollectionReader {
def read[K: Boundable : AvroRecordCodec : ClassTag, V: AvroRecordCodec : ClassTag](
instance: CassandraInstance,
Expand All @@ -44,7 +43,7 @@ object CassandraCollectionReader {
decomposeBounds: KeyBounds[K] => Seq[(BigInt, BigInt)],
filterIndexOnly: Boolean,
writerSchema: Option[Schema] = None,
executionContext: ExecutionContext = BlockingThreadPool.executionContext
runtime: => unsafe.IORuntime = IORuntimeTransient.IORuntime
): Seq[(K, V)] = {
if (queryKeyBounds.isEmpty) return Seq.empty[(K, V)]

Expand All @@ -56,7 +55,7 @@ object CassandraCollectionReader {
else
queryKeyBounds.flatMap(decomposeBounds)

implicit val ec = executionContext
implicit val ioRuntime: unsafe.IORuntime = runtime

val query = QueryBuilder.select("value")
.from(keyspace, table)
Expand All @@ -68,7 +67,7 @@ object CassandraCollectionReader {
instance.withSessionDo { session =>
val statement = session.prepare(query)

IOUtils.parJoin[K, V](ranges.iterator){ index: BigInt =>
IOUtils.parJoin[K, V](ranges.iterator) { index: BigInt =>
val row = session.execute(statement.bind(index: BigInteger))
if (row.asScala.nonEmpty) {
val bytes = row.one().getBytes("value").array()
Expand Down
Loading

0 comments on commit 301c401

Please sign in to comment.