Created
August 10, 2020 08:04
-
-
Save choowilson/49dd9295311341956e599bb860c4309f to your computer and use it in GitHub Desktop.
work in progress.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package global.skymind.training.segmentation.car; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.*; | |
import org.deeplearning4j.nn.conf.graph.MergeVertex; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.*; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.util.Random; | |
public class UNetFromScratch { | |
private static int epochs = 1; | |
private static int batchSize = 32; | |
private static int seed = 123; | |
private static int numClasses = 2; | |
private static int height = 224; | |
private static int width = 224; | |
private static int channel = 3; | |
private static final Random randNumGen = new Random(seed); | |
public static void main(String[] args) { | |
CacheMode cacheMode = CacheMode.NONE; | |
WorkspaceMode workspaceMode = WorkspaceMode.ENABLED; | |
ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; | |
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(new Adam(0.001)) | |
.weightInit(WeightInit.XAVIER) | |
.l2(5.0E-5D) | |
.miniBatch(true) | |
.cacheMode(cacheMode) | |
.trainingWorkspaceMode(workspaceMode) | |
.inferenceWorkspaceMode(workspaceMode) | |
.graphBuilder() | |
.addLayer("conv1-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"input"}) | |
.addLayer("conv1-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv1-1"}) | |
.addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(new int[]{2, 2}).build(), new String[]{"conv1-2"}) | |
.addLayer("conv2-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"pool1"}) | |
.addLayer("conv2-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv2-1"}) | |
.addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(new int[]{2, 2}).build(), new String[]{"conv2-2"}) | |
.addLayer("conv3-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"pool2"}) | |
.addLayer("conv3-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv3-1"}) | |
.addLayer("pool3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(new int[]{2, 2}).build(), new String[]{"conv3-2"}) | |
.addLayer("conv4-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"pool3"}) | |
.addLayer("conv4-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv4-1"}) | |
.addLayer("drop4", new DropoutLayer.Builder(0.5D).build(), new String[]{"conv4-2"}) | |
.addLayer("pool4", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(new int[]{2, 2}).build(), new String[]{"drop4"}) | |
.addLayer("conv5-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(1024) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"pool4"}) | |
.addLayer("conv5-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(1024) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv5-1"}) | |
.addLayer("drop5", new DropoutLayer.Builder(0.5D).build(), new String[]{"conv5-2"}) | |
.addLayer("up6-1", new Upsampling2D.Builder(2).build(), new String[]{"drop5"}) | |
.addLayer("up6-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(512) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"up6-1"}).addVertex("merge6", new MergeVertex(), new String[]{"drop4", "up6-2"}) | |
.addLayer("conv6-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"merge6"}) | |
.addLayer("conv6-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv6-1"}) | |
.addLayer("up7-1", new Upsampling2D.Builder(2).build(), new String[]{"conv6-2"}) | |
.addLayer("up7-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(256) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"up7-1"}).addVertex("merge7", new MergeVertex(), new String[]{"conv3-2", "up7-2"}) | |
.addLayer("conv7-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"merge7"}) | |
.addLayer("conv7-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv7-1"}) | |
.addLayer("up8-1", new Upsampling2D.Builder(2).build(), new String[]{"conv7-2"}) | |
.addLayer("up8-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(128) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"up8-1"}).addVertex("merge8", new MergeVertex(), new String[]{"conv2-2", "up8-2"}) | |
.addLayer("conv8-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"merge8"}) | |
.addLayer("conv8-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv8-1"}) | |
.addLayer("up9-1", new Upsampling2D.Builder(2).build(), new String[]{"conv8-2"}) | |
.addLayer("up9-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(64) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"up9-1"}).addVertex("merge9", new MergeVertex(), new String[]{"conv1-2", "up9-2"}) | |
.addLayer("conv9-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"merge9"}) | |
.addLayer("conv9-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv9-1"}) | |
.addLayer("conv9-3", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(2) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.RELU).build(), new String[]{"conv9-2"}) | |
.addLayer("conv10", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(1) | |
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode) | |
.activation(Activation.IDENTITY).build(), new String[]{"conv9-3"}) | |
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build(), new String[]{"conv10"}) | |
.addInputs(new String[]{"input"}).setInputTypes(new InputType[]{InputType.convolutional(224, 224, 1)}) | |
.setOutputs(new String[]{"output"}) | |
.build(); | |
ComputationGraph model = new ComputationGraph(graph); | |
model.init(); | |
System.out.println(model.summary()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment