Пример #1
0
        /// <summary>
        /// Subsampling test for training
        /// </summary>
        /// <param name="word">Tested word</param>
        /// <returns>True to remove the word</returns>
        public bool SubSampling(string word)
        {
            double fq = Frequencies[word];

            return((Math.Sqrt(fq / SubSamplingRate) + 1) * (SubSamplingRate / fq) < UniformRandom.Next());
        }
Пример #2
0
        /// <summary>
        /// Negative sampling test for training
        /// </summary>
        /// <param name="word">Tested word</param>
        /// <returns>True to sample</returns>
        public bool NegativeSampling(string word)
        {
            double fq = Math.Pow(Frequencies[word], SubSamplingPower);

            return(fq / NegativeSamplingSum > UniformRandom.Next());
        }
Пример #3
0
        /// <summary>
        /// Average embedding learning
        /// </summary>
        /// <param name="bias"></param>
        public void Train(Article.BiasType bias, string path)
        {
            Corpus corpus   = new Corpus("corpus.db");
            var    articles = corpus.Articles.OrderBy(p => 0.5 > UniformRandom.Next()).ToList();

            List <Article> testing = new List <Article>(), training = new List <Article>();

            for (int i = 0; i < articles.Count(); i++)
            {
                if (i < articles.Count() / 4)
                {
                    testing.Add(articles[i]);
                }
                else
                {
                    training.Add(articles[i]);
                }
            }

            int t_l = training.Where(a => a.Bias == Article.BiasType.Left).Count();
            int t_r = training.Where(a => a.Bias == Article.BiasType.Right).Count();

            // Counts
            foreach (Article a in training)
            {
                if (a.Bias != bias)
                {
                    continue;
                }
                CorpusSize++;
                var document = a.Document();
                foreach (string word in document)
                {
                    string stemmed = Stemmer.Stem(word);
                    CorpusAppearances[stemmed] = CorpusAppearances.ContainsKey(stemmed) ? CorpusAppearances[stemmed] + 1 : 1;
                }
            }

            // Get results
            Matrix <double>       X       = Matrices.Dense(training.Count, Embedding.EmbeddingSize);
            Dictionary <int, int> correct = new Dictionary <int, int>();
            int Xi = 0;

            foreach (Article a in training)
            {
                var document = a.Document();
                correct[Xi] = (a.Bias == bias) ? 0 : 1;
                X.SetRow(Xi++, WeightedAverage(document));
            }

            double last_loss = double.PositiveInfinity;

            for (int iteration = 0; iteration < 100000; iteration++)
            {
                var hidden_layer = X * W1;
                for (int i = 0; i < hidden_layer.RowCount; i++)
                {
                    var row = hidden_layer.Row(i);
                    hidden_layer.SetRow(i, (row + B1).Map(p => 1 / (1 + Math.Exp(-p))));
                }

                var scores = hidden_layer * W2;
                for (int i = 0; i < scores.RowCount; i++)
                {
                    var row = scores.Row(i);
                    scores.SetRow(i, row + B2);
                }

                var probs = scores.Clone().PointwiseExp();
                for (int i = 0; i < probs.RowCount; i++)
                {
                    if (Single.IsInfinity((float)probs[i, 0]))
                    {
                        probs[i, 0] = 1; probs[i, 1] = 0;
                    }
                    else if (Single.IsInfinity((float)probs[i, 1]))
                    {
                        probs[i, 1] = 1; probs[i, 0] = 0;
                    }
                    else
                    {
                        var    row = probs.Row(i);
                        double sum = row.Sum();
                        probs.SetRow(i, row / sum);
                    }
                }

                var dscores = probs.Clone();
                for (int i = 0; i < dscores.RowCount; i++)
                {
                    dscores[i, correct[i]] -= 1;
                }
                dscores /= training.Count;

                double data_loss = 0;
                for (int i = 0; i < probs.RowCount; i++)
                {
                    double c = probs.Row(i)[correct[i]];
                    c          = -Math.Log(c);
                    data_loss += c;
                }
                data_loss /= training.Count;
                var reg_loss = 0.5 * Reg * (W1.PointwiseMultiply(W1)).ColumnSums().Sum();
                reg_loss += 0.5 * Reg * (W2.PointwiseMultiply(W2)).ColumnSums().Sum();
                var loss = data_loss + reg_loss;
                if (iteration % 10 == 0)
                {
                    Console.WriteLine(string.Format("Iter. {0}: {1}", iteration, loss));
                }

                var dW2 = hidden_layer.Transpose() * dscores;
                var dB2 = dscores.ColumnSums();
                dW2 += Reg * W2;

                var dhidden = dscores * W2.Transpose();
                dhidden = dhidden.Map(p => p * (1 - p));

                var dW1 = X.Transpose() * dhidden;
                var dB1 = dhidden.ColumnSums();
                dW1 += Reg * W1;



                W1 += -StepSize * dW1;
                B1 += -StepSize * dB1;

                W2 += -StepSize * dW2;
                B2 += -StepSize * dB2;
            }

            Save(path);

            int right = 0;

            foreach (Article a in testing)
            {
                double p  = Probability(a.Document());
                bool   ok = (a.Bias == bias && p > 0.5) || (a.Bias != bias && p < 0.5);
                if (ok)
                {
                    right++;
                }
                Console.WriteLine(string.Format("{0:0.00%} Chance of match... {1}", p, ok ? "Ok." : "Wrong!"));
            }
            double pc = right / (double)testing.Count;

            Console.WriteLine(string.Format("\n{0:0.00%} accuracy.", pc));
        }