示例#1
0
        static void Train()
        {
            List <LabeledTomogram> tomograms = LabeledTomogramsFromPaintedFiles();

            DecisionTreeOptions options = new DecisionTreeOptions
            {
                // TODO: Fill in
                MaximumNumberOfRecursionLevels = 25,
                NumberOfFeatures        = 300,
                NumberOfThresholds      = 35,
                OffsetXMax              = 40,
                OffsetXMin              = -40,
                OffsetYMax              = 40,
                OffsetYMin              = -40,
                OutOfRangeValue         = 1000000,
                SplittingThresholdMax   = .2f,
                SufficientGainLevel     = 0,
                PercentageOfPixelsToUse = .9f,
                //DistanceThreshold = .1f,
            };

            DecisionTreeNode node = DecisionTreeBuilder.Train(tomograms, new Random(1234), options);

            BinaryFormatter bf = new BinaryFormatter();

            using (FileStream fs = File.Create("serialized.dat"))
            {
                bf.Serialize(fs, node);
            }

            return;
        }
示例#2
0
        private static List <LabeledPoint> TomogramsToPoints(List <LabeledTomogram> tomograms,
                                                             Random random, DecisionTreeOptions options)
        {
            List <LabeledPoint> points = new List <LabeledPoint>();

            foreach (LabeledTomogram tomogram in tomograms)
            {
                for (int y = 0, i = 0; y < tomogram.Height; y++)
                {
                    for (int x = 0; x < tomogram.Width; x++, i++)
                    {
                        float value = tomogram.Data[i];

                        if (random != null)
                        {
                            int label = (int)tomogram.Labels[i];
                            if (label == 1)
                            {
                                points.Add(new LabeledPoint
                                {
                                    X              = x,
                                    Y              = y,
                                    Z              = value,
                                    Label          = tomogram.Labels != null ? (int)tomogram.Labels[i] : -1,
                                    SourceTomogram = tomogram,
                                });
                            }
                            else
                            {
                                if (random.NextDouble() < options.PercentageOfPixelsToUse)
                                {
                                    points.Add(new LabeledPoint
                                    {
                                        X              = x,
                                        Y              = y,
                                        Z              = value,
                                        Label          = tomogram.Labels != null ? (int)tomogram.Labels[i] : -1,
                                        SourceTomogram = tomogram,
                                    });
                                }
                            }
                        }
                        else
                        {
                            points.Add(new LabeledPoint
                            {
                                X = x,
                                Y = y,
                                Z = value,
                                SourceTomogram = tomogram,
                            });
                        }
                    }
                }
            }

            return(points);
        }
示例#3
0
        public static float[] Predict(LabeledTomogram image, DecisionTreeNode node, DecisionTreeOptions options)
        {
            List <LabeledPoint> points = TomogramsToPoints(
                new List <LabeledTomogram>(new LabeledTomogram[] { image }), null, null);

            List <float> labels = new List <float>();

            foreach (LabeledPoint point in points)
            {
                labels.Add(RecurseAndPredict(point, node, options));
            }

            return(labels.ToArray());
        }
示例#4
0
        private static SplitDirection ComputeSplitDirection(LabeledPoint point, SplittingQuestion question,
                                                            DecisionTreeOptions options)
        {
            int frameHeight = point.SourceTomogram.Height;
            int frameWidth  = point.SourceTomogram.Width;

            int uY = point.Y + question.OffsetUY;

            if (uY >= frameHeight)
            {
                uY = -1;
            }
            int uX = point.X + question.OffsetUX;

            if (uX >= frameWidth)
            {
                uX = -1;
            }

            int vY = point.Y + question.OffsetVY;

            if (vY >= frameHeight)
            {
                vY = -1;
            }
            int vX = point.X + question.OffsetVX;

            if (vX >= frameWidth)
            {
                vX = -1;
            }

            int u = uY * frameWidth + uX;
            int v = vY * frameWidth + vX;
            int z = point.Y * frameWidth + point.X;

            float uVal = 0f, vVal = 0f, zVal = point.SourceTomogram.Data[z];

            if (u < 0 || v < 0)
            {
                uVal = vVal = options.OutOfRangeValue;
            }
            else
            {
                uVal = point.SourceTomogram.Data[u];
                //if(Math.Abs(zVal - uVal) > options.DistanceThreshold)
                //{
                //    uVal = options.OutOfRangeValue;
                //}
                vVal = point.SourceTomogram.Data[v];
                //if(Math.Abs(zVal - vVal) > options.DistanceThreshold)
                //{
                //    vVal = options.OutOfRangeValue;
                //}
            }

            if ((uVal - vVal) < question.Threshold)
            {
                return(SplitDirection.Left);
            }
            else
            {
                return(SplitDirection.Right);
            }
        }
示例#5
0
        private static int RecurseAndPredict(LabeledPoint point, DecisionTreeNode node, DecisionTreeOptions options)
        {
            if (node.IsLeaf)
            {
                return(node.Class);
            }
            else
            {
                SplitDirection direction = ComputeSplitDirection(point, node.Question, options);

                if (direction == SplitDirection.Left)
                {
                    return(RecurseAndPredict(point, node.LeftBranch, options));
                }
                else
                {
                    return(RecurseAndPredict(point, node.RightBranch, options));
                }
            }
        }
示例#6
0
        public static DecisionTreeNode Train(List <LabeledTomogram> trainingImages, Random random, DecisionTreeOptions options)
        {
            Console.WriteLine("Building points...");
            List <LabeledPoint> trainingPoints = TomogramsToPoints(trainingImages, random, options);

            Console.WriteLine("Build splitting questions...");
            List <SplittingQuestion> splittingQuestions =
                GenerateSplittingQuestions(random, options);

            DecisionTreeNode root = new DecisionTreeNode();

            Console.WriteLine("Recurse and partition...");
            RecurseAndPartition(trainingPoints, splittingQuestions,
                                1, options, root, random);

            return(root);
        }
示例#7
0
        private static void RecurseAndPartition(List <LabeledPoint> trainingPoints, List <SplittingQuestion> splittingQuestions,
                                                int currentRecursionLevel, DecisionTreeOptions options, DecisionTreeNode currentNode, Random random)
        {
            Console.WriteLine($"{new String('-', currentRecursionLevel)}{currentRecursionLevel}");

            if (currentRecursionLevel >= options.MaximumNumberOfRecursionLevels)
            {
                // create leaf node
                MakeLeafNode(currentNode, trainingPoints);
            }
            else
            {
                double            currentShannonEntropy = ComputeShannonEntropy(trainingPoints);
                double            highestGain           = double.MinValue;
                SplittingQuestion bestSplittingQuestion = null;

                //for (int s = 0; s < splittingQuestions.Count; s++)
                Parallel.For(0, splittingQuestions.Count, s =>
                {
                    //Console.Write(".");
                    //Interlocked.Increment(ref t);
                    //Console.WriteLine($"{t}/{splittingQuestions.Count}");

                    //List<LabeledPoint> leftBucket1 = new List<LabeledPoint>();
                    //List<LabeledPoint> rightBucket1 = new List<LabeledPoint>();

                    List <LabeledPointGroup> leftBucket = new List <LabeledPointGroup>();
                    leftBucket.Add(new LabeledPointGroup
                    {
                        Count = 0,
                        Class = 0
                    });
                    leftBucket.Add(new LabeledPointGroup
                    {
                        Count = 0,
                        Class = 1
                    });
                    List <LabeledPointGroup> rightBucket = new List <LabeledPointGroup>();
                    rightBucket.Add(new LabeledPointGroup
                    {
                        Count = 0,
                        Class = 0
                    });
                    rightBucket.Add(new LabeledPointGroup
                    {
                        Count = 0,
                        Class = 1
                    });

                    SplittingQuestion splittingQuestion = splittingQuestions[s];

                    for (int p = 0; p < trainingPoints.Count; p++)
                    {
                        //if (random.NextDouble() < .1 || trainingPoints.Count < 1000)
                        {
                            LabeledPoint trainingPoint = trainingPoints[p];

                            SplitDirection split = ComputeSplitDirection(trainingPoint, splittingQuestion, options);

                            if (split == SplitDirection.Left)
                            {
                                leftBucket[trainingPoint.Label].Count++;
                                //leftBucket1.Add(trainingPoint);
                            }
                            else
                            {
                                //rightBucket1.Add(trainingPoint);
                                rightBucket[trainingPoint.Label].Count++;
                            }
                        }
                    }

                    //double gain = ComputeGain(currentShannonEntropy, leftBucket1, rightBucket1);
                    double gain = ComputeGain(currentShannonEntropy, leftBucket, rightBucket);

                    lock (typeof(DecisionTreeBuilder))
                    {
                        if (gain > highestGain)
                        {
                            highestGain           = gain;
                            bestSplittingQuestion = splittingQuestion;
                        }
                    }
                });

                if (highestGain > options.SufficientGainLevel)
                {
                    List <LabeledPoint> bestLeftBucket  = new List <LabeledPoint>();
                    List <LabeledPoint> bestRightBucket = new List <LabeledPoint>();

                    for (int p = 0; p < trainingPoints.Count; p++)
                    {
                        LabeledPoint trainingPoint = trainingPoints[p];

                        SplitDirection split = ComputeSplitDirection(trainingPoint, bestSplittingQuestion, options);

                        if (split == SplitDirection.Left)
                        {
                            bestLeftBucket.Add(trainingPoint);
                        }
                        else
                        {
                            bestRightBucket.Add(trainingPoint);
                        }
                    }

                    currentNode.Question    = bestSplittingQuestion;
                    currentNode.LeftBranch  = new DecisionTreeNode();
                    currentNode.RightBranch = new DecisionTreeNode();
                    currentNode.IsLeaf      = false;

                    //System.Console.WriteLine("left: " + bestLeftBucket.Count.ToString());
                    //System.Console.WriteLine("right: " + bestRightBucket.Count.ToString());

                    //splittingQuestions =
                    //    GenerateSplittingQuestions(random, options);

                    RecurseAndPartition(bestLeftBucket, splittingQuestions,
                                        currentRecursionLevel + 1, options, currentNode.LeftBranch, random);

                    RecurseAndPartition(bestRightBucket, splittingQuestions,
                                        currentRecursionLevel + 1, options, currentNode.RightBranch, random);
                }
                else
                {
                    MakeLeafNode(currentNode, trainingPoints);
                }
            }
        }
示例#8
0
        private static List <SplittingQuestion> GenerateSplittingQuestions(Random random, DecisionTreeOptions options)
        {
            List <SplittingQuestion> ret = new List <SplittingQuestion>();

            for (int c = 0; c < options.NumberOfFeatures; c++)
            {
                int uX = random.Next(options.OffsetXMin, options.OffsetXMax);
                int uY = random.Next(options.OffsetXMin, options.OffsetXMax);
                int vX = random.Next(options.OffsetXMin, options.OffsetXMax);
                int vY = random.Next(options.OffsetXMin, options.OffsetXMax);

                for (int d = 0; d < options.NumberOfThresholds; d++)
                {
                    float threshold = (float)random.NextDouble() *
                                      options.SplittingThresholdMax * Math.Sign(random.Next(-1, 1));

                    ret.Add(new SplittingQuestion
                    {
                        OffsetUX  = uX,
                        OffsetUY  = uY,
                        OffsetVX  = vX,
                        OffsetVY  = vY,
                        Threshold = threshold
                    });
                }
            }

            return(ret);
        }
示例#9
0
        static void Test()
        {
            BinaryFormatter bf = new BinaryFormatter();

            using (FileStream fs = File.OpenRead("serialized.dat"))
            {
                DecisionTreeNode node = bf.Deserialize(fs) as DecisionTreeNode;


                DecisionTreeOptions options = new DecisionTreeOptions
                {
                    // TODO: Fill in
                    MaximumNumberOfRecursionLevels = 25,
                    NumberOfFeatures        = 300,
                    NumberOfThresholds      = 35,
                    OffsetXMax              = 40,
                    OffsetXMin              = -40,
                    OffsetYMax              = 40,
                    OffsetYMin              = -40,
                    OutOfRangeValue         = 1000000,
                    SplittingThresholdMax   = .2f,
                    SufficientGainLevel     = 0,
                    PercentageOfPixelsToUse = .9f,
                    //DistanceThreshold = .1f,
                };


                MRCFile file = MRCParser.Parse(Path.Combine("/home/brush/tomography2_fullsirtcliptrim.mrc"));

                MRCFrame frame = file.Frames[145];

                LabeledTomogram tom = new LabeledTomogram();
                tom.Width  = frame.Width;
                tom.Height = frame.Height;
                tom.Data   = new float[frame.Width * frame.Height];

                for (int i = 0; i < frame.Data.Length; i++)
                {
                    tom.Data[i] = frame.Data[i];
                }

                //for (int y = 264, i = 0; y < 364; y++)
                //{
                //    for (int x = 501; x < 601; x++, i++)
                //    {
                //        tom.Data[i] = frame.Data[y * frame.Width + x];
                //    }
                //}


                float[] labels = DecisionTreeBuilder.Predict(tom, node, options);
                //Bitmap bmp = DataManipulator.Tomogram2Bitmap(tom);
                Bitmap bmp = Drawing.TomogramDrawing.PaintClassifiedPixelsOnTomogram(tom, labels);
                bmp.Save("/var/www/html/static/labeled_real.png", System.Drawing.Imaging.ImageFormat.Png);

                LabeledTomogram tom2 = DataReader.ReadDatFile("/home/brush/tom4/0.dat");

                labels = DecisionTreeBuilder.Predict(tom2, node, options);
                //Bitmap bmp = DataManipulator.Tomogram2Bitmap(tom);
                bmp = Drawing.TomogramDrawing.PaintClassifiedPixelsOnTomogram(tom2, labels);
                bmp.Save("/var/www/html/static/labeled_simulated.png", System.Drawing.Imaging.ImageFormat.Png);
            }
        }