Ejemplo n.º 1
0
        public void Learn_HMMWithRandomDistributedProbs_LearnsAConsistentModelBasedOnData()
        {
            // Arrange
            int[] symbols = Enumerable.Range(1, 42).ToArray();
            List<int[]> obs = new List<int[]>();
            Random rnd = new Random();

            for (int i = 0; i < 10; i++)
            {
                obs.Add(symbols.OrderBy(x => rnd.Next()).ToArray());
            }

            HMMGraph hmm = new HMMGraph(symbols.Length);
            hmm.AddNode(new Node());
            hmm.AddNode(new Node());
            hmm.AddNode(new Node());
            hmm.AddNode(new Node());

            foreach (Node n in hmm.Nodes)
            {
                foreach (Node m in hmm.Nodes)
                {
                    n.SetTransition(m, 0.5);
                }

                foreach (int i in symbols)
                {
                    n.SetEmission(i, 0.5);
                }
            }

            hmm.Normalize();
            BaumWelch bw = new BaumWelch(obs.Count, hmm);

            // Act
            for (int i = 0; i < 5; i++)
            {
                bw.Learn(hmm, obs.ToArray());
            }
           
            // Assert
            const double PRECISION = .00000000001;
            foreach (Node n in hmm.Nodes)
            {
                //check transitions
                double sum = 0;
                foreach (Node nb in n.Transitions.Keys)
                {
                    sum += n.Transitions[nb];
                }
                Assert.IsTrue(1.0 - PRECISION < sum && sum < 1.0 + PRECISION);

                sum = 0;
                foreach (int o in n.Emissions.Keys)
                {
                    sum += n.Emissions[o];
                }
                Assert.IsTrue(1.0 - PRECISION < sum && sum < 1.0 + PRECISION);
            }
        }
Ejemplo n.º 2
0
		public void LearnTest_validInput_ModelDescribingTheData ()
		{
			// Arrange
			HMMGraph hmm = new HMMGraph(NUMBER_OF_SYMBOLS_IN_HMMGRAPH);
			//int[] t = {2,3,5,6,2,12,4,6,3,36,62,2,144,3,531,44,23,234,21};

			List<int[]> obs = new List<int[]>();

			Random rnd = new Random();

			for(int i=0;i<4;i++) {
				obs.Add(Enumerable.Range(1,49).OrderBy(x => rnd.Next()).ToArray());
			}

			hmm.AddNode(new Node());
			hmm.AddNode(new Node());
			hmm.AddNode(new Node());
			hmm.AddNode(new Node());
			
			BaumWelch BW = new BaumWelch();

			// Act
			HMMGraph result = BW.Learn(hmm, obs.ToArray());

			// Assert
			Assert.IsNotNull(result);
			Assert.Inconclusive();	
		}
        private static Node FindQPrime(HMMGraph graph, int[] combinedTrainData) {
            BaumWelch bw = new BaumWelch(combinedTrainData.Length, graph);

            bw.PreCompute(graph, combinedTrainData);

            Node qPrime = graph.Nodes[0];
            double best = 0.0;
            double score = 0.0;

            foreach (Node n in graph.Nodes) {

                score = bw.ComputeGamma(n, graph, combinedTrainData); // relative (unscaled)

                if (score > best) {

                    qPrime = n;
                    best = score;
                }
            }
            return qPrime;
        }