Ejemplo n.º 1
0
        /* Inference */
        public BCCWordsPosteriors InferPosteriors(
            int[][] workerLabel, int[][] workerTaskIndex, int[][] words, int[] wordCounts, int[] trueLabels = null,
            int numIterations = 35)
        {
            ObserveCrowdLabels(workerLabel, workerTaskIndex);

            ObserveWords(words, wordCounts);

            if (trueLabels != null)
            {
                ObserveTrueLabels(trueLabels);
            }

            BCCWordsPosteriors posteriors = new BCCWordsPosteriors();

            Console.WriteLine("\n***** BCC Words *****\n");
            for (int it = 1; it <= numIterations; it++)
            {
                Engine.NumberOfIterations        = it;
                posteriors.TrueLabel             = Engine.Infer <Discrete[]>(TrueLabel);
                posteriors.WorkerConfusionMatrix = Engine.Infer <Dirichlet[][]>(WorkerConfusionMatrix);
                posteriors.BackgroundLabelProb   = Engine.Infer <Dirichlet>(BackgroundLabelProb);
                posteriors.ProbWordPosterior     = Engine.Infer <Dirichlet[]>(ProbWord);
                Console.WriteLine("Iteration {0}:\t{1:0.0000}", it, posteriors.TrueLabel[0]);
            }

            posteriors.Evidence = Engine.Infer <Bernoulli>(evidence);
            return(posteriors);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Runs the majority vote method on the data.
        /// </summary>
        /// <param name="modelName"></param>
        /// <param name="data">The data</param>
        /// <param name="mode"></param>
        /// <param name="calculateAccuracy">Compute the accuracy (true).</param>
        /// <param name="fullData"></param>
        /// <param name="model"></param>
        /// <param name="useMajorityVote"></param>
        /// <param name="useRandomLabel"></param>
        /// <returns>The updated results</returns>
        public void RunBCCWords(string modelName,
                                IList <Datum> data,
                                IList <Datum> fullData,
                                BCCWords model,
                                RunMode mode,
                                bool calculateAccuracy,
                                bool useMajorityVote = false,
                                bool useRandomLabel  = false)
        {
            DataMappingWords MappingWords = null;

            if (FullMapping == null)
            {
                FullMapping = new DataMapping(fullData);
            }

            if (Mapping == null)
            {
                // Build vocabulary
                Console.Write("Building vocabulary...");
                Stopwatch stopwatch = new Stopwatch();
                stopwatch.Start();
                string[] corpus = data.Select(d => d.BodyText).Distinct().ToArray();
                Vocabulary = BuildVocabularyFromCorpus(corpus);
                Console.WriteLine("done. Elapsed time: {0}", stopwatch.Elapsed);

                // Build data mapping
                this.Mapping    = new DataMappingWords(data, MappingWords.Vocabulary);
                MappingWords    = Mapping as DataMappingWords;
                this.GoldLabels = MappingWords.GetGoldLabelsPerTaskId();
            }

            MappingWords = Mapping as DataMappingWords;
            int[] trueLabels = null;
            if (useMajorityVote)
            {
                if (MappingWords != null)
                {
                    var majorityLabel = MappingWords.GetMajorityVotesPerTaskId(data);
                    trueLabels = Util.ArrayInit(FullMapping.TaskCount, i => majorityLabel.ContainsKey(Mapping.TaskIndexToId[i]) ? (int)majorityLabel[Mapping.TaskIndexToId[i]] : Rand.Int(Mapping.LabelMin, Mapping.LabelMax + 1));
                    data       = MappingWords.BuildDataFromAssignedLabels(majorityLabel, data);
                }
            }

            if (useRandomLabel)
            {
                var randomLabels = MappingWords.GetRandomLabelPerTaskId(data);
                data = MappingWords.BuildDataFromAssignedLabels(randomLabels, data);
            }

            var labelsPerWorkerIndex      = MappingWords.GetLabelsPerWorkerIndex(data);
            var taskIndicesPerWorkerIndex = MappingWords.GetTaskIndicesPerWorkerIndex(data);

            // Create model
            ClearResults();
            model.CreateModel(MappingWords.TaskCount, MappingWords.LabelCount, MappingWords.WordCount);

            // Run model inference
            BCCWordsPosteriors posteriors = model.InferPosteriors(labelsPerWorkerIndex, taskIndicesPerWorkerIndex, MappingWords.WordIndicesPerTaskIndex, MappingWords.WordCountsPerTaskIndex, trueLabels);

            // Update results
            UpdateResults(posteriors, mode);

            // Compute accuracy
            if (calculateAccuracy)
            {
                UpdateAccuracy();
            }
        }