Created
May 25, 2021 12:02
-
-
Save jchapuis/8d6f60476ebc2b6813ab28daa9216d98 to your computer and use it in GitHub Desktop.
Generic LCS diffing in Scala (dynamic programming with memoization)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cats.{ Eq, Show } | |
import cats.syntax.eq._ | |
import cats.syntax.show._ | |
import cats.instances.int._ | |
import FunctionHelpers._ | |
object Differ { | |
sealed trait Diff[T] | |
object Diff { | |
final case class Insert[T](revision: T) extends Diff[T] | |
final case class Delete[T](baseline: T) extends Diff[T] | |
final case class Keep[T](baseline: T, revision: T) extends Diff[T] | |
implicit def show[T](implicit elementShow: Show[T]): Show[Diff[T]] = | |
Show.show { | |
case Diff.Insert(revision) => show"+$revision" | |
case Diff.Delete(baseline) => show"-$baseline" | |
case Diff.Keep(_, _) => show"_" | |
} | |
} | |
implicit class RichSeq[T](baseline: Seq[T]) { | |
def diffWith(revision: Seq[T])(implicit eq: Eq[T]): Seq[Diff[T]] = { | |
val matrix = computeMatrix(baseline, revision) | |
generateDiff(matrix, baseline, revision) | |
} | |
private type LCSMatrix = (Int, Int) => Int | |
// baseline is on vertical axis, revision on the horizontal | |
private def computeMatrix(baseline: Seq[T], revision: Seq[T])(implicit eq: Eq[T]): LCSMatrix = { | |
lazy val lengthMatrix: LCSMatrix = (lengthFor _).memoize | |
def lengthFor(baselineIndex: Int, revisionIndex: Int): Int = | |
if (baselineIndex === 0 || revisionIndex === 0) | |
0 | |
else if (baseline(baselineIndex - 1) === revision(revisionIndex - 1)) | |
lengthMatrix(baselineIndex - 1, revisionIndex - 1) + 1 | |
else { | |
val candidateUp = lengthMatrix(baselineIndex - 1, revisionIndex) | |
val candidateLeft = lengthMatrix(baselineIndex, revisionIndex - 1) | |
Math.max(candidateUp, candidateLeft) | |
} | |
lengthMatrix | |
} | |
private def generateDiff(matrix: LCSMatrix, baseline: Seq[T], revision: Seq[T])(implicit | |
eq: Eq[T] | |
): Seq[Diff[T]] = { | |
def generateDiff(baselineIndex: Int, revisionIndex: Int, acc: Seq[Diff[T]]): Seq[Diff[T]] = { | |
lazy val currentRow = baselineIndex | |
lazy val currentColumn = revisionIndex | |
lazy val canMoveUp = currentRow > 0 | |
lazy val canMoveLeft = currentColumn > 0 | |
lazy val rowAbove = currentRow - 1 | |
lazy val columnLeft = currentColumn - 1 | |
lazy val cellAbove = matrix(rowAbove, currentColumn) | |
lazy val cellLeft = matrix(currentRow, columnLeft) | |
if (canMoveUp && canMoveLeft && baseline(rowAbove) === revision(columnLeft)) | |
generateDiff(rowAbove, columnLeft, acc :+ Diff.Keep(baseline(rowAbove), revision(columnLeft))) | |
else if (canMoveLeft && (!canMoveUp || cellLeft >= cellAbove)) | |
generateDiff(currentRow, columnLeft, acc :+ Diff.Insert(revision(columnLeft))) | |
else if (canMoveUp && (!canMoveLeft || cellLeft < cellAbove)) | |
generateDiff(rowAbove, currentColumn, acc :+ Diff.Delete(baseline(rowAbove))) | |
else acc | |
} | |
generateDiff(baseline.size, revision.size, acc = Nil).reverse | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.scalatest.matchers.should.Matchers | |
import org.scalatest.wordspec.AnyWordSpec | |
import Differ._ | |
import cats.instances.char._ | |
import cats.syntax.show._ | |
class DifferSpec extends AnyWordSpec with Matchers { | |
"Differ" should { | |
"produce diff for simple example" in new SimpleFixture { | |
baseline.toSeq.diffWith(revision.toSeq).map(_.show).mkString shouldBe diff | |
} | |
"produce optimal diff for dna sequence" in new AdnFixture { | |
baseline.toSeq.diffWith(revision.toSeq).map(_.show).mkString shouldBe diff | |
} | |
"reproduce wikipedia example correctly" in new WikipediaFixture { | |
baseline.toSeq.diffWith(revision.toSeq).map(_.show).mkString shouldBe diff | |
} | |
} | |
trait SimpleFixture { | |
val baseline = "A B C D".filterNot(_.isWhitespace) | |
val revision = "A D".filterNot(_.isWhitespace) | |
val diff = "_-B-C_" | |
} | |
trait AdnFixture { | |
val baseline = "c t c a t g g a g c".filterNot(_.isWhitespace) | |
val revision = "t c a a t g g a".filterNot(_.isWhitespace) | |
val diff = "-c__+a_____-g-c" | |
} | |
trait WikipediaFixture { | |
val baseline = "X M J Y A U Z".filterNot(_.isWhitespace) | |
val revision = "M Z J A W X U".filterNot(_.isWhitespace) | |
val diff = "-X_+Z_-Y_+W+X_-Z" | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object FunctionHelpers { | |
@SuppressWarnings(Array("org.wartremover.warts.MutableDataStructures")) | |
implicit class RichFunction2[A1, A2, R](f: (A1, A2) => R) { | |
def memoize: (A1, A2) => R = | |
new ((A1, A2) => R) { | |
private val cache = scala.collection.mutable.Map[(A1, A2), R]() | |
override def apply(a1: A1, a2: A2): R = cache.getOrElseUpdate((a1, a2), f(a1, a2)) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment