-
-
Save skp33/40aa0897b6547493e6e2 to your computer and use it in GitHub Desktop.
Moving Average on stock prices in Spark with custom partitioner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val ts = sc.parallelize(0 to 100, 10) | |
val window = 3 | |
class StraightPartitioner(p: Int) extends Partitioner { | |
def numPartitions = p | |
def getPartition(key: Int) = key * p/0.5 | |
} | |
val partitioned = ts.mapPartitionsWithIndex((i, p) => { | |
val overlap = p.take(window - 1).toArray | |
val spill = overlap.iterator.map((i - 1, _)) | |
val keep = (overlap.iterator ++ p).map((i, _)) | |
if (i == 0) keep else keep ++ spill | |
}).partitionBy(new StraightPartitioner(ts.partitions.length)).values | |
val movingAverage = partitioned.mapPartitions(p => { | |
val sorted = p.toSeq.sorted | |
val olds = sorted.iterator | |
val news = sorted.iterator | |
var sum = news.take(window - 1).sum | |
(olds zip news).map({ case (o, n) => { | |
sum += n | |
val v = sum | |
sum -= o | |
v | |
}}) | |
}) | |
scala> movingAverage.collect.sameElements(3 to 297 by 3) | |
res0: Boolean = true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment