Ejemplo n.º 1
0
        private static bool HandleWords(StreamReader reader, ref long wordCount, long?[] sentence, ref ulong nextRandom,
                                        ref long sentenceLength, IEnumerable <string> words, WordCollection wordCollection, float thresholdForOccurrenceOfWords)
        {
            var totalNumberOfWords = wordCollection.GetTotalNumberOfWords();

            foreach (var word in words)
            {
                var wordIndex = wordCollection[word];
                if (!wordIndex.HasValue)
                {
                    continue;
                }
                wordCount++;

                //Subsampling of frequent words
                if (thresholdForOccurrenceOfWords > 0)
                {
                    var random = ((float)Math.Sqrt(wordCollection.GetOccurrenceOfWord(word) / (thresholdForOccurrenceOfWords * totalNumberOfWords)) + 1) *
                                 (thresholdForOccurrenceOfWords * totalNumberOfWords) / wordCollection.GetOccurrenceOfWord(word);
                    nextRandom = LinearCongruentialGenerator(nextRandom);
                    if (random < (nextRandom & 0xFFFF) / (float)65536)
                    {
                        continue;
                    }
                }
                sentence[sentenceLength] = wordIndex.Value;
                sentenceLength++;
                if (sentenceLength > sentence.Length)
                {
                    return(true);
                }
            }
            if (reader.EndOfStream)
            {
                return(true);
            }
            return(false);
        }
Ejemplo n.º 2
0
        private void TrainModelThreadStart(int id)
        {
            long sentenceLength = 0;
            long sentencePosition = 0;
            long wordCount = 0, lastWordCount = 0;
            var  sentence  = new long?[_maxSentenceLength]; //Sentence elements will not be null to my understanding
            var  localIter = _numberOfIterations;

            var nextRandom = (ulong)id;
            var neu1       = new float[_numberOfDimensions];
            var sum        = _wordCollection.GetTotalNumberOfWords();

            string[] lastLine = null;
            using (var reader = _fileHandler.GetReader())
            {
                reader.BaseStream.Seek(_fileHandler.FileSize / _numberOfThreads * id, SeekOrigin.Begin);
                while (true)
                {
                    if (wordCount - lastWordCount > 10000)
                    {
                        _wordCountActual += wordCount - lastWordCount;
                        lastWordCount     = wordCount;
                        _alpha            = _startingAlpha * (1 - _wordCountActual / (float)(_numberOfIterations * sum + 1));
                        if (_alpha < _startingAlpha * (float)0.0001)
                        {
                            _alpha = _startingAlpha * (float)0.0001;
                        }
                    }
                    if (sentenceLength == 0)
                    {
                        wordCount        = SetSentence(reader, wordCount, sentence, ref nextRandom, ref sentenceLength, ref lastLine, _wordCollection, _thresholdForOccurrenceOfWords);
                        sentencePosition = 0;
                    }
                    if (reader.EndOfStream || wordCount > sum / _numberOfThreads)
                    {
                        _wordCountActual += wordCount - lastWordCount;
                        localIter--;
                        if (localIter == 0)
                        {
                            break;
                        }
                        wordCount      = 0;
                        lastWordCount  = 0;
                        sentenceLength = 0;
                        reader.BaseStream.Seek(_fileHandler.FileSize / _numberOfThreads * id, SeekOrigin.Begin);
                        Console.WriteLine($"Iterations remaining: {localIter} Thread: {id}");
                        continue;
                    }
                    var wordIndex = sentence[sentencePosition];
                    if (!wordIndex.HasValue)
                    {
                        continue;
                    }
                    long c;
                    for (c = 0; c < _numberOfDimensions; c++)
                    {
                        neu1[c] = 0;
                    }
                    nextRandom = LinearCongruentialGenerator(nextRandom);
                    var randomWindowPosition = (long)(nextRandom % (ulong)_windowSize);

                    nextRandom = SkipGram(randomWindowPosition, sentencePosition, sentenceLength, sentence, wordIndex.Value, nextRandom);
                    sentencePosition++;
                    if (sentencePosition >= sentenceLength)
                    {
                        sentenceLength = 0;
                    }
                }
            }
            GC.Collect();
        }
Ejemplo n.º 3
0
        // TODO:  refactor this method
        private void TrainModelThreadStart(int id)
        {
            var sentenceLength   = 0;
            var sentencePosition = 0;
            var wordCount        = 0;
            var lastWordCount    = 0;
            var sentence         = new int?[_maxSentenceLength];
            var localIter        = _numberOfIterations;
            var localNetwork     = _neuralNetwork.CloneWithSameWeightValueReferences();

            var random = new Random();
            var sum    = _wordCollection.GetTotalNumberOfWords();

            string[] lastLine = null;
            using (var reader = _fileHandler.GetReader())
            {
                reader.BaseStream.Seek(_fileHandler.FileSize / _numberOfThreads * id, SeekOrigin.Begin);
                while (true)
                {
                    if (wordCount - lastWordCount > 10000)
                    {
                        _wordCountActual += wordCount - lastWordCount;
                        lastWordCount     = wordCount;
                        _alpha            = _startingAlpha * (1 - _wordCountActual / (float)(_numberOfIterations * sum + 1));
                        if (_alpha < _startingAlpha * (float)0.0001)
                        {
                            _alpha = _startingAlpha * (float)0.0001;
                        }
                    }
                    if (sentenceLength == 0)
                    {
                        wordCount        = SetSentence(reader, wordCount, sentence, random, ref sentenceLength, ref lastLine);
                        sentencePosition = 0;
                    }
                    if (reader.EndOfStream || wordCount > sum / _numberOfThreads)
                    {
                        _wordCountActual += wordCount - lastWordCount;
                        localIter--;
                        if (localIter == 0)
                        {
                            break;
                        }
                        wordCount      = 0;
                        lastWordCount  = 0;
                        sentenceLength = 0;
                        reader.BaseStream.Seek(_fileHandler.FileSize / _numberOfThreads * id, SeekOrigin.Begin);
                        Console.WriteLine($"Iterations remaining: {localIter} Thread: {id}");
                        continue;
                    }
                    var wordIndex = sentence[sentencePosition];
                    if (!wordIndex.HasValue)
                    {
                        continue;
                    }

                    if (_negativeSamples > 0)
                    {
                        TrainNetwork(localNetwork, sentencePosition, sentenceLength, sentence, wordIndex.Value, random);
                    }

                    sentencePosition++;

                    if (sentencePosition >= sentenceLength)
                    {
                        sentenceLength = 0;
                    }
                }
            }
            GC.Collect();
        }