Last active
June 21, 2016 14:30
-
-
Save dmmiller612/8a9a4f53a7b16eb6edff0efa200dc31a to your computer and use it in GitHub Desktop.
Java Spark Gist for Linear Regression after a group by and conversion from JavaPairRdd<Something, Iterable<Something>> to Map<Something, LinearRegressionModel.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cdapp.services; | |
import com.fasterxml.jackson.databind.ObjectMapper; | |
import com.google.common.collect.Lists; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.ml.Pipeline; | |
import org.apache.spark.ml.PipelineModel; | |
import org.apache.spark.ml.PipelineStage; | |
import org.apache.spark.ml.feature.OneHotEncoder; | |
import org.apache.spark.ml.feature.StringIndexer; | |
import org.apache.spark.ml.feature.VectorAssembler; | |
import org.apache.spark.mllib.linalg.Vector; | |
import org.apache.spark.mllib.regression.LabeledPoint; | |
import org.apache.spark.mllib.regression.LinearRegressionModel; | |
import org.apache.spark.mllib.regression.LinearRegressionWithSGD; | |
import org.apache.spark.sql.DataFrame; | |
import org.apache.spark.sql.Row; | |
import org.apache.spark.sql.RowFactory; | |
import org.apache.spark.sql.SQLContext; | |
import org.apache.spark.sql.types.DataTypes; | |
import org.apache.spark.sql.types.Metadata; | |
import org.apache.spark.sql.types.StructField; | |
import org.apache.spark.sql.types.StructType; | |
import weatherapp.models.CD; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.stream.Collectors; | |
/** | |
* The main goal of this gist is to show how to convert PAIRRDD<Something, Iterable<Something>> to Map<String, JavaRDD<Something>>, | |
* then perform machine learning on the items that were grouped by. It also includes usage of pipelines, spark sqs, one hot encoding, etc. | |
* I am still adding to this gist. | |
*/ | |
public class SparkCDGist { | |
final SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("cds"); | |
final JavaSparkContext sparkContext = new JavaSparkContext(conf); | |
final SQLContext sqlContext = new SQLContext(sparkContext); | |
/** | |
* creates a dataframe based on the incoming json, which is | |
* {"id": "String", "condition": "String", "location": "String", "price": "String"} | |
* @param rows | |
* @return | |
*/ | |
private DataFrame createDataFrame(JavaRDD<Row> rows){ | |
StructType schema = new StructType(new StructField[]{ | |
new StructField("id", DataTypes.StringType, false, Metadata.empty()), | |
new StructField("condition", DataTypes.StringType, false, Metadata.empty()), | |
new StructField("location", DataTypes.StringType, false, Metadata.empty()), | |
new StructField("price", DataTypes.IntegerType, false, Metadata.empty()) | |
}); | |
return sqlContext.createDataFrame(rows, schema); | |
} | |
/** | |
* Properly indexes the strings for one hot encoding to take place | |
*/ | |
private List<StringIndexer> getStringIndexers(List<String> keys){ | |
return keys.stream() | |
.map(item -> new StringIndexer() | |
.setInputCol(item) | |
.setOutputCol(item + "_index")) | |
.collect(Collectors.toList()); | |
} | |
/** | |
* Sets up one hot encoder for pipeline | |
*/ | |
private List<OneHotEncoder> getOneHotEncoders(List<String> keys, String suffix){ | |
return keys.stream().map(name -> new OneHotEncoder() | |
.setInputCol(name + suffix) | |
.setOutputCol(name + suffix + "_vec")) //name is originalName + _index + _vec | |
.collect(Collectors.toList()); | |
} | |
/** | |
* creates the pipeline | |
*/ | |
private Pipeline createPipeline(List<StringIndexer> stringIndexers, List<OneHotEncoder> oneHotEncoders, VectorAssembler vectorAssembler){ | |
int totalSize = stringIndexers.size() + oneHotEncoders.size(); | |
PipelineStage[] pipelineStages = new PipelineStage[totalSize + 1]; | |
for (int i = 0; i < totalSize; i++){ | |
if (i < stringIndexers.size()){ | |
pipelineStages[i] = stringIndexers.get(i); | |
} else { | |
pipelineStages[i] = oneHotEncoders.get(i - stringIndexers.size()); | |
} | |
} | |
pipelineStages[totalSize] = vectorAssembler; | |
return new Pipeline().setStages(pipelineStages); | |
} | |
/** | |
* calls handle pipeline and sets everything up. | |
*/ | |
private Pipeline handleCds(){ | |
List<String> names = Arrays.asList("condition", "location"); | |
List<StringIndexer> keys = getStringIndexers(names); | |
List<OneHotEncoder> oneHotEncoders = getOneHotEncoders(names, "_index"); | |
return createPipeline(keys, oneHotEncoders, new VectorAssembler() | |
.setInputCols(new String[]{"condition_index_vec", "location_index_vec"}) | |
.setOutputCol("features")); | |
} | |
/** | |
* Creates a Linear Regression Model | |
*/ | |
public LinearRegressionModel performMachineLearning(DataFrame df){ | |
Pipeline pipeline = handleCds(); | |
PipelineModel pm = pipeline.fit(df); | |
JavaRDD<Row> vecFeatures = pm.transform(df).select("features", "price").javaRDD(); | |
JavaRDD<LabeledPoint> labeledPoints = vecFeatures.map(item -> new LabeledPoint(new Integer(item.getInt(item.size() - 1)).doubleValue(), (Vector)item.get(0))); | |
labeledPoints.cache(); | |
int numIterations = 300; | |
double stepSize = 0.000001; | |
return LinearRegressionWithSGD.train(JavaRDD.toRDD(labeledPoints), numIterations, stepSize); | |
} | |
/** | |
* gets and RDD by key from a javaPairRDD | |
*/ | |
private JavaRDD<Row> getRddByKey(JavaPairRDD<String, Iterable<Row>> pairRDD, String key) { | |
return pairRDD.filter(v -> v._1().equals(key)).values().flatMap(tuples -> tuples); | |
} | |
/** | |
* Main method that returns a map of the string as the key and value as the regression model. | |
*/ | |
public Map<String, LinearRegressionModel> regressionModels(){ | |
JavaPairRDD<String, Iterable<Row>> cds = getCds(); | |
Map<String, LinearRegressionModel> rdds = new HashMap<>(); | |
try { | |
List<String> keys = cds.keys().distinct().collect(); | |
for (String key : keys){ | |
JavaRDD<Row> rddByKey = getRddByKey(cds, key); | |
DataFrame df = createDataFrame(rddByKey); | |
LinearRegressionModel linearRegressionModel = performMachineLearning(df); | |
rdds.put(key, linearRegressionModel); | |
} | |
} catch (Exception e) { | |
System.out.println(e.toString()); | |
} | |
return rdds; | |
} | |
/** | |
* Gets all the cds data from a hdfs. | |
*/ | |
public JavaPairRDD<String, Iterable<Row>> getCds(){ | |
return sparkContext.textFile("hdfs://localhost:9000/users/cd/*/") | |
.map(item -> { | |
Row row = null; | |
try { | |
ObjectMapper objectMapper = new ObjectMapper(); | |
CD cd = objectMapper.readValue(item, CD.class); | |
row = RowFactory.create(cd.getId(), cd.getCondition(), cd.getLocation(), cd.getPrice()); | |
} catch (Exception e){} | |
return row; | |
}) | |
.groupBy(item -> item.getString(0)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment