Last active
January 30, 2020 03:54
-
-
Save johnynek/48ab7f060dda1bf0edc7be8229576048 to your computer and use it in GitHub Desktop.
A minimal example of a test runner for https://github.com/sbt/test-interface. Any comments or reviews welcome.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* copyright 2020 P. Oscar Boykin <[email protected]> | |
* licensed under Apache V2 license: https://www.apache.org/licenses/LICENSE-2.0.html | |
* or, at your option, GPLv2 or later: https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html | |
*/ | |
package testrunner | |
import io.github.classgraph._ | |
import java.io.PrintWriter | |
import java.net.{URL, URLClassLoader} | |
import java.nio.file.{Path, Paths} | |
import sbt.testing._ | |
import scala.collection.immutable.Queue | |
import scala.collection.JavaConverters._ | |
object TestRunner { | |
sealed abstract class Select { | |
import Select._ | |
final def toSelector: Selector = | |
this match { | |
case TestNameContains(n) => new TestWildcardSelector(n) | |
case EntireSuite => suiteSelector | |
case NestedTest(s, t) => new NestedTestSelector(s, t) | |
case NestedSuite(s) => new NestedSuiteSelector(s) | |
} | |
} | |
object Select { | |
private val suiteSelector: Selector = new SuiteSelector | |
case class TestNameContains(name: String) extends Select | |
case object EntireSuite extends Select | |
case class NestedTest(suiteName: String, testName: String) extends Select | |
case class NestedSuite(suiteName: String) extends Select | |
} | |
sealed abstract class Explicitness(val isExplicit: Boolean) | |
object Explicitness { | |
case object Explicit extends Explicitness(true) | |
case object NotExplicit extends Explicitness(false) | |
} | |
case class TestClassName(name: String) | |
sealed abstract class Scanned { | |
def findFrameworks: List[Framework] | |
def findTestClassesFromJars(f: Framework, testJars: ClassInfoList): List[(TestClassName, Fingerprint)] | |
def allClasses: ClassInfoList | |
def close(): Unit | |
} | |
def scanJars(cp: List[Path]): Scanned = | |
fromClassGraph( | |
(new ClassGraph) | |
.overrideClasspath(cp.asJava) | |
.ignoreParentClassLoaders()) | |
def scanClassLoader(cl: ClassLoader): Scanned = | |
fromClassGraph((new ClassGraph).addClassLoader(cl)) | |
private def fromClassGraph(withCL: ClassGraph): Scanned = | |
new Scanned { | |
val scan = withCL | |
.enableClassInfo | |
.enableMethodInfo // need this to check for constructors | |
.enableAnnotationInfo // need this for annotation scanning | |
.ignoreMethodVisibility // we need to see private constructors | |
.scan | |
def close() = scan.close() | |
def allClasses: ClassInfoList = scan.getAllClasses | |
def subsOf(nm: String): ClassInfoList = { | |
val ci = scan.getClassInfo(nm) | |
if (ci.isInterface()) scan.getClassesImplementing(nm) | |
else scan.getSubclasses(nm) | |
} | |
def findFrameworks: List[Framework] = { | |
val fwclass = classOf[Framework] | |
val frameworks = subsOf(fwclass.getName) | |
frameworks | |
.iterator | |
.asScala | |
.map { ci => | |
val cls = ci.loadClass() | |
val obj = cls.newInstance() | |
fwclass.cast(obj) | |
} | |
.toList | |
} | |
def moduleFilter(mod: Boolean, names: Iterator[ClassInfo]): Iterator[(ClassInfo, TestClassName)] = | |
if (mod) names.flatMap { ci => | |
val n = ci.getName() | |
if (n.endsWith("$")) (ci, TestClassName(n.dropRight(1))) :: Nil | |
else Nil | |
} | |
else names.flatMap { ci => | |
val n = ci.getName() | |
if (n.endsWith("$")) Nil | |
else (ci, TestClassName(n)) :: Nil | |
} | |
private def findSubClasses(subs: Seq[SubclassFingerprint], tests: ClassInfoList): List[(TestClassName, Fingerprint)] = | |
subs.groupBy(_.superclassName) | |
.iterator | |
.flatMap { case (className, subs) => | |
// these are all candidates, but we need to ensure they match the other requirements | |
val candidates = subsOf(className).intersect(tests) | |
// there are four possibilities now: | |
// reqNoArg = { true, false } x isModule = { true, false } | |
def hasNoArg(ci: ClassInfo): Boolean = | |
ci.getConstructorInfo | |
.asScala | |
.exists { mi => | |
mi.getParameterInfo().isEmpty | |
} | |
val selectedNames: Seq[(ClassInfo, TestClassName, Fingerprint)] = | |
subs.flatMap { sub => | |
val c1 = candidates.iterator.asScala | |
val c2 = if (sub.requireNoArgConstructor) c1.filter(hasNoArg) else c1 | |
moduleFilter(sub.isModule, c2).map { case (ci, t) => (ci, t, sub) } | |
} | |
val notSelected = candidates.exclude(new ClassInfoList(selectedNames.map(_._1).asJava)) | |
// all the notSelected items should be abstract or we likely have a programming error | |
val errs = notSelected.iterator.asScala.filterNot { ci => ci.isAbstract || ci.isInterface } | |
if (errs.nonEmpty) { | |
sys.error(s"tests: ${errs.toList} ignored, this is likely an error with your tests") | |
} | |
else { | |
selectedNames.map { case (_, t, f) => (t, f) } | |
} | |
} | |
.toList | |
private def findAnnotated(anns: Seq[AnnotatedFingerprint], tests: ClassInfoList): List[(TestClassName, Fingerprint)] = | |
anns.groupBy(_.annotationName) | |
.iterator | |
.flatMap { case (aname, anns) => | |
val classes = scan.getClassesWithAnnotation(aname) | |
val methods = scan.getClassesWithMethodAnnotation(aname) | |
val candidates = classes.union(methods).intersect(tests) | |
val selectedNames: Seq[(ClassInfo, TestClassName, Fingerprint)] = | |
anns.flatMap { ann => | |
val c1 = candidates.iterator.asScala | |
moduleFilter(ann.isModule, c1).map { case (ci, t) => (ci, t, ann) } | |
} | |
val notSelected = candidates.exclude(new ClassInfoList(selectedNames.map(_._1).asJava)) | |
// all the notSelected items should be abstract or we likely have a programming error | |
val errs = notSelected.iterator.asScala.filterNot { ci => ci.isAbstract || ci.isInterface } | |
if (errs.nonEmpty) { | |
sys.error(s"tests: ${errs.toList} ignored, this is likely an error with your tests") | |
} | |
else { | |
selectedNames.map { case (_, t, f) => (t, f) } | |
} | |
} | |
.toList | |
def findTestClassesFromJars(f: Framework, tests: ClassInfoList): List[(TestClassName, Fingerprint)] = { | |
val fps = f.fingerprints | |
fps.foreach { | |
case _: SubclassFingerprint => () | |
case _: AnnotatedFingerprint => () | |
case other => sys.error(s"unknown fingerprint type: $other") | |
} | |
val subclasses = fps.collect { case sc: SubclassFingerprint => sc } | |
val anns = fps.collect { case an: AnnotatedFingerprint => an } | |
(findSubClasses(subclasses, tests) ::: findAnnotated(anns, tests)) | |
.distinct | |
.sortBy(_._1.name) | |
} | |
} | |
def makeTask(cn: TestClassName, fp: Fingerprint, explicit: Explicitness, selectors: Seq[Select]): TaskDef = | |
new TaskDef(cn.name, fp, explicit.isExplicit, selectors.iterator.map(_.toSelector).toArray) | |
/** | |
* We don't want to depend on cats just to get a few functions | |
*/ | |
trait Monoid[A] { self => | |
def empty: A | |
def combine(a: A, b: A): A | |
def combineAll(i: Iterable[A]): A = | |
if (i.isEmpty) empty | |
else i.reduce(combine(_, _)) | |
} | |
object Monoid { | |
def apply[A](implicit monoid: Monoid[A]): Monoid[A] = monoid | |
implicit def mapMonoid[A, B](implicit B: Monoid[B]): Monoid[Map[A, B]] = | |
new Monoid[Map[A, B]] { | |
def empty = Map.empty | |
def combine(a: Map[A, B], b: Map[A, B]): Map[A, B] = { | |
val (fn, big, small) = | |
if (a.size >= b.size) ({(a: B, b: B) => B.combine(a, b) }, a, b) | |
else ({(a: B, b: B) => B.combine(b, a) }, b, a) | |
small.foldLeft(big) { case (big, (k, sv)) => | |
big.get(k) match { | |
case None => big.updated(k, sv) | |
case Some(bv) => big.updated(k, fn(bv, sv)) | |
} | |
} | |
} | |
} | |
} | |
implicit class ListFoldMap[A](val list: List[A]) extends AnyVal { | |
def foldMap[B](fn: A => B)(implicit m: Monoid[B]): B = { | |
@annotation.tailrec | |
def loop(list: List[A], acc: B): B = | |
list match { | |
case Nil => acc | |
case h :: t => loop(t, m.combine(acc, fn(h))) | |
} | |
list match { | |
case Nil => m.empty | |
case h :: t => loop(t, fn(h)) | |
} | |
} | |
} | |
case class Statistics(count: Long, totalDuration: Option[Long], failures: List[Throwable]) { | |
def +(that: Statistics): Statistics = { | |
val dur = for { | |
a <- totalDuration | |
b <- that.totalDuration | |
} yield a + b | |
Statistics(count + that.count, dur, failures ::: that.failures) | |
} | |
} | |
object Statistics { | |
private val emptyNone: Statistics = Statistics(1L, None, Nil) | |
def fromDur(d: Long, failure: Option[Throwable]): Statistics = | |
if (d < 0) { | |
failure match { | |
case None => emptyNone | |
case Some(h) => Statistics(1L, None, h :: Nil) | |
} | |
} | |
else Statistics(1L, Some(d), failure.toList) | |
implicit val statisticsMonoid: Monoid[Statistics] = | |
new Monoid[Statistics] { | |
val empty: Statistics = Statistics(0L, Some(0L), Nil) | |
def combine(a: Statistics, b: Statistics) = a + b | |
} | |
def fromEvent(e: Event): Statistics = { | |
val optT = e.throwable() | |
fromDur( | |
e.duration(), | |
if (optT.isEmpty) None else Some(optT.get())) | |
} | |
} | |
sealed abstract class LogKind(val kind: String) | |
object LogKind { | |
case object Debug extends LogKind("DEBUG") | |
case object Error extends LogKind("ERROR") | |
case object Info extends LogKind("INFO") | |
case object Warn extends LogKind("WARN") | |
} | |
class AggregatingLogger extends Logger { | |
private val mutex = new AnyRef | |
private var queue: Queue[Either[Throwable, (LogKind, String)]] = | |
Queue.empty | |
def getLog(): Queue[Either[Throwable, (LogKind, String)]] = | |
mutex.synchronized { queue } | |
private def put(lk: LogKind, s: String): Unit = | |
mutex.synchronized { | |
queue = queue :+ Right((lk, s)) | |
} | |
def ansiCodesSupported(): Boolean = false | |
def debug(d: String): Unit = | |
put(LogKind.Debug, d) | |
def error(e: String): Unit = | |
put(LogKind.Error, e) | |
def info(i: String): Unit = | |
put(LogKind.Info, i) | |
def trace(t: Throwable): Unit = | |
mutex.synchronized { | |
queue = queue :+ Left(t) | |
} | |
def warn(w: String): Unit = | |
put(LogKind.Warn, w) | |
} | |
class AggregatingEventHandler extends EventHandler { | |
private val mutex = new AnyRef | |
private var stats = Map.empty[Status, Statistics] | |
def getStats(): Map[Status, Statistics] = | |
mutex.synchronized { | |
stats | |
} | |
def handle(e: Event): Unit = | |
mutex.synchronized { | |
stats = Monoid[Map[Status, Statistics]].combine(stats, Map(e.status() -> Statistics.fromEvent(e))) | |
} | |
} | |
case class TaskResult(log: Queue[Either[Throwable, (LogKind, String)]], stats: Map[Status, Statistics]) | |
object TaskResult { | |
implicit val taskResultMonoid: Monoid[TaskResult] = | |
new Monoid[TaskResult] { | |
val empty: TaskResult = TaskResult(Queue.empty, Map.empty) | |
def combine(a: TaskResult, b: TaskResult): TaskResult = | |
TaskResult( | |
a.log ++ b.log, | |
Monoid[Map[Status, Statistics]].combine(a.stats, b.stats)) | |
} | |
} | |
trait HandlerLogger { | |
def eventHandler: EventHandler | |
def loggers: Seq[Logger] | |
def results(): TaskResult | |
} | |
def makeHandlerLogger(): HandlerLogger = | |
new HandlerLogger { | |
val log = new AggregatingLogger | |
val eventHandler = new AggregatingEventHandler | |
val loggers = log :: Nil | |
def results() = | |
TaskResult(log.getLog(), eventHandler.getStats()) | |
} | |
def executeAll(task: Task): Map[TaskDef, TaskResult] = { | |
val hl = makeHandlerLogger() | |
val loggers = hl.loggers.toArray | |
val next = task.execute(hl.eventHandler, loggers) | |
Monoid[Map[TaskDef, TaskResult]] | |
.combine( | |
Map(task.taskDef -> hl.results), | |
next.toList.foldMap(executeAll(_))) | |
} | |
def complete(fw: Framework, doneMessage: String, results: Map[TaskDef, TaskResult], out: PrintWriter): Unit = { | |
def printResult(fqn: String, tr: TaskResult): Unit = { | |
out.println(s"test class: $fqn") | |
tr.log.foreach { | |
case Left(err) => | |
out.println("") | |
err.printStackTrace(out) | |
out.println("") | |
case Right((k, msg)) => | |
out.println(s"${k.kind}: $msg") | |
} | |
// finally the stats | |
out.println("") | |
tr | |
.stats | |
.toList | |
.map { case (s, stat) => (s.toString, stat) } | |
.sortBy(_._1) | |
.foreach { case (k, Statistics(cnt, optTime, errs)) => | |
val time = optTime match { | |
case None => "" | |
case Some(t) => | |
val s = t.toDouble / 1000.0 | |
s", $s seconds" | |
} | |
out.println(s"$k: count = $cnt$time, ${errs.size} errors") | |
} | |
} | |
out.println(s"results from: ${fw.name}") | |
results | |
.toList | |
.sortBy(_._1.fullyQualifiedName) | |
.foreach { case (td, res) => | |
printResult(td.fullyQualifiedName, res) | |
} | |
if (doneMessage.nonEmpty) out.println(doneMessage) | |
} | |
/** | |
* cl should be full the full Classpath including all the test framework and the tests | |
* testJars should be the paths to the jars that contain the tests for this run | |
*/ | |
def runAll(cl: ClassLoader, testJars: List[Path], args: Array[String], remoteArgs: Array[String], out: PrintWriter): Unit = { | |
val scannedFW = scanClassLoader(cl) | |
val fws: List[Framework] = scannedFW.findFrameworks | |
val scannedTests = scanJars(testJars) | |
val tests: List[(String, List[(TestClassName, Fingerprint, Framework)])] = | |
fws | |
.iterator | |
.flatMap { fw => | |
scannedFW.findTestClassesFromJars(fw, scannedTests.allClasses) | |
.map { case (tn, fp) => | |
(tn, fp, fw) | |
} | |
} | |
.toList | |
.groupBy(_._3.name) | |
.toList | |
.sortBy(_._1) | |
scannedFW.close() | |
scannedTests.close() | |
if (tests.isEmpty) sys.error(s"found 0 tests. Frameworks: $fws, $tests") | |
else { | |
val allResults = tests.foldMap { | |
case (name, Nil) => Monoid[TaskResult].empty | |
case (name, ne@(h :: _)) => | |
val fw = h._3 | |
val runner = fw.runner(args, remoteArgs, cl) | |
val taskDefs = ne.map { case (cn, fp, _) => makeTask(cn, fp, Explicitness.NotExplicit, Select.EntireSuite :: Nil) } | |
val tasks = runner.tasks(taskDefs.toArray) | |
val results: Map[TaskDef, TaskResult] = tasks.toList.foldMap(executeAll(_)) | |
val message = runner.done() | |
complete(fw, message, results, out) | |
Monoid[TaskResult].combineAll(results.values) | |
} | |
val failed = allResults.stats.getOrElse(Status.Failure, Monoid[Statistics].empty).count | |
val errored = allResults.stats.getOrElse(Status.Error, Monoid[Statistics].empty).count | |
if (failed != 0L || errored != 0L) sys.error(s"test had $failed failures, $errored errors") | |
else () | |
} | |
} | |
def makeClassLoader(paths: List[Path]): ClassLoader = { | |
val root = Thread.currentThread.getContextClassLoader() | |
def toURL(p: Path): URL = new URL(s"file://${p.toString}") | |
new URLClassLoader( | |
paths.iterator.map(toURL).toArray, | |
root) | |
} | |
/** | |
* args: | |
* --dep_jar: a jar with a dependency of the test framework | |
* --test_jar: a jar containing a test | |
* --arg: an arg to the test framework | |
* --remote_arg: an arg to remotely started JVMs | |
* --output: the path to write into | |
*/ | |
def main(args: Array[String]): Unit = { | |
val deps = List.newBuilder[Path] | |
val tests = List.newBuilder[Path] | |
val testArgs = List.newBuilder[String] | |
val remoteArgs = List.newBuilder[String] | |
var output: String = "" | |
@annotation.tailrec | |
def loop(idx: Int): Unit = | |
if ((idx + 1) >= args.length) () | |
else { | |
args(idx) match { | |
case "--dep_jar" => | |
deps += Paths.get(args(idx + 1)) | |
case "--test_jar" => | |
tests += Paths.get(args(idx + 1)) | |
case "--arg" => | |
testArgs += args(idx + 1) | |
case "--remote_arg" => | |
remoteArgs += args(idx + 1) | |
case "--output" => | |
output = args(idx + 1) | |
case other => sys.error(s"unexpected arg $other in ${args.toList}") | |
} | |
loop(idx + 2) | |
} | |
loop(0) | |
if (output == "") sys.error(s"expected to get a non-empty output argument, didn't find one in: ${args.toList}") | |
val testJars = tests.result() | |
val cl = makeClassLoader(deps.result() ::: testJars) | |
val out = new PrintWriter(Paths.get(output).toFile, "UTF-8") | |
try { | |
runAll(cl, testJars, testArgs.result().toArray, remoteArgs.result().toArray, out) | |
} | |
finally { | |
out.close() | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment