コード例 #1
0
        /// <summary>
        /// Calculates the corpus information.
        /// </summary>
        /// <param name="corpus">
        /// The array of texts.
        /// </param>
        /// <param name="vocabularyThreshold">
        /// Optional vocabulary threshold.If set, terms need to be seen at least this many times to be considered as part of the vocabulary.
        /// </param>
        /// <returns>
        /// The corpus information.
        /// </returns>
        public static CorpusInformation BuildCorpusInformation(string[] corpus, int?vocabularyThreshold = null)
        {
            Console.WriteLine(@"Building vocabulary... ");
            var docAndVocabularyInfo = CorpusInformation.FromDocs(corpus);

            Console.WriteLine($@"Vocabulary size: {docAndVocabularyInfo.NumberOfWords}");

            if (vocabularyThreshold.HasValue)
            {
                Console.WriteLine($@"Deleting words seen less than {vocabularyThreshold.Value} times...");
                docAndVocabularyInfo = docAndVocabularyInfo.ConvertToThresholdedVocabulary(vocabularyThreshold);
                Console.WriteLine($@"Vocabulary size after deletion: {docAndVocabularyInfo.NumberOfWords}");
            }

            return(docAndVocabularyInfo);
        }
コード例 #2
0
        /// <inheritdoc />
        /// <summary>
        /// Initializes a new instance of the <see cref="T:CrowdSourcing.WordsDataMapping" /> class.
        /// </summary>
        /// <param name="data">
        /// The data.
        /// </param>
        /// <param name="labelValueToString">
        /// The mapping from label values in file to label strings.
        /// </param>
        /// <param name="corpusInfo">
        /// The corpus information.
        /// </param>
        /// <param name="maxWordsPerTweet">The maximum number of words considered per tweet.</param>
        public CrowdDataWithTextMapping(
            CrowdDataWithText data,
            Dictionary <int, string> labelValueToString,
            CorpusInformation corpusInfo)
            : base(data, labelValueToString)
        {
            this.CorpusInfo = corpusInfo;
            var docs = this.TweetIds.Select(tid => data.TweetTexts[tid]).ToArray();

            this.WordIndicesPerTweetIndex = docs.Select(
                doc =>
            {
                var indices = corpusInfo.GetWordIndices(doc).ToArray();
                return(indices);
            }).ToArray();
            this.WordCountsPerTweetIndex = this.WordIndicesPerTweetIndex.Select(arr => arr.Length).ToArray();
        }
コード例 #3
0
        /// <summary>
        /// Restricts the current instance to a sub-vocabulary
        /// </summary>
        /// <param name="subVocabulary">The sub-vocabulary</param>
        /// <returns>The document and vocabulary information for the sub-vocabulary.</returns>
        public virtual CorpusInformation ConvertToSubVocabulary(List <string> subVocabulary)
        {
            var result                    = new CorpusInformation();
            var subVocabDistinct          = subVocabulary.Distinct().ToList();
            var subVocabSize              = subVocabDistinct.Count;
            var subVocabIndexToVocabIndex =
                subVocabDistinct.Select(word => this.VocabularyToVocabularyIndex[word]).ToArray();

            var subVocabToSubVocabIndex = subVocabDistinct
                                          .Select((id, idx) => new KeyValuePair <string, int>(id, idx)).ToDictionary(x => x.Key, y => y.Value);

            result.Documents = new int[this.NumberOfDocuments][];

            for (var docIndex = 0; docIndex < this.NumberOfDocuments; docIndex++)
            {
                result.Documents[docIndex] = this.Documents[docIndex]
                                             .Select(idx => this.Vocabulary[idx]).Where(token => subVocabToSubVocabIndex.ContainsKey(token))
                                             .Select(token => subVocabToSubVocabIndex[token]).ToArray();
            }

            result.DocumentCounts = new int[subVocabSize];
            result.VocabularyIdfs = new double[subVocabSize];

            for (var subVocabIndex = 0; subVocabIndex < subVocabSize; subVocabIndex++)
            {
                result.DocumentCounts[subVocabIndex] =
                    this.DocumentCounts[subVocabIndexToVocabIndex[subVocabIndex]];
                result.VocabularyIdfs[subVocabIndex] =
                    this.VocabularyIdfs[subVocabIndexToVocabIndex[subVocabIndex]];
            }

            result.Vocabulary = subVocabDistinct.ToArray();
            result.VocabularyToVocabularyIndex = subVocabToSubVocabIndex;

            return(result);
        }
コード例 #4
0
ファイル: ModelRuns.cs プロジェクト: raudut/Infernet-asthme
        RunSweep(
            CrowdDataWithText[] trainingDataSets,
            CrowdDataWithText validationDataSet,
            ModelBase model,
            CorpusInformation corpusInformation,
            Func <CrowdDataMapping, ModelBase, ModelRunnerBase, ModelRunnerBase> runnerCreator,
            Func <ModelRunnerBase, Dictionary <string, object> > trainingResultGetter,
            Func <ModelRunnerBase, Dictionary <string, object>, Dictionary <string, object> > validationResultGetter,
            ExperimentParameters experimentParameters,
            CrowdDataWithText validationDataSetNoLabels    = null,
            List <ModelRunnerBase> previousTrainingRunners = null,
            int numIterations = 200,
            Dictionary <string, Dictionary <string, object> > resultStorage = null)
        {
            var       allResultsForThisModel     = resultStorage ?? new Dictionary <string, Dictionary <string, object> >();
            var       trainingRunners            = new List <ModelRunnerBase>();
            const int NumIterationsForValidation = 5;

            for (var i = 0; i < trainingDataSets.Length; i++)
            {
                var currentModelName =
                    $"{model.Name}_{i}"; // Chart code assumes suffix of dash then number for ordering chart points
                var trainingData    = trainingDataSets[i];
                var trainingMapping = new CrowdDataWithTextMapping(
                    trainingData,
                    LabelValuesToString,
                    corpusInformation);

                var trainingRunner = runnerCreator.Invoke(trainingMapping, model, previousTrainingRunners?[i]);
                trainingRunners.Add(trainingRunner);
                Rand.Restart(experimentParameters.RandomSeed);
                RunModel(trainingRunner, currentModelName + "_Training", numIterations,
                         experimentParameters.UseGoldLabelsInTraining);
                var trainingResults = trainingResultGetter?.Invoke(trainingRunner) ?? new Dictionary <string, object>();

                trainingResults[ErrorsKey]        = trainingRunner.GetErrors();
                trainingResults[WorkerMetricsKey] = GetWorkerMetrics(trainingRunner.DataMapping.Data, trainingRunner, experimentParameters.MaximumNumberWorkers);

                var currentResults =
                    new Dictionary <string, object> {
                    [TrainingKey] = trainingResults
                };

                var validationMapping = new CrowdDataWithTextMapping(
                    validationDataSet,
                    LabelValuesToString,
                    corpusInformation);

                var validationRunner = runnerCreator.Invoke(validationMapping, model, trainingRunner);
                Rand.Restart(experimentParameters.RandomSeed);
                var validationMetrics = RunModel(validationRunner, currentModelName + "_Validation",
                                                 NumIterationsForValidation, false);

                foreach (var prediction in validationRunner.Posteriors.TrueLabel)
                {
                    Console.WriteLine(prediction);
                }

                var validationResults = validationResultGetter.Invoke(validationRunner, validationMetrics);
                validationResults[ErrorsKey]        = validationRunner.GetErrors();
                validationResults[WorkerMetricsKey] = GetWorkerMetrics(validationRunner.DataMapping.Data,
                                                                       validationRunner, experimentParameters.MaximumNumberWorkers);
                currentResults[ValidationKey] = validationResults;

                if (validationDataSetNoLabels != null)
                {
                    var validationMappingNoLabels = new CrowdDataWithTextMapping(
                        validationDataSetNoLabels,
                        LabelValuesToString,
                        corpusInformation);

                    var validationNoLabelsRunner =
                        runnerCreator.Invoke(validationMappingNoLabels, model, trainingRunner);
                    Rand.Restart(experimentParameters.RandomSeed);
                    var validationMetricsNoLabels = RunModel(validationNoLabelsRunner,
                                                             currentModelName + "_ValidationNoLabels", NumIterationsForValidation, false);
                    currentResults[ValidationNoLabelsKey] =
                        validationResultGetter.Invoke(validationNoLabelsRunner, validationMetricsNoLabels);
                }

                allResultsForThisModel[$"TrainingPercent_{Math.Min((i + 1) * 100 / trainingDataSets.Length, 100)}"] =
                    currentResults;
            }

            return(allResultsForThisModel, trainingRunners);
        }
