コード例 #1
0
ファイル: dist_tests.cs プロジェクト: grouptheory/moadb
        private IDistribution createGaussianMixture(double mean, double std, double mean2, double std2)
        {
            IDistribution d1 = create1DGaussian(mean, std);
            IDistribution d2 = create1DGaussian(mean2, std2);

            int dim = 1;

            string [] names = new string [1] {
                "x"
            };
            double [] mins = new double [1] {
                0.00
            };
            double [] maxs = new double [1] {
                100.0
            };
            IBlauSpace s = BlauSpace.create(dim, names, mins, maxs);
            Mixture    d = new Mixture(s);

            d.Add(d1, 0.75);
            d.Add(d2, 0.25);
            d.DistributionComplete();

            return(d);
        }
コード例 #2
0
ファイル: dist_tests.cs プロジェクト: grouptheory/moadb
        public void DistributionSpaceIterator_GaussianMixtureTest()
        {
            Console.WriteLine("DistributionSpaceIterator_GaussianMixtureTest");

            double        mean  = 70.0;
            double        std   = 1.0;
            IDistribution d1    = create1DGaussian(mean, std);
            double        mean2 = 20.0;
            double        std2  = 1.0;
            IDistribution d2    = create1DGaussian(mean2, std2);

            int dim = 1;

            string [] names = new string [1] {
                "x"
            };
            double [] mins = new double [1] {
                0.00
            };
            double [] maxs = new double [1] {
                100.0
            };
            IBlauSpace s = BlauSpace.create(dim, names, mins, maxs);
            Mixture    d = new Mixture(s);

            d.Add(d1, 0.75);
            d.Add(d2, 0.25);
            d.DistributionComplete();


            SingletonLogger.Instance().DebugLog(typeof(dist_tests), "original distribution: " + d);
            DistributionSpace ds = new DistributionSpace(d);

            int [] steps = new int[ds.ParamSpace.Dimension];

            for (int N = 3; N <= 5; N++)
            {
                for (int i = 0; i < ds.ParamSpace.Dimension; i++)
                {
                    steps[i] = N;
                }

                IDistributionSpaceIterator it = ds.iterator(steps);

                int count   = 0;
                int validCt = 0;
                foreach (IDistribution diter in it)
                {
                    if (diter.IsValid())
                    {
                        validCt++;
                        SingletonLogger.Instance().DebugLog(typeof(dist_tests), "iterator distribution: " + diter);
                    }
                    count++;
                }
                Assert.AreEqual((N + 1) * (N + 1) * (N + 1) * (N + 1) * (N + 1) * (N + 1), count);
                SingletonLogger.Instance().InfoLog(typeof(dist_tests), "N=" + N + "  valid distributions: " + validCt + " / total: " + count);
            }
        }
 public Mixture<VectorGaussian> ToMixture()
 {
     Mixture<VectorGaussian> mixture = new Mixture<VectorGaussian>();
     for (int i = 0; i < weights.Count; ++i)
         mixture.Add(
             VectorGaussian.FromMeanAndVariance(
                 MicrosoftResearch.Infer.Maths.Vector.FromArray(this.means[i]),
                 new PositiveDefiniteMatrix(JaggedArrayToMatrix(this.variances[i]))),
             this.weights[i]);
     return mixture;
 }
コード例 #4
0
        public Mixture <VectorGaussian> ToMixture()
        {
            Mixture <VectorGaussian> mixture = new Mixture <VectorGaussian>();

            for (int i = 0; i < weights.Count; ++i)
            {
                mixture.Add(
                    VectorGaussian.FromMeanAndVariance(
                        MicrosoftResearch.Infer.Maths.Vector.FromArray(this.means[i]),
                        new PositiveDefiniteMatrix(JaggedArrayToMatrix(this.variances[i]))),
                    this.weights[i]);
            }
            return(mixture);
        }
コード例 #5
0
        public static Mixture <VectorGaussian> Fit(MicrosoftResearch.Infer.Maths.Vector[] data, int componentCount, int retryCount, double tolerance = 1e-4)
        {
            Debug.Assert(data != null);
            Debug.Assert(data.Length > componentCount * 3);
            Debug.Assert(componentCount > 1);
            Debug.Assert(retryCount >= 0);

            int dimensions = data[0].Count;

            // Find point boundary
            MicrosoftResearch.Infer.Maths.Vector min = data[0].Clone();
            MicrosoftResearch.Infer.Maths.Vector max = min.Clone();
            for (int i = 1; i < data.Length; ++i)
            {
                Debug.Assert(dimensions == data[i].Count);
                for (int j = 0; j < dimensions; ++j)
                {
                    min[j] = Math.Min(min[j], data[i][j]);
                    max[j] = Math.Max(max[j], data[i][j]);
                }
            }

            // Initialize solution
            MicrosoftResearch.Infer.Maths.Vector[] means = new MicrosoftResearch.Infer.Maths.Vector[componentCount];
            PositiveDefiniteMatrix[] covariances         = new PositiveDefiniteMatrix[componentCount];
            for (int i = 0; i < componentCount; ++i)
            {
                GenerateRandomMixtureComponent(min, max, out means[i], out covariances[i]);
            }
            double[] weights = Enumerable.Repeat(1.0 / componentCount, componentCount).ToArray();

            // EM algorithm for GMM
            double[,] expectations = new double[data.Length, componentCount];
            double       lastEstimate;
            const double negativeInfinity = -1e+20;
            bool         convergenceDetected;
            double       currentEstimate = negativeInfinity;

            do
            {
                lastEstimate        = currentEstimate;
                convergenceDetected = false;

                // E-step: estimate expectations on hidden variables
                for (int i = 0; i < data.Length; ++i)
                {
                    double sum = 0;
                    for (int j = 0; j < componentCount; ++j)
                    {
                        expectations[i, j] =
                            Math.Exp(VectorGaussian.GetLogProb(data[i], means[j], covariances[j])) * weights[j];
                        sum += expectations[i, j];
                    }
                    for (int j = 0; j < componentCount; ++j)
                    {
                        expectations[i, j] /= sum;
                    }
                }

                // M-step:

                // Re-estimate means
                for (int j = 0; j < componentCount; ++j)
                {
                    means[j] = MicrosoftResearch.Infer.Maths.Vector.Zero(dimensions);
                    double sum = 0;
                    for (int i = 0; i < data.Length; ++i)
                    {
                        means[j] += data[i] * expectations[i, j];
                        sum      += expectations[i, j];
                    }
                    means[j] *= 1.0 / sum;
                }

                // Re-estimate covariances
                for (int j = 0; j < componentCount; ++j)
                {
                    Matrix covariance = new Matrix(dimensions, dimensions);
                    double sum        = 0;
                    for (int i = 0; i < data.Length; ++i)
                    {
                        MicrosoftResearch.Infer.Maths.Vector dataDiff = data[i] - means[j];
                        covariance += dataDiff.Outer(dataDiff) * expectations[i, j];
                        sum        += expectations[i, j];
                    }
                    covariance    *= 1.0 / sum;
                    covariances[j] = new PositiveDefiniteMatrix(covariance);

                    if (covariances[j].LogDeterminant() < -30)
                    {
                        DebugConfiguration.WriteDebugText("Convergence detected for component {0}", j);
                        if (retryCount == 0)
                        {
                            throw new InvalidOperationException("Can't fit GMM. Retry number exceeded.");
                        }

                        retryCount -= 1;
                        GenerateRandomMixtureComponent(min, max, out means[j], out covariances[j]);
                        DebugConfiguration.WriteDebugText("Component {0} regenerated", j);

                        convergenceDetected = true;
                    }
                }

                if (convergenceDetected)
                {
                    currentEstimate = negativeInfinity;
                    continue;
                }

                // Re-estimate weights
                double expectationSum = 0;
                for (int j = 0; j < componentCount; ++j)
                {
                    weights[j] = 0;
                    for (int i = 0; i < data.Length; ++i)
                    {
                        weights[j]     += expectations[i, j];
                        expectationSum += expectations[i, j];
                    }
                }
                for (int j = 0; j < componentCount; ++j)
                {
                    weights[j] /= expectationSum;
                }

                // Compute likelihood estimate
                currentEstimate = 0;
                for (int i = 0; i < data.Length; ++i)
                {
                    for (int j = 0; j < componentCount; ++j)
                    {
                        currentEstimate +=
                            expectations[i, j] * (VectorGaussian.GetLogProb(data[i], means[j], covariances[j]) + Math.Log(weights[j]));
                    }
                }

                DebugConfiguration.WriteDebugText("L={0:0.000000}", currentEstimate);
            } while (convergenceDetected || (currentEstimate - lastEstimate > tolerance));

            Mixture <VectorGaussian> result = new Mixture <VectorGaussian>();

            for (int j = 0; j < componentCount; ++j)
            {
                result.Add(VectorGaussian.FromMeanAndVariance(means[j], covariances[j]), weights[j]);
            }

            DebugConfiguration.WriteDebugText("GMM successfully fitted.");

            return(result);
        }
コード例 #6
0
        public static void Main1DMixture(string[] args)
        {
            Console.WriteLine("console_tests");

            LoggerInitialization.SetThreshold(typeof(console_tests.MainClass), LogLevel.Debug);

            double        mean  = 70.0;
            double        std   = 1.0;
            IDistribution d1    = create1DGaussian(mean, std);
            double        mean2 = 20.0;
            double        std2  = 1.0;
            IDistribution d2    = create1DGaussian(mean2, std2);

            int dim = 1;

            string [] names = new string [1] {
                "x"
            };
            double [] mins = new double [1] {
                0.00
            };
            double [] maxs = new double [1] {
                100.0
            };
            IBlauSpace s = BlauSpace.create(dim, names, mins, maxs);

            Mixture d = new Mixture(s);

            d.Add(d1, 0.75);
            d.Add(d2, 0.25);
            d.DistributionComplete();

            SingletonLogger.Instance().DebugLog(typeof(console_tests.MainClass), "original distribution: " + d);
            Console.WriteLine("original distribution: " + d);

            DistributionSpace ds = new DistributionSpace(d);

            int [] steps = new int[ds.ParamSpace.Dimension];

            // N = subdivisions of each of the parameter values
            for (int N = 3; N <= 6; N++)
            {
                for (int i = 0; i < ds.ParamSpace.Dimension; i++)
                {
                    steps[i] = N;
                }

                IDistributionSpaceIterator it = ds.iterator(steps);

                int count   = 0;
                int validCt = 0;
                foreach (IDistribution diter in it)
                {
                    if (diter.IsValid())
                    {
                        validCt++;
                        SingletonLogger.Instance().DebugLog(typeof(console_tests.MainClass), "iterator distribution: " + diter);
                        Console.WriteLine("valid distribution: " + diter);
                    }
                    else
                    {
                        Console.WriteLine("invalid distribution: " + diter);
                    }
                    count++;
                }

                Console.WriteLine("# of valid distributions: " + validCt);
                Console.WriteLine("# of total distributions: " + count);

                Assert.AreEqual((N + 1) * (N + 1) * (N + 1) * (N + 1) * (N + 1) * (N + 1), count);
                SingletonLogger.Instance().InfoLog(typeof(console_tests.MainClass), "N=" + N + "  valid distributions: " + validCt + " / total: " + count);
            }
        }
