Beispiel #1
0
        /// <summary>
        /// Updates the scores for a dataset.
        /// </summary>
        /// <param name="dataset">The dataset to use.</param>
        /// <param name="scores">The current scores for this dataset</param>
        /// <param name="iteration">The iteration of the algorithm.
        /// Used to look up the sub-graph to use to update the score.</param>
        /// <returns></returns>
        private void UpdateScoresForSet(Dataset dataset, double[] scores, int iteration)
        {
            DefineDocumentThreadBlocks(dataset.NumDocs, BlockingThreadPool.NumThreads, out int[] threadBlocks);

            var updateTask = ThreadTaskManager.MakeTask(
                (threadIndex) =>
            {
                int startIndexInclusive = threadBlocks[threadIndex];
                int endIndexExclusive   = threadBlocks[threadIndex + 1];
                for (int featureIndex = 0; featureIndex < _subGraph.Splits.Length; featureIndex++)
                {
                    var featureIndexer = dataset.GetIndexer(featureIndex);
                    for (int doc = startIndexInclusive; doc < endIndexExclusive; doc++)
                    {
                        if (featureIndexer[doc] <= _subGraph.Splits[featureIndex][iteration].SplitPoint)
                        {
                            scores[doc] += _subGraph.Splits[featureIndex][iteration].LteValue;
                        }
                        else
                        {
                            scores[doc] += _subGraph.Splits[featureIndex][iteration].GtValue;
                        }
                    }
                }
            }, BlockingThreadPool.NumThreads);

            updateTask.RunTask();
        }
Beispiel #2
0
        private void TrainOnEachFeature(double[] gradient, double[] scores, double sumTargets, double sumWeights, int iteration)
        {
            var trainTask = ThreadTaskManager.MakeTask(
                (feature) =>
            {
                TrainingIteration(feature, gradient, scores, sumTargets, sumWeights, iteration);
            }, TrainSet.NumFeatures);

            trainTask.RunTask();
        }
Beispiel #3
0
        /// <summary>
        /// Center the graph using the mean response per feature on the training set.
        /// </summary>
        private void CenterGraph()
        {
            // Define this once
            DefineDocumentThreadBlocks(TrainSet.NumDocs, BlockingThreadPool.NumThreads, out int[] trainThreadBlocks);

            // Compute the mean of each Effect
            var meanEffects = new double[BinEffects.Length];
            var updateTask  = ThreadTaskManager.MakeTask(
                (threadIndex) =>
            {
                int startIndexInclusive = trainThreadBlocks[threadIndex];
                int endIndexExclusive   = trainThreadBlocks[threadIndex + 1];
                for (int featureIndex = 0; featureIndex < BinEffects.Length; featureIndex++)
                {
                    var featureIndexer = TrainSet.GetIndexer(featureIndex);
                    for (int doc = startIndexInclusive; doc < endIndexExclusive; doc++)
                    {
                        var bin = featureIndexer[doc];
                        double totalEffect;
                        double newTotalEffect;
                        do
                        {
                            totalEffect    = meanEffects[featureIndex];
                            newTotalEffect = totalEffect + BinEffects[featureIndex][bin];
                        } while (totalEffect !=
                                 Interlocked.CompareExchange(ref meanEffects[featureIndex], newTotalEffect, totalEffect));
                        // Update the shared effect, being careful of threading
                    }
                }
            }, BlockingThreadPool.NumThreads);

            updateTask.RunTask();

            // Compute the intercept and center each graph
            MeanEffect = 0.0;
            for (int featureIndex = 0; featureIndex < BinEffects.Length; featureIndex++)
            {
                // Compute the mean effect
                meanEffects[featureIndex] /= TrainSet.NumDocs;

                // Shift the mean from the bins into the intercept
                MeanEffect += meanEffects[featureIndex];
                for (int bin = 0; bin < BinEffects[featureIndex].Length; ++bin)
                {
                    BinEffects[featureIndex][bin] -= meanEffects[featureIndex];
                }
            }
        }
Beispiel #4
0
        private void SumUpsAcrossFlocks(double[] gradient, double sumTargets, double sumWeights)
        {
            var sumupTask = ThreadTaskManager.MakeTask(
                (flockIndex) =>
            {
                _histogram[flockIndex].Sumup(
                    TrainSet.FlockToFirstFeature(flockIndex),
                    null,
                    TrainSet.NumDocs,
                    sumTargets,
                    sumWeights,
                    gradient,
                    TrainSet.SampleWeights,
                    null);
            }, TrainSet.NumFlocks);

            sumupTask.RunTask();
        }
Beispiel #5
0
        private void InitializeThreads()
        {
            ParallelTraining = new SingleTrainer();

            int numThreads = Args.NumThreads ?? Environment.ProcessorCount;

            if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
            {
                using (var ch = Host.Start("GamTrainer"))
                {
                    numThreads = Host.ConcurrencyFactor;
                    ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
                               + "setting of the environment. Using {0} training threads instead.", numThreads);
                }
            }

            ThreadTaskManager.Initialize(numThreads);
        }
Beispiel #6
0
 private void InitializeThreads()
 {
     ParallelTraining = new SingleTrainer();
     ThreadTaskManager.Initialize(GamTrainerOptions.NumberOfThreads ?? Environment.ProcessorCount);
 }