コード例 #5
0
ファイル: Program.cs プロジェクト: raudut/Infernet-asthme
        public static void RunExperiments(Outputter outputter, ExperimentParameters experimentParameters, string dataPath)
        {
            var engine = new InferenceEngine(new ExpectationPropagation())
            {
                ShowFactorGraph = false,
                ShowWarnings    = true,
                ShowProgress    = false
            };

            // Set engine flags
            engine.Compiler.WriteSourceFiles    = true;
            engine.Compiler.UseParallelForLoops = false;

            string labelsFileName     = Path.Combine(dataPath, "HarnessingTheCrowdLabels.tsv");
            string goldLabelsFileName = Path.Combine(dataPath, "HarnessingTheCrowdGoldLabels.tsv");
            string textsFileName      = Path.Combine(dataPath, "HarnessingTheCrowdTexts.tsv");

            if (!File.Exists(labelsFileName) ||
                !File.Exists(goldLabelsFileName) ||
                !File.Exists(textsFileName))
            {
                throw new FileNotFoundException("Unfortunately, we were not able to ship the data necessary to run the code for this chapter.");
            }

            var crowdData = CrowdDataWithText.LoadData(
                labelsFileName,
                goldLabelsFileName,
                textsFileName,
                new HashSet <int>(LabelValuesToString.Keys));

            Rand.Restart(experimentParameters.RandomSeed);

            // Preparing the input data
            var split              = crowdData.SplitData(experimentParameters.FractionGoldLabelsReservedForTraining);
            var fullTrainingData   = split[CrowdData.Mode.Training];
            var fullValidationData = split[CrowdData.Mode.Validation];

            var trainingData = (CrowdDataWithText)fullTrainingData.LimitData(
                maxNumTweets: experimentParameters.UseOnlyTweetsWithGoldLabels
                    ? fullTrainingData.NumGoldTweets
                    : experimentParameters.NumberOfTrainingTweets,
                randomSeed: experimentParameters.RandomSeed);

            if (experimentParameters.UseBalancedTrainingSets)
            {
                trainingData = (CrowdDataWithText)fullTrainingData.LimitData(
                    maxNumTweets: experimentParameters.UseOnlyTweetsWithGoldLabels
                        ? fullTrainingData.NumGoldTweets
                        : experimentParameters.NumberOfTrainingTweets,
                    balanceTweetsByLabel: true,
                    randomSeed: experimentParameters.RandomSeed);
            }

            var validationData = (CrowdDataWithText)fullValidationData.LimitData(
                experimentParameters.NumberOfValidationJudgments,
                randomSeed: experimentParameters.RandomSeed);

            var validationDataWithNoWorkerLabels = (CrowdDataWithText)validationData.LimitData(maxNumWorkers: 0);

            var trainingWorkerMetrics   = GetWorkerMetrics(trainingData, null, experimentParameters.MaximumNumberWorkers);
            var validationWorkerMetrics =
                GetWorkerMetrics(validationData, null, experimentParameters.MaximumNumberWorkers);
            var corpus            = trainingData.TweetTexts.Values.Distinct().ToArray();
            var corpusInformation = CorpusInformation.BuildCorpusInformation(
                corpus,
                experimentParameters.VocabularyThreshold);

            var totalNumTrainingJudgments = trainingData.CrowdLabels.Count;
            var numTrainingJudgments      = Util.ArrayInit(
                experimentParameters.NumDataSizes,
                i => Math.Round(totalNumTrainingJudgments * (i + 1) / (double)experimentParameters.NumDataSizes));

            var trainingDataSets = Util.ArrayInit(
                experimentParameters.NumDataSizes,
                i => (CrowdDataWithText)trainingData.LimitData((int)numTrainingJudgments[i],
                                                               randomSeed: experimentParameters.RandomSeed));

            var trainingTweets         = trainingData.Tweets;
            var trainingCrowdWorkers   = trainingData.Workers;
            var validationTweets       = validationData.Tweets;
            var validationCrowdWorkers = validationData.Workers;

            var allDataInfo = new Dictionary <string, object>
            {
                ["Full training data"]           = fullTrainingData,
                ["Full validation data"]         = fullValidationData,
                ["Validation data"]              = validationData,
                ["Training data"]                = trainingData,
                ["Training data sets"]           = trainingDataSets,
                ["Validation data set"]          = validationData,
                ["Training set tweets"]          = trainingTweets,
                ["Training set crowd workers"]   = trainingCrowdWorkers,
                ["Validation set tweets"]        = validationTweets,
                ["Validation set crowd workers"] = validationCrowdWorkers,
                ["CorpusInformation"]            = corpusInformation
            };

            outputter.Out(allDataInfo, "Inputs");

            WriteCrowdDataDetailsToConsole(fullTrainingData, "Full training data");
            WriteCrowdDataDetailsToConsole(fullValidationData, "Full validation data");
            WriteCrowdDataDetailsToConsole(trainingData, "Training set");
            WriteCrowdDataDetailsToConsole(validationData, "Validation set");

            var now    = DateTime.Now;
            var nowStr = $"{now.Year:D4}{now.Month:D2}{now.Day:D2}T{now.Hour:D2}{now.Minute:D2}{now.Second:D2}";

            var experimentInfo = new Dictionary <string, object>
            {
                ["Date"]       = nowStr,
                ["Parameters"] = experimentParameters,
                [$"Training {WorkerMetricsKey}"]         = trainingWorkerMetrics,
                [$"Validation {WorkerMetricsKey}"]       = validationWorkerMetrics,
                [$"Training {LabelCountHistogramKey}"]   = GetWorkerLabelCountHistogram(trainingData),
                [$"Validation {LabelCountHistogramKey}"] = GetWorkerLabelCountHistogram(validationData)
            };

            outputter.Out(experimentInfo, "Experiment");

            // Data structures for comparing results
            var metricsForPlots = new List <Metric> {
                Metric.Accuracy, Metric.AverageLogProb
            };
            var metricsStripCharts = new Dictionary <string, Dictionary <string, PointWithBounds[]> >();
            var metricsBarCharts   = new Dictionary <string, Dictionary <string, double> >();

            foreach (var metric in metricsForPlots)
            {
                var metricString = metric.ToString();
                metricsStripCharts[metricString] = new Dictionary <string, PointWithBounds[]>();
                metricsBarCharts[metricString]   = new Dictionary <string, double>();
            }

            // Creates a snapshot of currently accumulated comparison metrics
            Dictionary <string, object> getCurrentComparisonSnapshot()
            {
                var metricsStripChartsSnapshot =
                    new Dictionary <string, Dictionary <string, PointWithBounds[]> >();

                var metricsBarChartsSnapshot =
                    new Dictionary <string, Dictionary <string, double> >();

                foreach (var metric in metricsForPlots)
                {
                    var metricString = metric.ToString();
                    metricsStripChartsSnapshot[metricString] = new Dictionary <string, PointWithBounds[]>(metricsStripCharts[metricString]);
                    metricsBarChartsSnapshot[metricString]   = new Dictionary <string, double>(metricsBarCharts[metricString]);
                }
                return(new Dictionary <string, object>
                {
                    ["Strip Charts"] = metricsStripChartsSnapshot,
                    ["Bar Charts"] = metricsBarChartsSnapshot
                });
            }

            // Running models as specified in experiment parameters
            bool section2Started = false;
            Dictionary <string, object> section2Comparison = null;

            if ((experimentParameters.ModelTypes & ModelTypes.MajorityVote) != 0)
            {
                Console.WriteLine($"\n{Contents.S2TryingOutTheWorkerModel.NumberedName}\n");
                section2Started = true;
                var accuracyString      = Metric.Accuracy.ToString();
                var majorityVoteString  = "Majority Vote";
                var majorityVoteMetrics = RunMajorityVoteModel(validationData, majorityVoteString);
                outputter.Out(majorityVoteMetrics, Contents.S2TryingOutTheWorkerModel.NumberedName, majorityVoteString);

                metricsStripCharts[accuracyString][majorityVoteString] = GetChartPointsForMetric(
                    Enumerable.Repeat(majorityVoteMetrics, trainingDataSets.Length).ToList(),
                    Metric.Accuracy.ToString(),
                    numTrainingJudgments);

                metricsBarCharts[accuracyString][majorityVoteString] =
                    metricsStripCharts[accuracyString][majorityVoteString].Last().Y;

                section2Comparison = getCurrentComparisonSnapshot();
                outputter.Out(section2Comparison, Contents.S2TryingOutTheWorkerModel.NumberedName, "Comparison");
            }

            if ((experimentParameters.ModelTypes & ModelTypes.HonestWorker) != 0)
            {
                if (!section2Started)
                {
                    Console.WriteLine($"\n{Contents.S2TryingOutTheWorkerModel.NumberedName}\n");
                }
                var currentModel        = new HonestWorkerModel(engine);
                var honestWorkerMetrics = new Dictionary <string, Dictionary <string, object> >();
                outputter.Out(honestWorkerMetrics, Contents.S2TryingOutTheWorkerModel.NumberedName, currentModel.Name);

                honestWorkerMetrics = RunSweep(
                    trainingDataSets,
                    validationData,
                    currentModel,
                    corpusInformation,
                    (map, model, runner) => new HonestWorkerRunner(map, (HonestWorkerModel)model, runner),
                    GetHonestWorkerTrainingResults,
                    (runner, metrics) =>
                    GetHonestWorkerValidationResults(runner, metrics, experimentParameters.MaximumNumberWorkers),
                    experimentParameters,
                    resultStorage: honestWorkerMetrics).Results;


                foreach (var metric in metricsForPlots)
                {
                    metricsStripCharts[metric.ToString()][currentModel.Name] =
                        GetChartPointsForMetric(honestWorkerMetrics, metric, numTrainingJudgments);
                    metricsBarCharts[metric.ToString()][currentModel.Name] =
                        metricsStripCharts[metric.ToString()][currentModel.Name].Last().Y;
                }

                section2Comparison = getCurrentComparisonSnapshot();
                outputter.Out(section2Comparison, Contents.S2TryingOutTheWorkerModel.NumberedName, "Comparison");
            }

            if ((experimentParameters.ModelTypes & ModelTypes.BiasedWorker) != 0)
            {
                Console.WriteLine($"\n{Contents.S3CorrectingForWorkerBiases.NumberedName}\n");
                var currentModel        = new BiasedWorkerModel(engine);
                var biasedWorkerMetrics = new Dictionary <string, Dictionary <string, object> >();
                outputter.Out(biasedWorkerMetrics, Contents.S3CorrectingForWorkerBiases.NumberedName, currentModel.Name);

                biasedWorkerMetrics = RunSweep(
                    trainingDataSets,
                    validationData,
                    currentModel,
                    corpusInformation,
                    (map, model, runner) => new BiasedWorkerModelRunner(map, (BiasedWorkerModel)model, runner),
                    runner => GetBiasedWorkerTrainingResults(runner, experimentParameters.MaximumNumberWorkers),
                    GetValidationResults,
                    experimentParameters,
                    resultStorage: biasedWorkerMetrics).Results;

                foreach (var metric in metricsForPlots)
                {
                    metricsStripCharts[metric.ToString()][currentModel.Name] =
                        GetChartPointsForMetric(biasedWorkerMetrics, metric, numTrainingJudgments);
                    metricsBarCharts[metric.ToString()][currentModel.Name] =
                        metricsStripCharts[metric.ToString()][currentModel.Name].Last().Y;
                }

                outputter.Out(getCurrentComparisonSnapshot(), Contents.S3CorrectingForWorkerBiases.NumberedName, "Comparison");
            }

            if ((experimentParameters.ModelTypes & ModelTypes.BiasedCommunity) != 0)
            {
                Console.WriteLine($"\n{Contents.S4CommunitiesOfWorkers.NumberedName}\n");
                engine.Algorithm = new VariationalMessagePassing();
                foreach (int numCommunities in experimentParameters.NumberOfCommunitiesSweep)
                {
                    var currentModel = new BiasedCommunityModel(engine)
                    {
                        NumberOfCommunities = numCommunities
                    };
                    var result = new Dictionary <string, Dictionary <string, object> >();
                    outputter.Out(result, Contents.S4CommunitiesOfWorkers.NumberedName, currentModel.Name);

                    result = RunSweep(
                        trainingDataSets,
                        validationData,
                        currentModel,
                        corpusInformation,
                        (map, model, runner) => new BiasedCommunityModelRunner(
                            map,
                            (BiasedCommunityModel)model,
                            runner),
                        GetBiasedCommunityTrainingResults,
                        GetValidationResults,
                        experimentParameters,
                        resultStorage: result).Results;

                    foreach (var metric in metricsForPlots)
                    {
                        metricsStripCharts[metric.ToString()][currentModel.Name] =
                            GetChartPointsForMetric(result, metric, numTrainingJudgments);
                        metricsBarCharts[metric.ToString()][currentModel.Name] =
                            metricsStripCharts[metric.ToString()][currentModel.Name].Last().Y;
                    }
                }

                outputter.Out(getCurrentComparisonSnapshot(), Contents.S4CommunitiesOfWorkers.NumberedName, "Comparison");
            }

            if ((experimentParameters.ModelTypes & ModelTypes.BiasedCommunityWords) != 0)
            {
                Console.WriteLine($"\n{Contents.S5MakingUseOfTheTweets.NumberedName}\n");
                // VMP is an order of magnitude faster than EP here, while producing almost the same result
                engine.Algorithm = new VariationalMessagePassing();
                foreach (int numCommunities in experimentParameters.NumberOfCommunitiesSweep)
                {
                    var currentModel = new BiasedCommunityWordsModel(engine)
                    {
                        NumberOfCommunities = numCommunities
                    };
                    var result = new Dictionary <string, Dictionary <string, object> >();
                    outputter.Out(result, Contents.S5MakingUseOfTheTweets.NumberedName, currentModel.Name);

                    result = RunSweep(
                        trainingDataSets,
                        validationData,
                        currentModel,
                        corpusInformation,
                        (map, model, runner) => new BiasedCommunityWordsRunner(
                            (CrowdDataWithTextMapping)map,
                            (BiasedCommunityWordsModel)model,
                            runner),
                        GetBiasedCommunityWordsRunnerResults,
                        GetValidationResults,
                        experimentParameters,
                        validationDataWithNoWorkerLabels,
                        null,
                        50,
                        resultStorage: result).Results;

                    foreach (var metric in metricsForPlots)
                    {
                        metricsStripCharts[metric.ToString()][currentModel.Name] =
                            GetChartPointsForMetric(result, metric, numTrainingJudgments);
                        metricsBarCharts[metric.ToString()][currentModel.Name] =
                            metricsStripCharts[metric.ToString()][currentModel.Name].Last().Y;
                    }
                }

                outputter.Out(getCurrentComparisonSnapshot(), Contents.S5MakingUseOfTheTweets.NumberedName, "Comparison");
            }

            Console.WriteLine("\nCompleted all experiments.");
        }