コード例 #7
0
        public static void MakeMixture_Main(string[] args)
        {
            Console.WriteLine("MakeMixture");

            // Command line parsing
            Arguments CommandLine = new Arguments(args);

            bool   err       = false;
            string errString = "";

            string file1   = "unassigned";
            string file2   = "unassigned";
            string outfile = "unassigned";

            // Look for specific arguments values and display
            // them if they exist (return null if they don't)
            if (CommandLine["file1"] != null)
            {
                file1 = CommandLine["file1"];
                if (!File.Exists(file1))
                {
                    errString += ("The specified 'file1' was not found: " + file1 + "  ");
                    err        = true;
                }
            }
            else
            {
                errString += ("The 'file1' was not specified.  ");
                err        = true;
            }

            double weight1 = -1.0;

            // Look for specific arguments values and display
            // them if they exist (return null if they don't)
            if (CommandLine["weight1"] != null)
            {
                try {
                    weight1 = Double.Parse(CommandLine["weight1"]);
                    if ((weight1 < 0.0) || (weight1 > 1.0))
                    {
                        errString += ("The specified 'weight1' was not in the range [0,1].  ");
                        err        = true;
                    }
                }
                catch (Exception) {
                    errString += ("The specified 'weight1' was not valid.  ");
                    err        = true;
                }
            }
            else
            {
                errString += ("The 'weight1' was not specified.  ");
                err        = true;
            }

            if (CommandLine["file2"] != null)
            {
                file2 = CommandLine["file2"];
                if (!File.Exists(file2))
                {
                    errString += ("The specified 'file2' was not found: " + file2 + "  ");
                    err        = true;
                }
            }
            else
            {
                errString += ("The 'file2' was not specified.  ");
                err        = true;
            }

            if (CommandLine["outfile"] != null)
            {
                outfile = CommandLine["outfile"];
            }
            else
            {
                errString += ("The 'outfile' was not specified.  ");
                err        = true;
            }

            if (err)
            {
                Console.Out.WriteLine("Arguments parsing failed.");
                Console.Out.WriteLine("  " + errString);
            }
            else
            {
                Console.Out.WriteLine("Arguments parsing successful.");
                Console.Out.WriteLine("  file1 = " + file1);
                Console.Out.WriteLine("  weight1 = " + weight1);
                Console.Out.WriteLine("  file2 = " + file2);
                Console.Out.WriteLine("  weight2 = " + (1 - weight1));
                Console.Out.WriteLine("  outfile = " + outfile);

                SoapFormatter formatter = new SoapFormatter();

                FileStream    fs = new FileStream(file1, FileMode.Open);
                IDistribution d1 = (IDistribution)formatter.Deserialize(fs);
                fs.Close();

                fs = new FileStream(file2, FileMode.Open);
                IDistribution d2 = (IDistribution)formatter.Deserialize(fs);
                fs.Close();

                if (!BlauSpace.contains(d1.SampleSpace, d2.SampleSpace) || !BlauSpace.contains(d2.SampleSpace, d1.SampleSpace))
                {
                    Console.Out.WriteLine("The sample spaces of the two distributions are not identical intersection.");
                    Console.Out.WriteLine("  d1: " + d1);
                    Console.Out.WriteLine("  d2: " + d2);
                    Console.Out.WriteLine("The mixture must be constructed over identical sample spaces.");
                    Console.Out.WriteLine("This is a fatal error, preventing the construction of the mixture distribution.");
                }
                else
                {
                    int       dim3   = d1.SampleSpace.Dimension;
                    string [] names3 = new string [dim3];
                    double [] mins3  = new double [dim3];
                    double [] maxs3  = new double [dim3];
                    for (int i = 0; i < d1.SampleSpace.Dimension; i++)
                    {
                        names3[i] = d1.SampleSpace.getAxis(i).Name;
                        mins3[i]  = d1.SampleSpace.getAxis(i).MinimumValue;
                        maxs3[i]  = d1.SampleSpace.getAxis(i).MaximumValue;
                    }
                    IBlauSpace s3 = BlauSpace.create(dim3, names3, mins3, maxs3);

                    Mixture d3 = new Mixture(s3);
                    d3.Add(d1, weight1);
                    d3.Add(d2, 1.0 - weight1);
                    d3.DistributionComplete();

                    Console.Out.WriteLine("Distribution: " + d3);


                    fs = new FileStream(outfile, FileMode.Create);
                    formatter.Serialize(fs, d3);
                    fs.Close();
                }
            }
        }