Skip to content

Instantly share code, notes, and snippets.

@choowilson
Created August 10, 2020 08:04
Show Gist options
  • Save choowilson/49dd9295311341956e599bb860c4309f to your computer and use it in GitHub Desktop.
Save choowilson/49dd9295311341956e599bb860c4309f to your computer and use it in GitHub Desktop.
work in progress.
package global.skymind.training.segmentation.car;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Random;
public class UNetFromScratch {
private static int epochs = 1;
private static int batchSize = 32;
private static int seed = 123;
private static int numClasses = 2;
private static int height = 224;
private static int width = 224;
private static int channel = 3;
private static final Random randNumGen = new Random(seed);
public static void main(String[] args) {
CacheMode cacheMode = CacheMode.NONE;
WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder()
.seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam(0.001))
.weightInit(WeightInit.XAVIER)
.l2(5.0E-5D)
.miniBatch(true)
.cacheMode(cacheMode)
.trainingWorkspaceMode(workspaceMode)
.inferenceWorkspaceMode(workspaceMode)
.graphBuilder()
.addLayer("conv1-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"input"})
.addLayer("conv1-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv1-1"})
.addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(new int[]{2, 2}).build(), new String[]{"conv1-2"})
.addLayer("conv2-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"pool1"})
.addLayer("conv2-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv2-1"})
.addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(new int[]{2, 2}).build(), new String[]{"conv2-2"})
.addLayer("conv3-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"pool2"})
.addLayer("conv3-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv3-1"})
.addLayer("pool3", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(new int[]{2, 2}).build(), new String[]{"conv3-2"})
.addLayer("conv4-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"pool3"})
.addLayer("conv4-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv4-1"})
.addLayer("drop4", new DropoutLayer.Builder(0.5D).build(), new String[]{"conv4-2"})
.addLayer("pool4", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(new int[]{2, 2}).build(), new String[]{"drop4"})
.addLayer("conv5-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(1024)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"pool4"})
.addLayer("conv5-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(1024)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv5-1"})
.addLayer("drop5", new DropoutLayer.Builder(0.5D).build(), new String[]{"conv5-2"})
.addLayer("up6-1", new Upsampling2D.Builder(2).build(), new String[]{"drop5"})
.addLayer("up6-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"up6-1"}).addVertex("merge6", new MergeVertex(), new String[]{"drop4", "up6-2"})
.addLayer("conv6-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"merge6"})
.addLayer("conv6-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(512)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv6-1"})
.addLayer("up7-1", new Upsampling2D.Builder(2).build(), new String[]{"conv6-2"})
.addLayer("up7-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"up7-1"}).addVertex("merge7", new MergeVertex(), new String[]{"conv3-2", "up7-2"})
.addLayer("conv7-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"merge7"})
.addLayer("conv7-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(256)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv7-1"})
.addLayer("up8-1", new Upsampling2D.Builder(2).build(), new String[]{"conv7-2"})
.addLayer("up8-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"up8-1"}).addVertex("merge8", new MergeVertex(), new String[]{"conv2-2", "up8-2"})
.addLayer("conv8-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"merge8"})
.addLayer("conv8-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(128)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv8-1"})
.addLayer("up9-1", new Upsampling2D.Builder(2).build(), new String[]{"conv8-2"})
.addLayer("up9-2", new ConvolutionLayer.Builder(new int[]{2, 2}).stride(new int[]{1, 1}).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"up9-1"}).addVertex("merge9", new MergeVertex(), new String[]{"conv1-2", "up9-2"})
.addLayer("conv9-1", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"merge9"})
.addLayer("conv9-2", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(64)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv9-1"})
.addLayer("conv9-3", new ConvolutionLayer.Builder(new int[]{3, 3}).stride(new int[]{1, 1}).nOut(2)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.RELU).build(), new String[]{"conv9-2"})
.addLayer("conv10", new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(1)
.convolutionMode(ConvolutionMode.Same).cudnnAlgoMode(cudnnAlgoMode)
.activation(Activation.IDENTITY).build(), new String[]{"conv9-3"})
.addLayer("output", new CnnLossLayer.Builder(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build(), new String[]{"conv10"})
.addInputs(new String[]{"input"}).setInputTypes(new InputType[]{InputType.convolutional(224, 224, 1)})
.setOutputs(new String[]{"output"})
.build();
ComputationGraph model = new ComputationGraph(graph);
model.init();
System.out.println(model.summary());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment