Beispiel #1
0
        /// <inheritdoc />
        public override void SetPriorsFromPosteriors(
            int[] newWorkerToOldWorkerMap,
            int[] newWordToOldWordMap,
            ModelPosteriors modelPosteriors)
        {
            base.SetPriorsFromPosteriors(newWorkerToOldWorkerMap, newWordToOldWordMap, modelPosteriors);
            var biasedCommunityWordsModelPosteriors = (BiasedCommunityWordsPosteriors)modelPosteriors;
            var probWordPosteriors = biasedCommunityWordsModelPosteriors.ProbWordPosterior;

            AssertWhenDebugging.IsTrue(this.LabelValueCount == probWordPosteriors.Length);
            this.ProbWordsPrior.ObservedValue = Util.ArrayInit(this.LabelValueCount, i => Dirichlet.Symmetric(this.VocabularySize, this.ProbWordInitialCount));

            for (var labIndex = 0; labIndex < probWordPosteriors.Length; labIndex++)
            {
                var probWordPosterior = probWordPosteriors[labIndex];
                var oldPseudoCounts   = probWordPosterior.PseudoCount;
                var newPseudoCounts   = Vector.Constant(this.VocabularySize, this.ProbWordInitialCount, oldPseudoCounts.Sparsity);
                for (var wordIndex = 0; wordIndex < newWordToOldWordMap.Length; wordIndex++)
                {
                    var oldIdx = newWordToOldWordMap[wordIndex];
                    if (oldIdx >= 0)
                    {
                        newPseudoCounts[wordIndex] = oldPseudoCounts[oldIdx];
                    }
                }

                this.ProbWordsPrior.ObservedValue[labIndex] = new Dirichlet(newPseudoCounts);
            }
        }
Beispiel #2
0
        /// <inheritdoc />
        public override ModelPosteriors InferPosteriors(
            int[][] workerLabel,
            int[][] workerJudgedTweetIndex,
            int[][] words    = null,
            int[] wordCounts = null,
            int[] newWorkerToOldWorkerMap = null,
            int[] newWordToOldWordMap     = null,
            int?[] goldLabels             = null,
            ModelPosteriors oldPosteriors = null,
            int numIterations             = 20)
        {
            this.ObserveLabels(workerLabel, workerJudgedTweetIndex, goldLabels);
            this.ObserveWords(words, wordCounts);
            if (newWorkerToOldWorkerMap == null || oldPosteriors == null)
            {
                this.SetDefaultPriors();
            }
            else
            {
                this.SetPriorsFromPosteriors(newWorkerToOldWorkerMap, newWordToOldWordMap, oldPosteriors);
            }

            // Initialize messages
            var discreteUniform = Discrete.Uniform(this.NumberOfCommunities);

            this.WorkerCommunityInitializer.ObservedValue = Distribution <int> .Array(Util.ArrayInit(workerLabel.Length, w => Discrete.PointMass(discreteUniform.Sample(), this.NumberOfCommunities)));

            var posteriors = new BiasedCommunityWordsPosteriors();
            var evidences  = new List <double>();

            try
            {
                for (var it = 1; it <= numIterations; it++)
                {
                    this.Engine.NumberOfIterations = it;
                    posteriors.TrueLabel           = this.Engine.Infer <Discrete[]>(this.TrueLabel);
                    posteriors.CommunityCpt        = this.Engine.Infer <Dirichlet[][]>(this.ProbWorkerLabel);
                    posteriors.WorkerCommunities   = this.Engine.Infer <Discrete[]>(this.Community);
                    posteriors.BackgroundLabelProb = this.Engine.Infer <Dirichlet>(this.ProbLabel);
                    posteriors.ProbWordPosterior   = this.Engine.Infer <Dirichlet[]>(this.ProbWords);
                    if (this.HasEvidence)
                    {
                        posteriors.Evidence = this.Engine.Infer <Bernoulli>(this.Evidence);
                        Console.WriteLine($"Iteration {it} log evidence:\t{posteriors.Evidence.LogOdds:0.0000}");
                        evidences.Add(posteriors.Evidence.LogOdds);
                        if (ModelBase.HasConverged(evidences))
                        {
                            break;
                        }
                    }
                }
            }
            catch (Exception e)
            {
                Console.WriteLine(e);
            }

            return(posteriors);
        }
        /// <inheritdoc />
        public override ModelPosteriors InferPosteriors(
            int[][] workerLabel,
            int[][] workerJudgedTweetIndex,
            int[][] words    = null,
            int[] wordCounts = null,
            int[] newWorkerToOldWorkerMap = null,
            int[] newWordToOldWordMap     = null,
            int?[] goldLabels             = null,
            ModelPosteriors oldPosteriors = null,
            int numIterations             = 20)
        {
            this.ObserveLabels(workerLabel, workerJudgedTweetIndex, goldLabels);
            if (newWorkerToOldWorkerMap == null || oldPosteriors == null)
            {
                this.SetDefaultPriors();
            }
            else
            {
                this.SetPriorsFromPosteriors(newWorkerToOldWorkerMap, newWordToOldWordMap, oldPosteriors);
            }

            var posteriors = new HonestWorkerModelPosteriors();
            var evidences  = new List <double>();

            for (var it = 1; it <= numIterations; it++)
            {
                this.Engine.NumberOfIterations    = it;
                posteriors.TrueLabel              = this.Engine.Infer <Discrete[]>(this.TrueLabel);
                posteriors.BackgroundLabelProb    = this.Engine.Infer <Dirichlet>(this.ProbLabel);
                posteriors.RandomGuessProbability = this.Engine.Infer <Dirichlet>(this.RandomGuessProbability);
                posteriors.WorkerAbility          = this.Engine.Infer <Beta[]>(this.Ability);
                if (this.HasEvidence)
                {
                    posteriors.Evidence = this.Engine.Infer <Bernoulli>(this.Evidence);
                    Console.WriteLine($"Iteration {it} log evidence:\t{posteriors.Evidence.LogOdds:0.0000}");
                    evidences.Add(posteriors.Evidence.LogOdds);
                    if (ModelBase.HasConverged(evidences))
                    {
                        break;
                    }
                }
            }

            return(posteriors);
        }
        /// <inheritdoc />
        public override void SetPriorsFromPosteriors(
            int[] newWorkerToOldWorkerMap,
            int[] newWordToOldWordMap,
            ModelPosteriors modelPosteriors)
        {
            base.SetPriorsFromPosteriors(newWorkerToOldWorkerMap, newWordToOldWordMap, modelPosteriors);
            var biasedWorkerModelPosteriors = (BiasedWorkerModelPosteriors)modelPosteriors;

            this.ProbWorkerLabelPrior.ObservedValue = Util.ArrayInit(this.WorkerCount, input => GetCptPrior(InitialOnDiagonalPseudoCount, InitialOffDiagonalPseudoCount, this.LabelValueCount));
            for (var i = 0; i < newWorkerToOldWorkerMap.Length; i++)
            {
                var oldIdx = newWorkerToOldWorkerMap[i];
                if (oldIdx >= 0)
                {
                    this.ProbWorkerLabelPrior.ObservedValue[i] = biasedWorkerModelPosteriors.WorkerCpt[oldIdx];
                }
            }
        }
        /// <inheritdoc />
        public override void SetPriorsFromPosteriors(
            int[] newWorkerToOldWorkerMap,
            int[] newWordToOldWordMap,
            ModelPosteriors modelPosteriors)
        {
            base.SetPriorsFromPosteriors(newWorkerToOldWorkerMap, newWordToOldWordMap, modelPosteriors);
            var honestWorkerModelPosteriors = (HonestWorkerModelPosteriors)modelPosteriors;

            this.RandomGuessPrior.ObservedValue = honestWorkerModelPosteriors.RandomGuessProbability;
            this.AbilityPrior.ObservedValue     = Util.ArrayInit(this.WorkerCount, input => new Beta(2, 1));
            for (var i = 0; i < newWorkerToOldWorkerMap.Length; i++)
            {
                var oldIdx = newWorkerToOldWorkerMap[i];
                if (oldIdx >= 0)
                {
                    this.AbilityPrior.ObservedValue[i] = honestWorkerModelPosteriors.WorkerAbility[oldIdx];
                }
            }
        }