Exemplo n.º 1
0
 /**
  * <summary>Updates the alpha parameter after 10000 words has been processed.</summary>
  */
 public void AlphaUpdate()
 {
     if (_wordCount - _lastWordCount > 10000)
     {
         _wordCountActual += _wordCount - _lastWordCount;
         _lastWordCount    = _wordCount;
         _alpha            = _startingAlpha * (1 - _wordCountActual /
                                               (_wordToVecParameter.GetNumberOfIterations() * corpus.NumberOfWords() + 1.0));
         if (_alpha < _startingAlpha * 0.0001)
         {
             _alpha = _startingAlpha * 0.0001;
         }
     }
 }
        /**
         * <summary>Main method for training the CBow version of Word2Vec algorithm.</summary>
         */
        private void TrainCbow()
        {
            var iteration       = new Iteration(_corpus, _parameter);
            var currentSentence = _corpus.GetSentence(iteration.GetSentenceIndex());
            var random          = new Random();
            var outputs         = new Vector(_parameter.GetLayerSize(), 0);
            var outputUpdate    = new Vector(_parameter.GetLayerSize(), 0);

            _corpus.ShuffleSentences(1);
            while (iteration.GetIterationCount() < _parameter.GetNumberOfIterations())
            {
                iteration.AlphaUpdate();
                var wordIndex   = _vocabulary.GetPosition(currentSentence.GetWord(iteration.GetSentencePosition()));
                var currentWord = _vocabulary.GetWord(wordIndex);
                outputs.Clear();
                outputUpdate.Clear();
                var b  = random.Next(_parameter.GetWindow());
                var cw = 0;
                int lastWordIndex;
                for (var a = b; a < _parameter.GetWindow() * 2 + 1 - b; a++)
                {
                    var c = iteration.GetSentencePosition() - _parameter.GetWindow() + a;
                    if (a != _parameter.GetWindow() && currentSentence.SafeIndex(c))
                    {
                        lastWordIndex = _vocabulary.GetPosition(currentSentence.GetWord(c));
                        outputs.Add(_wordVectors.GetRow(lastWordIndex));
                        cw++;
                    }
                }

                if (cw > 0)
                {
                    outputs.Divide(cw);
                    int    l2;
                    double f;
                    double g;
                    if (_parameter.IsHierarchicalSoftMax())
                    {
                        for (var d = 0; d < currentWord.GetCodeLength(); d++)
                        {
                            l2 = currentWord.GetPoint(d);
                            f  = outputs.DotProduct(_wordVectorUpdate.GetRow(l2));
                            if (f <= -MAX_EXP || f >= MAX_EXP)
                            {
                                continue;
                            }

                            f = _expTable[(int)((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];

                            g = (1 - currentWord.GetCode(d) - f) * iteration.GetAlpha();
                            outputUpdate.Add(_wordVectorUpdate.GetRow(l2).Product(g));
                            _wordVectorUpdate.Add(l2, outputs.Product(g));
                        }
                    }
                    else
                    {
                        for (var d = 0; d < _parameter.GetNegativeSamplingSize() + 1; d++)
                        {
                            int target;
                            int label;
                            if (d == 0)
                            {
                                target = wordIndex;
                                label  = 1;
                            }
                            else
                            {
                                target = _vocabulary.GetTableValue(random.Next(_vocabulary.GetTableSize()));
                                if (target == 0)
                                {
                                    target = random.Next(_vocabulary.Size() - 1) + 1;
                                }
                                if (target == wordIndex)
                                {
                                    continue;
                                }
                                label = 0;
                            }

                            l2 = target;
                            f  = outputs.DotProduct(_wordVectorUpdate.GetRow(l2));
                            g  = CalculateG(f, iteration.GetAlpha(), label);
                            outputUpdate.Add(_wordVectorUpdate.GetRow(l2).Product(g));
                            _wordVectorUpdate.Add(l2, outputs.Product(g));
                        }
                    }

                    for (var a = b; a < _parameter.GetWindow() * 2 + 1 - b; a++)
                    {
                        var c = iteration.GetSentencePosition() - _parameter.GetWindow() + a;
                        if (a != _parameter.GetWindow() && currentSentence.SafeIndex(c))
                        {
                            lastWordIndex = _vocabulary.GetPosition(currentSentence.GetWord(c));
                            _wordVectors.Add(lastWordIndex, outputUpdate);
                        }
                    }
                }

                currentSentence = iteration.SentenceUpdate(currentSentence);
            }
        }