/// <summary>
        ///   Creates an instance of the model to be learned. Inheritors of this abstract
        ///   class must define this method so new models can be created from the training data.
        /// </summary>
        ///
        protected override HiddenMarkovModel <TDistribution, TObservation> Create(TObservation[][] x, int numberOfClasses)
        {
            var hmm = new HiddenMarkovModel <TDistribution, TObservation>(states: numberOfClasses, emissions: Emissions);

            MarkovHelperMethods.CheckObservationDimensions(x, hmm);

            return(hmm);
        }
예제 #2
0
        /// <summary>
        /// Learns a model that can map the given inputs to the given outputs.
        /// </summary>
        /// <param name="x">The model inputs.</param>
        /// <param name="y">The desired outputs associated with each <paramref name="x">inputs</paramref>.</param>
        /// <param name="weights">The weight of importance for each input-output pair (if supported by the learning algorithm).</param>
        /// <returns>A model that has learned how to produce <paramref name="y" /> given <paramref name="x" />.</returns>
        public TModel Learn(TObservation[][] x, int[][] y, double[] weights = null)
        {
            if (weights != null)
            {
                throw new ArgumentException(Accord.Properties.Resources.NotSupportedWeights, "weights");
            }

            if (Model == null)
            {
                Model = Create(x, numberOfClasses: y.Max() + 1);
            }

            MarkovHelperMethods.CheckObservationDimensions(x, Model);

            // Grab model information
            var model          = Model;
            var fittingOptions = FittingOptions;

            int N      = x.Length;
            int states = model.NumberOfStates;

            int[] initial = new int[states];
            int[,] transitions = new int[states, states];


            // 1. Count first state occurrences
            for (int i = 0; i < y.Length; i++)
            {
                initial[y[i][0]]++;
            }

            // 2. Count all state transitions
            foreach (int[] path in y)
            {
                for (int j = 1; j < path.Length; j++)
                {
                    transitions[path[j - 1], path[j]]++;
                }
            }

            if (UseWeights)
            {
                int totalObservations = 0;
                for (int i = 0; i < x.Length; i++)
                {
                    totalObservations += x[i].Length;
                }

                double[][] totalWeights = new double[states][];
                for (int i = 0; i < totalWeights.Length; i++)
                {
                    totalWeights[i] = new double[totalObservations];
                }

                var all = new TObservation[totalObservations];

                for (int i = 0, c = 0; i < y.Length; i++)
                {
                    for (int t = 0; t < y[i].Length; t++, c++)
                    {
                        int state = y[i][t];
                        all[c] = x[i][t];
                        totalWeights[state][c] = 1;
                    }
                }

                for (int i = 0; i < model.NumberOfStates; i++)
                {
                    model.Emissions[i].Fit(all, totalWeights[i], fittingOptions);
                }
            }
            else
            {
                // 3. Count emissions for each state
                var clusters = new List <TObservation> [model.NumberOfStates];
                for (int i = 0; i < clusters.Length; i++)
                {
                    clusters[i] = new List <TObservation>();
                }

                // Count symbol frequencies per state
                for (int i = 0; i < y.Length; i++)
                {
                    for (int t = 0; t < y[i].Length; t++)
                    {
                        int state  = y[i][t];
                        var symbol = x[i][t];

                        clusters[state].Add(symbol);
                    }
                }

                // Estimate probability distributions
                for (int i = 0; i < model.NumberOfStates; i++)
                {
                    if (clusters[i].Count > 0)
                    {
                        model.Emissions[i].Fit(clusters[i].ToArray(), fittingOptions);
                    }
                }
            }

            // 4. Form log-probabilities, using the Laplace
            //    correction to avoid zero probabilities

            if (UseLaplaceRule)
            {
                // Use Laplace's rule of succession correction
                // http://en.wikipedia.org/wiki/Rule_of_succession

                for (int i = 0; i < initial.Length; i++)
                {
                    initial[i]++;

                    for (int j = 0; j < states; j++)
                    {
                        transitions[i, j]++;
                    }
                }
            }

            // Form probabilities
            int initialCount = initial.Sum();

            int[] transitionCount = transitions.Sum(1);

            if (initialCount == 0)
            {
                initialCount = 1;
            }

            for (int i = 0; i < transitionCount.Length; i++)
            {
                if (transitionCount[i] == 0)
                {
                    transitionCount[i] = 1;
                }
            }


            for (int i = 0; i < initial.Length; i++)
            {
                model.LogInitial[i] = Math.Log(initial[i] / (double)initialCount);
            }

            for (int i = 0; i < transitionCount.Length; i++)
            {
                for (int j = 0; j < states; j++)
                {
                    model.LogTransitions[i][j] = Math.Log(transitions[i, j] / (double)transitionCount[i]);
                }
            }

            Accord.Diagnostics.Debug.Assert(!model.LogInitial.HasNaN());
            Accord.Diagnostics.Debug.Assert(!model.LogTransitions.HasNaN());

            return(Model);
        }
        /// <summary>
        ///   Learns a model that can map the given inputs to the desired outputs.
        /// </summary>
        /// 
        /// <param name="x">The model inputs.</param>
        /// <param name="weights">The weight of importance for each input sample.</param>
        /// 
        /// <returns>A model that has learned how to produce suitable outputs
        ///   given the input data <paramref name="x"/>.</returns>
        /// 
        public TModel Learn(TObservation[][] x, double[] weights = null)
        {
            // Initial argument checks
            CheckArgs(x, weights);

            // Baum-Welch algorithm.

            // The Baum–Welch algorithm is a particular case of a generalized expectation-maximization
            // (GEM) algorithm. It can compute maximum likelihood estimates and posterior mode estimates
            // for the parameters (transition and emission probabilities) of an HMM, when given only
            // emissions as training data.

            // The algorithm has two steps:
            //  - Calculating the forward probability and the backward probability for each HMM state;
            //  - On the basis of this, determining the frequency of the transition-emission pair values
            //    and dividing it by the probability of the entire string. This amounts to calculating
            //    the expected count of the particular transition-emission pair. Each time a particular
            //    transition is found, the value of the quotient of the transition divided by the probability
            //    of the entire string goes up, and this value can then be made the new value of the transition.


            this.samples = x.Concatenate();
            this.vectorObservations = x;

            if (Model == null)
                Model = Create(x);

            MarkovHelperMethods.CheckObservationDimensions(x, Model);

            if (MaxIterations > 0 && CurrentIteration >= MaxIterations)
                return Model;

            // Grab model information
            int states = Model.NumberOfStates;
            var logA = Model.LogTransitions;
            var logP = Model.LogInitial;

            // Initialize the algorithm
            int N = x.Length;
            double logN = Math.Log(N);
            LogKsi = new double[N][][,];
            LogGamma = new double[N][,];
            LogWeights = new double[N];
            if (weights != null)
                weights.Log(result: LogWeights);

            for (int i = 0; i < x.Length; i++)
            {
                int T = x[i].Length;

                LogKsi[i] = new double[T][,];
                LogGamma[i] = new double[T, states];

                for (int t = 0; t < LogKsi[i].Length; t++)
                    LogKsi[i][t] = new double[states, states];
            }

            int TMax = x.Max(x_i => x_i.Length);
#if SERIAL
            lnFwd = new double[TMax, states];
            lnBwd = new double[TMax, states];
            sampleWeights = new double[samples.Length];
#endif

            convergence.CurrentIteration--;
            bool hasUpdated = false;

            do
            {
                if (Token.IsCancellationRequested)
                    break;

                // Initialize the model log-likelihood
                LogLikelihood = Expectation(x, TMax);
                convergence.NewValue = LogLikelihood;

                // Check for convergence
                if (hasUpdated && convergence.HasConverged)
                    break;

                if (Token.IsCancellationRequested)
                    break;

                // 3. Continue with parameter re-estimation
                // 3.1 Re-estimation of initial state probabilities 
#if SERIAL
                for (int i = 0; i < logP.Length; i++)
#else
                Parallel.For(0, logP.Length, ParallelOptions, i =>
#endif
                {
                    double lnsum = Double.NegativeInfinity;
                    for (int k = 0; k < LogGamma.Length; k++)
                        lnsum = Special.LogSum(lnsum, LogGamma[k][0, i]);
                    logP[i] = lnsum - logN;
                }
#if !SERIAL
                );
#endif

                // 3.2 Re-estimation of transition probabilities 
#if SERIAL
                for (int i = 0; i < states; i++)
#else
                Parallel.For(0, states, ParallelOptions, i =>
#endif
                {
                    for (int j = 0; j < states; j++)
                    {
                        double lnnum = Double.NegativeInfinity;
                        double lnden = Double.NegativeInfinity;

                        for (int k = 0; k < LogGamma.Length; k++)
                        {
                            int T = x[k].Length;

                            for (int t = 0; t < T - 1; t++)
                            {
                                lnnum = Special.LogSum(lnnum, LogKsi[k][t][i, j]);
                                lnden = Special.LogSum(lnden, LogGamma[k][t, i]);
                            }
                        }

                        logA[i][j] = (lnnum == lnden) ? 0 : lnnum - lnden;
                        Accord.Diagnostics.Debug.Assert(!Double.IsNaN(logA[i][j]));
                    }
                }
#if !SERIAL
                );
#endif

                // 3.3 Re-estimation of emission probabilities
                UpdateEmissions(); // discrete and continuous
                hasUpdated = true;

            } while (true);

            return Model;
        }