static void Main(string[] args) { String positiveSamplesPath = "D:\\StromaSet\\S-114-HE_64\\training\\stroma"; String negativeSamplesPath = "D:\\StromaSet\\S-114-HE_64\\training\\not-stroma"; String rbmSavePath = "D:\\StromaSet\\weights"; String rbm0WeightsPath = "D:\\StromaSet\\weights\\RBM0_T1_769_500_139_0,04191353.weights"; String rbm1WeightsPath = "D:\\StromaSet\\weights\\RBM1_T1_500_75_375_0,07180008.weights"; String rbm2WeightsPath = "D:\\StromaSet\\weights\\RBM2_TOP_T3_76_40_16_0,1688143.weights"; int batchSize = 100; int patchWidth = 16; int patchHeight = 16; Random random = new Random(); int rbm0Visible = patchWidth * patchHeight * 3 + 1; int rbm0Hidden = 500; int rbm1Visible = rbm0Hidden; int rbm1Hidden = 75; int rbm2Visible = rbm1Hidden + 1; int rbm2Hidden = 40; IBatchGenerator generator = new ScaleBatchGenerator(positiveSamplesPath, negativeSamplesPath); //Matrix<float> rbm0Weights = WeightsHelper.generateWeights(rbm0Visible, rbm0Hidden, random); Matrix<float> rbm0Weights = WeightsHelper.loadWeights(rbm0WeightsPath); //Matrix<float> rbm1Weights = WeightsHelper.generateWeights(rbm1Visible, rbm1Hidden, random); Matrix<float> rbm1Weights = WeightsHelper.loadWeights(rbm1WeightsPath); //Matrix<float> rbm2Weights = WeightsHelper.generateWeights(rbm2Visible, rbm2Hidden, random); Matrix<float> rbm2Weights = WeightsHelper.loadWeights(rbm2WeightsPath); RBM rbm0 = new RBM(rbm0Weights, false); RBM rbm1 = new RBM(rbm1Weights, false); RBM rbm2 = new RBM(rbm2Weights, false); //RBMTrainer.IRBMInput rbm0Input = new RBM0Input(generator, batchSize, patchWidth, patchHeight); //RBMTrainer.trainRBM(rbm0, rbm0Input, 0.01f, 100000, 100, rbmSavePath, "RBM0_LABEL_T1", rbm0Visible, rbm0Hidden); //RBMTrainer.IRBMInput rbm1Input = new RBM1Input(generator, batchSize, patchWidth, patchHeight, rbm0); //RBMTrainer.trainRBM(rbm1, rbm1Input, 0.01f, 2000000, 1000, rbmSavePath, "RBM1_LABELS_T2", rbm1Visible, rbm1Hidden); RBMTrainer.IRBMInput rbm2Input = new RBM2Input(generator, batchSize, patchWidth, patchHeight, rbm0, rbm1); RBMTrainer.trainRBM(rbm2, rbm2Input, 0.03f, 1000000, 1000, rbmSavePath, "RBM2_TOP_T3", rbm2Visible, rbm2Hidden); }
static void Main(string[] args) { String positiveSamplesPath = "D:\\StromaSet\\S-114-HE_64\\crossvalidation\\stroma"; String negativeSamplesPath = "D:\\StromaSet\\S-114-HE_64\\crossvalidation\\not-stroma"; String outputPath = "D:\\StromaSet\\reconstructions"; String rbm0WeightsPath = "D:\\StromaSet\\weights\\RBM0_T1_769_500_139_0,04191353.weights"; String rbm1WeightsPath = "D:\\StromaSet\\weights\\RBM1_T1_500_75_375_0,07180008.weights"; String rbm2WeightsPath = "D:\\StromaSet\\weights\\RBM2_TOP_T3_76_40_16_0,1688143.weights"; int batchSize = 100; int patchWidth = 16; int patchHeight = 16; IBatchGenerator generator = new ScaleBatchGenerator(positiveSamplesPath, negativeSamplesPath); Matrix<float> rbm0Weights = WeightsHelper.loadWeights(rbm0WeightsPath); Matrix<float> rbm1Weights = WeightsHelper.loadWeights(rbm1WeightsPath); Matrix<float> rbm2Weights = WeightsHelper.loadWeights(rbm2WeightsPath); RBM rbm0 = new RBM(rbm0Weights, false); RBM rbm1 = new RBM(rbm1Weights, false); RBM rbm2 = new RBM(rbm2Weights, false); Matrix<float> batch = generator.nextBatch(batchSize, patchWidth, patchHeight); Matrix<float> rbm0Hidden = rbm0.getHidden(batch); Matrix<float> rbm1Hidden = rbm1.getHidden(rbm0Hidden); Matrix<float> rbm1HiddenWithEmptyLabels = MatrixHelper.addEmptyLabels(rbm1Hidden); Matrix<float> rbm2Hidden = rbm2.getHidden(rbm1HiddenWithEmptyLabels); Matrix<float> rbm2Visible = rbm2.getVisible(rbm2Hidden); Matrix<float> rbm2VisibleWithoutLabels = MatrixHelper.removeLabels(rbm2Visible); Matrix<float> rbm1Visible = rbm1.getVisible(rbm2VisibleWithoutLabels); Matrix<float> rbm0Visible = rbm0.getVisible(rbm1Visible); ImageHelper.persistOriginalAndReconstruction(patchWidth, patchHeight, batch, rbm0Visible, outputPath); Console.WriteLine("Image Reconstruction: " + RBMTrainer.reconstructionError(batch, rbm0Visible)); Console.WriteLine("Prediction Quality: " + RBMTrainer.predictionQuality(rbm2Visible)); Console.WriteLine("press key to exit: "); Console.ReadKey(); }
public static void trainRBM(RBM rbm, IRBMInput input, float learningRate, int epochs, int saveInterval, String saveDir, String trainingName, int visibleLayer, int hiddenLayer) { input.generateInput(); Matrix<float> currentInput = input.getInput(); float minError = float.MaxValue; Matrix<float> minWeights = null; float error = float.MaxValue; int repeat = epochs / saveInterval; for (int i = 0; i < repeat; ++i) { for (int j = 0; j < saveInterval; ++j) { Thread thread = new Thread(input.generateInput); thread.Start(); error = rbm.train(currentInput, learningRate); Console.WriteLine(trainingName + "; Epoche: " + (i * saveInterval + j) + "; Error: " + error); if (error < minError) { minError = error; minWeights = rbm.getWeights(); } thread.Join(); currentInput = input.getInput(); } // save best weights from last interval String outputFile = saveDir + "\\" + trainingName + "_" + visibleLayer + "_" + hiddenLayer + "_" + i + "_" + minError + ".weights"; WeightsHelper.saveWeights(minWeights, outputFile); minError = float.MaxValue; Console.WriteLine("weights saved"); } }
public RBM2Input(IBatchGenerator generator, int batchSize, int patchWidth, int patchHeight, RBM rbm0, RBM rbm1) { this.generator = generator; this.batchSize = batchSize; this.patchHeight = patchWidth; this.patchWidth = patchWidth; this.rbm0 = rbm0; this.rbm1 = rbm1; }
static void Main(string[] args) { InOut io = new InOut(args); RBM rbm0 = new RBM(io.getRBM0Weights(), false); RBM rbm1 = new RBM(io.getRBM1Weights(), false); RBM rbm2 = new RBM(io.getRBM2Weights(), false); LinkedList<ParseObject> objects = io.getParseObjects(); foreach (ParseObject o in objects) { classifyImage(o, rbm0, rbm1, rbm2); } io.writeOuput(); }
private static void classifyImage(ParseObject o, RBM rbm0, RBM rbm1, RBM rbm2) { Bitmap image = o.getImage(); LinkedList<float[]> scaledPatches = new LinkedList<float[]>(); int classWhite = 0; int classStroma = 0; int classNotStroma = 0; for (int y = 0; y < image.Height - patchHeight; y += scanIncrement) { for (int x = 0; x < image.Width - patchWidth; x += scanIncrement) { Bitmap subImage = image.Clone(new Rectangle(x, y, patchWidth, patchHeight), image.PixelFormat); float[] scaledPatch = ImageHelper.generateScaledPatch(subImage, scaleWidth, scaleHeight, whiteThreshold); if (scaledPatch == null) { ++classWhite; continue; } scaledPatches.AddLast(scaledPatch); } } if (scaledPatches.Count > 0) { int columnCount = scaleWidth * scaleHeight * 3 + 1; Matrix<float> batch = Matrix<float>.Build.Dense(scaledPatches.Count, columnCount); int row = 0; foreach (float[] scaledPatch in scaledPatches) { batch.SetRow(row++, scaledPatch); } Matrix<float> rbm0Hidden = rbm0.getHidden(batch); Matrix<float> rbm1Hidden = rbm1.getHidden(rbm0Hidden); Matrix<float> rbm1HiddenWithEmptyLabels = MatrixHelper.addEmptyLabels(rbm1Hidden); Matrix<float> rbm2Hidden = rbm2.getHidden(rbm1HiddenWithEmptyLabels); Matrix<float> rbm2Visible = rbm2.getVisible(rbm2Hidden); int lastColumn = rbm2Visible.ColumnCount - 1; for (int i = 0; i < rbm2Visible.RowCount; ++i) { if (rbm2Visible.At(i, lastColumn) > 0.5f) ++classStroma; else ++classNotStroma; } } float stroma = classStroma / (float)(classNotStroma + classWhite + classStroma); Boolean isStroma = stroma > classificationThreshold; Console.WriteLine("Is Stroma: " + isStroma + ", " + stroma); Console.WriteLine("Stroma: " + classStroma + ", NotStroma: " + classNotStroma + ", White: " + classWhite); o.setStroma(isStroma); o.setStromaRatio(stroma); }