Created
October 16, 2019 14:41
-
-
Save ABeltramo/51eb2bdea3a34028b157078d6f43463d to your computer and use it in GitHub Desktop.
KerasJavaPredictor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<dependencies> | |
<dependency> | |
<groupId>org.deeplearning4j</groupId> | |
<artifactId>deeplearning4j-core</artifactId> | |
<version>1.0.0-beta3</version> | |
</dependency> | |
</dependencies> | |
<dependencies> | |
<dependency> | |
<groupId>org.deeplearning4j</groupId> | |
<artifactId>deeplearning4j-modelimport</artifactId> | |
<version>1.0.0-beta3</version> | |
</dependency> | |
</dependencies> | |
<dependencies> | |
<dependency> | |
<groupId>org.nd4j</groupId> | |
<artifactId>nd4j-native-platform</artifactId> | |
<version>1.0.0-beta3</version> | |
</dependency> | |
</dependencies> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package co.chatterbox.xai.image; | |
import co.chatterbox.image_ablation.ModelOutput; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import java.util.List; | |
public class KerasJavaOutput implements ModelOutput { | |
private INDArray predictions; | |
private List<String> classes; | |
KerasJavaOutput(List<String> classes, INDArray predictions) { | |
this.classes = classes; | |
this.predictions = predictions; | |
} | |
@Override | |
public double getScoreForClass(String classId) { | |
int pos = classes.indexOf(classId); | |
return predictions.getDouble(pos); | |
} | |
@Override | |
public String getPredictedClass() { | |
int maxAt = 0; | |
for (int i = 0; i < predictions.columns(); i++) { // search the maximum predicted score | |
maxAt = predictions.getDouble(i) > predictions.getDouble(maxAt) ? i : maxAt; | |
} | |
return classes.get(maxAt); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package co.chatterbox.xai.image; | |
import co.chatterbox.image_ablation.ModelOutput; | |
import co.chatterbox.image_ablation.PredictionInterface; | |
import co.chatterbox.xai.spi.ImagePredictor; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.datavec.image.transform.ColorConversionTransform; | |
import org.deeplearning4j.nn.conf.CacheMode; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import javax.imageio.ImageIO; | |
import java.awt.image.BufferedImage; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.*; | |
import java.util.stream.Collectors; | |
import static org.bytedeco.javacpp.opencv_imgproc.COLOR_BGR2RGB; | |
public class KerasJavaPredictor implements PredictionInterface { | |
private List<String> classes; | |
private ComputationGraph model; | |
// Hard coded params in Keras models | |
private int width = 224; | |
private int height = 224; | |
private int channels = 3; | |
private NativeImageLoader imgLoader = new NativeImageLoader(height, width, channels, new ColorConversionTransform(COLOR_BGR2RGB)); | |
private INDArray preProcess(BufferedImage image) { | |
// convert to matrix | |
INDArray img; | |
try { | |
img = imgLoader.asMatrix(image); | |
} catch (IOException e) { | |
throw new RuntimeException(e); | |
} | |
return img; | |
} | |
@Override | |
public ModelOutput predict(BufferedImage bufferedImage) { | |
INDArray image = preProcess(bufferedImage); | |
return new KerasJavaOutput(classes, model.outputSingle(false, image)); | |
} | |
@Override | |
public List<ModelOutput> predictBatch(List<BufferedImage> images) { | |
return images.stream().map(this::predict).collect(Collectors.toList()); | |
} | |
public KerasJavaPredictor(List<String> classes, String modelPath) { | |
this.classes = classes; | |
try { | |
model = KerasModelImport.importKerasModelAndWeights(modelPath, false); | |
model.setCacheMode(CacheMode.DEVICE); | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
return this; | |
} | |
// TEST | |
public static void main(String[] args) { | |
KerasJavaPredictor predictor = new KerasJavaPredictor(); | |
Map params = new HashMap<String, Object>(); | |
params.put("model-path", "/.../cats_vs_dogs.hdf5"); | |
params.put("classes", Arrays.asList("dogs", "cats")); | |
predictor = (KerasJavaPredictor) predictor.initialize(params); | |
List<BufferedImage> images = new ArrayList<>(); | |
try { | |
images.add(ImageIO.read(new File("/.../dogs_vs_cats/dog.jpg"))); | |
images.add(ImageIO.read(new File("/.../dogs_vs_cats/dog.jpg"))); | |
images.add(ImageIO.read(new File("/.../dogs_vs_cats/dog.jpg"))); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
List res = predictor.predictBatch(images); | |
System.out.println(res); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment