Last active
December 30, 2015 20:59
-
-
Save twolfe18/7884701 to your computer and use it in GitHub Desktop.
how to write neat feature functions in scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object IDontCare { | |
def main(args: Array[String]) { | |
example | |
example2 | |
example3 | |
} | |
def conjoin[I,P,O](f1: I => P, f2: I => P, compose: (P, P) => O): I => O = { | |
(i: I) => { | |
val p1 = f1(i) | |
val p2 = f2(i) | |
compose(p1, p2) | |
} | |
} | |
def conjoin2[I,P1,P2,O](f1: I => P1, f2: I => P2, composeValues: (P1, P2) => O): I => O = { | |
(i: I) => composeValues(f1(i), f2(i)) | |
} | |
def example { | |
val f1 = (i: String) => i.size.toDouble | |
val f2 = (i: String) => if(i.startsWith("a")) 1d else 0d | |
val compose = (d1: Double, d2: Double) => d1 * d2 | |
val fc: String => Double = conjoin(f1, f2, compose) | |
println("### example ###") | |
for(i <- Seq("air", "dog")) | |
println("fc(\"%s\") = %f".format(i, fc(i))) | |
} | |
// this is way more slick! | |
sealed case class OneHot(val hot: Int, val range: Int) { | |
assert(hot >= 0) | |
assert(hot < range) | |
def this(b: Boolean) = this(if(b) 1 else 0, 2) | |
} | |
object OneHot { | |
def compose(oh1: OneHot, oh2: OneHot): OneHot = { | |
val newRange = oh1.range * oh2.range | |
val newVal = oh1.hot * oh2.range + oh2.hot | |
OneHot(newVal, newRange) | |
} | |
} | |
def example2 { | |
val f1 = (i: String) => { | |
val v = | |
if(i == i.toUpperCase) 0 // all caps | |
else if(i == i.toLowerCase) 1 // all lower | |
else 2 // mixed | |
OneHot(v, 3) | |
} | |
val f2 = (i: String) => new OneHot(i.toLowerCase.contains("dog")) | |
val fc = conjoin2(f1, f2, OneHot.compose) | |
println("### example2 ###") | |
for(i <- Seq("dog", "DOG", "Dog", "dont", "DONT", "Dont")) | |
println("fc(\"%s\") = %s".format(i, fc(i))) | |
} | |
// lets try to get somthing that generates all conjunctions of features | |
implicit class OneHotFeature[T](val func: T => OneHot) extends Function[T, OneHot] { | |
override def apply(t: T): OneHot = func(t) | |
} | |
object OneHotFeature { | |
def compose[T](f1: OneHotFeature[T], f2: OneHotFeature[T]): OneHotFeature[T] = { | |
new OneHotFeature((t: T) => OneHot.compose(f1(t), f2(t))) | |
} | |
} | |
def allPairs[T](features: Seq[OneHotFeature[T]]): Seq[OneHotFeature[T]] = { | |
for(f1 <- features; f2 <- features if f1 != f2) | |
yield OneHotFeature.compose(f1, f2) | |
} | |
import collection.mutable.{ ArrayBuffer, HashSet } | |
def greedyFeatureSelection[T](baseFeatures: Seq[OneHotFeature[T]], criterion: Seq[OneHotFeature[T]] => Double, howMany: Int): Seq[OneHotFeature[T]] = { | |
val buf = new ArrayBuffer[OneHotFeature[T]] | |
buf ++= baseFeatures | |
val pairs = new HashSet[OneHotFeature[T]] ++ allPairs(baseFeatures) | |
for(i <- 1 to howMany) { | |
val bestF = pairs.maxBy(f => criterion((buf :+ f).toSeq)) | |
pairs -= bestF | |
buf += bestF | |
} | |
buf.toSeq | |
} | |
import util.Random | |
def example3 { | |
val features = new ArrayBuffer[OneHotFeature[String]] | |
features += ((i: String) => { | |
val v = | |
if(i == i.toUpperCase) 0 // all caps | |
else if(i == i.toLowerCase) 1 // all lower | |
else 2 // mixed | |
OneHot(v, 3) | |
}) | |
features += ((i: String) => new OneHot(i.toLowerCase.contains("dog"))) | |
features += ((i: String) => new OneHot(i.toLowerCase.contains("cat"))) | |
def criterion(feats: Seq[OneHotFeature[String]]): Double = Random.nextDouble // i don't really want to fake data for this... | |
val pairs = greedyFeatureSelection(features.toSeq, criterion, 2) | |
println(pairs) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment