コード例 #1
0
ファイル: SOFNN.cs プロジェクト: yuxi214/QuantSys
        public void Train(DenseMatrix X, DenseVector d, DenseVector Kd)
        {
            int R = X.RowCount;
            int N = X.ColumnCount;
            int U = 0; //the number of neurons in the structure


            var c     = new DenseMatrix(R, 1);
            var sigma = new DenseMatrix(R, 1);

            var Q    = new DenseMatrix((R + 1), (R + 1));
            var O    = new DenseMatrix(1, (R + 1));
            var pT_n = new DenseMatrix((R + 1), 1);

            double maxPhi = 0;
            int    maxIndex;

            var Psi = new DenseMatrix(N, 1);

            Console.WriteLine("Running...");
            //for each observation n in X
            for (int i = 0; i < N; i++)
            {
                Console.WriteLine(100 * (i / (double)N) + "%");

                var x = new DenseVector(R);
                X.Column(i, x);

                //if there are neurons in structure,
                //update structure recursively.
                if (U == 0)
                {
                    c     = (DenseMatrix)x.ToColumnMatrix();
                    sigma = new DenseMatrix(R, 1, SigmaZero);
                    U     = 1;
                    Psi   = CalculatePsi(X, c, sigma);
                    UpdateStructure(X, Psi, d, ref Q, ref O);
                    pT_n =
                        (DenseMatrix)
                        (CalculateGreatPsi((DenseMatrix)x.ToColumnMatrix(), (DenseMatrix)Psi.Row(i).ToRowMatrix()))
                        .Transpose();
                }
                else
                {
                    StructureRecurse(X, Psi, d, i, ref Q, ref O, ref pT_n);
                }


                bool KeepSpinning = true;
                while (KeepSpinning)
                {
                    //Calculate the error and if-part criteria
                    double ee = pT_n.Multiply(O)[0, 0];

                    double approximationError = Math.Abs(d[i] - ee);

                    DenseVector Phi;
                    double      SumPhi;
                    CalculatePhi(x, c, sigma, out Phi, out SumPhi);

                    maxPhi   = Phi.Maximum();
                    maxIndex = Phi.MaximumIndex();

                    if (approximationError > delta)
                    {
                        if (maxPhi < threshold)
                        {
                            var tempSigma = new DenseVector(R);
                            sigma.Column(maxIndex, tempSigma);

                            double minSigma = tempSigma.Minimum();
                            int    minIndex = tempSigma.MinimumIndex();
                            sigma[minIndex, maxIndex] = k_sigma * minSigma;
                            Psi = CalculatePsi(X, c, sigma);
                            UpdateStructure(X, Psi, d, ref Q, ref O);
                            var psi = new DenseVector(Psi.ColumnCount);
                            Psi.Row(i, psi);

                            pT_n =
                                (DenseMatrix)
                                CalculateGreatPsi((DenseMatrix)x.ToColumnMatrix(), (DenseMatrix)psi.ToRowMatrix())
                                .Transpose();
                        }
                        else
                        {
                            //add a new neuron and update strucutre

                            double distance  = 0;
                            var    cTemp     = new DenseVector(R);
                            var    sigmaTemp = new DenseVector(R);

                            //foreach input variable
                            for (int j = 0; j < R; j++)
                            {
                                distance = Math.Abs(x[j] - c[j, 0]);
                                int distanceIndex = 0;

                                //foreach neuron past 1
                                for (int k = 1; k < U; k++)
                                {
                                    if ((Math.Abs(x[j] - c[j, k])) < distance)
                                    {
                                        distanceIndex = k;
                                        distance      = Math.Abs(x[j] - c[j, k]);
                                    }
                                }

                                if (distance < Kd[j])
                                {
                                    cTemp[j]     = c[j, distanceIndex];
                                    sigmaTemp[j] = sigma[j, distanceIndex];
                                }
                                else
                                {
                                    cTemp[j]     = x[j];
                                    sigmaTemp[j] = distance;
                                }
                            }
                            //end foreach

                            c     = (DenseMatrix)c.InsertColumn(c.ColumnCount - 1, cTemp);
                            sigma = (DenseMatrix)sigma.InsertColumn(sigma.ColumnCount - 1, sigmaTemp);
                            Psi   = CalculatePsi(X, c, sigma);
                            UpdateStructure(X, Psi, d, ref Q, ref O);
                            U++;
                            KeepSpinning = false;
                        }
                    }
                    else
                    {
                        if (maxPhi < threshold)
                        {
                            var tempSigma = new DenseVector(R);
                            sigma.Column(maxIndex, tempSigma);

                            double minSigma = tempSigma.Minimum();
                            int    minIndex = tempSigma.MinimumIndex();
                            sigma[minIndex, maxIndex] = k_sigma * minSigma;
                            Psi = CalculatePsi(X, c, sigma);
                            UpdateStructure(X, Psi, d, ref Q, ref O);
                            var psi = new DenseVector(Psi.ColumnCount);
                            Psi.Row(i, psi);

                            pT_n =
                                (DenseMatrix)
                                CalculateGreatPsi((DenseMatrix)x.ToColumnMatrix(), (DenseMatrix)psi.ToRowMatrix())
                                .Transpose();
                        }
                        else
                        {
                            KeepSpinning = false;
                        }
                    }
                }
            }

            out_C     = c;
            out_O     = O;
            out_Sigma = sigma;

            Console.WriteLine("Done.");
        }
コード例 #2
0
        /// <summary>
        /// Gets the local minimums distribution.
        /// </summary>
        /// <param name="dmFieldData">The dm field data.</param>
        /// <param name="dimensionNumber">The dimension number:
        /// 1 - rows (angle)
        /// 2 - columns (distance)
        /// </param>
        /// <returns>DenseMatrix.</returns>
        //public static DenseMatrix GetLocalMinimumsDistribution(DenseMatrix dmFieldData, PointD sunCenterPoint, PointD imageCenterPoint, double imageRadius, int imageHeight, double imageCircleCropFactor = 0.9d, int dimensionNumber = 1)
        public static List <Point3D> GetLocalMinimumsDistribution(DenseMatrix dmFieldData, RoundData sunDiskData, RoundData imageRoundData, int imageHeight, double imageCircleCropFactor = 0.9d)
        {
            // DenseMatrix dmFieldminimumsData = DenseMatrix.Create(dmFieldData.RowCount, dmFieldData.ColumnCount, 0.0d);
            List <Point3D> lRetPoints = new List <Point3D>();

            double     imageRadius      = imageRoundData.DRadius;
            PointD     imageCenterPoint = imageRoundData.pointDCircleCenter();
            PointPolar imageCenterPointRelatedToSunCenter = new PointPolar(imageCenterPoint - sunDiskData.pointDCircleCenter(), true);
            double     distanceSunCenterToImageCenter     = PointD.Distance(imageCenterPoint, sunDiskData.pointDCircleCenter());


            #region // obsolete
            //if (dimensionNumber == 1)
            //{
            #endregion // obsolete
            for (int i = 0; i < dmFieldData.RowCount; i++)
            {
                bool itsTheCropCase = false;
                //если направлени на кроп кадра - то не берем в расмотрение
                double currentAngle = ((double)i / (double)(dmFieldData.RowCount - 1)) * 2.0d * Math.PI;

                LineDescription2D line, lineMargin;
                if (currentAngle < Math.PI)
                {
                    //верхняя половина, смотрим направление на y=0.0d
                    line = new LineDescription2D(sunDiskData.pointDCircleCenter(),
                                                 new Vector2D(Math.Cos(currentAngle), -Math.Sin(currentAngle)));
                    lineMargin = new LineDescription2D(new PointD(0.0d, 0.0d), new Vector2D(1.0d, 0.0d));
                }
                else
                {
                    line = new LineDescription2D(sunDiskData.pointDCircleCenter(),
                                                 new Vector2D(Math.Cos(currentAngle), Math.Sin(currentAngle)));
                    lineMargin = new LineDescription2D(new PointD(0.0d, imageHeight), new Vector2D(1.0d, 0.0d));
                }

                PointD crossPointD = LineDescription2D.CrossPoint(line, lineMargin);
                if (crossPointD.Distance(imageCenterPoint) < imageRadius)
                {
                    itsTheCropCase = true;
                }

                #region // obsolete
                //double yMargin = 0.0d;
                //double xMargin = sunCenterPoint.X + (yMargin - sunCenterPoint.Y) / Math.Tan(currentAngle);
                //double dx = xMargin - imageCenterPoint.X;
                //double dy = yMargin - imageCenterPoint.Y;
                //if (Math.Sqrt(dx * dx + dy * dy) < imageRadius) itsTheCropCase = true;
                #endregion // obsolete

                #region    //obsolete
                //else
                //{
                //    //нижняя половина, смотрим направление на y=imageHeight
                //    double yMargin = (double)imageHeight;
                //    double xMargin = sunCenterPoint.X + (yMargin - sunCenterPoint.Y) / Math.Tan(currentAngle);
                //    double dx = xMargin - imageCenterPoint.X;
                //    double dy = yMargin - imageCenterPoint.Y;
                //    if (Math.Sqrt(dx * dx + dy * dy) < imageRadius) itsTheCropCase = true;
                //}
                #endregion //obsolete
                //Если слишком близко к краю изображения - тоже исключаем. Минимум должен лежать не ближе, например, 1/15



                //DenseMatrix dmSlicedDataMatrix = (DenseMatrix)dmFieldData.SubMatrix(i, 1, 0, dmFieldData.ColumnCount);
                DenseVector dvRowDataVector = (DenseVector)dmFieldData.EnumerateRows().ElementAt(i);
                #region // debug plotting
                //dvRowDataVector.SaveVectorDataAsImagePlot(
                //    "D:\\_gulevlab\\SkyImagesAnalysis_appData\\patent-samples\\result.2015-03-24\\img-2014-09-20T16-03-58devID1\\dvRowDataVector-plot-image-" +
                //    i.ToString("D03") + "-step1.png");
                #endregion // debug plotting
                dvRowDataVector.MapIndexedInplace((idx, x) => ((x == 0.0d) || (idx < sunDiskData.DRadius * 1.5d)) ? (1.0d) : (x));
                #region    // debug plotting
                //dvRowDataVector.SaveVectorDataAsImagePlot(
                //    "D:\\_gulevlab\\SkyImagesAnalysis_appData\\patent-samples\\result.2015-03-24\\img-2014-09-20T16-03-58devID1\\dvRowDataVector-plot-image-" +
                //    i.ToString("D03") + "-step2.png");
                #endregion // debug plotting
                double phiFromImageCenterToDirection = imageCenterPointRelatedToSunCenter.Phi - currentAngle;
                double distanceToImageMargin         = distanceSunCenterToImageCenter * Math.Cos(phiFromImageCenterToDirection) +
                                                       Math.Sqrt(imageRadius * imageRadius -
                                                                 distanceSunCenterToImageCenter * distanceSunCenterToImageCenter *
                                                                 Math.Sin(phiFromImageCenterToDirection) *
                                                                 Math.Sin(phiFromImageCenterToDirection));
                dvRowDataVector.MapIndexedInplace(
                    (idx, x) => ((double)idx / distanceToImageMargin >= imageCircleCropFactor) ? (1.0d) : (x));
                #region // debug plotting
                //dvRowDataVector.SaveVectorDataAsImagePlot(
                //    "D:\\_gulevlab\\SkyImagesAnalysis_appData\\patent-samples\\result.2015-03-24\\img-2014-09-20T16-03-58devID1\\dvRowDataVector-plot-image-" +
                //    i.ToString("D03") + "-step3.png");
                #endregion // debug plotting
                double minValue      = dvRowDataVector.Minimum();
                int    minValueIndex = dvRowDataVector.MinimumIndex();

                //if (!itsTheCropCase) dmFieldminimumsData[i, minValueIndex] = minValue;
                if ((!itsTheCropCase) && ((double)minValueIndex > sunDiskData.DRadius))
                {
                    lRetPoints.Add(new Point3D(currentAngle, minValueIndex, minValue));
                }
                else
                {
                    continue;
                }
            }

            #region // obsolete
            //}
            //else if (dimensionNumber == 2)
            //{
            //    for (int i = 0; i < dmFieldData.ColumnCount; i++)
            //    {
            //        DenseMatrix dmSlicedDataMatrix = (DenseMatrix)dmFieldData.SubMatrix(0, dmFieldData.RowCount, i, 1);
            //        DenseVector dvSlicedDataVector = DenseVector.OfEnumerable(dmSlicedDataMatrix.Values);
            //        dvSlicedDataVector.MapInplace(new Func<double, double>(x => (x == 0.0d) ? (1.0d) : (x)));
            //        double minValue = dvSlicedDataVector.Minimum();
            //        int minValueIndex = dvSlicedDataVector.MinimumIndex();
            //        dmFieldminimumsData[minValueIndex, i] = minValue;
            //    }
            //}
            #endregion // obsolete

            //return dmFieldminimumsData;
            return(lRetPoints);
        }