コード例 #1
0
        /// <summary>
        /// Sets the priors of CBCC.
        /// </summary>
        /// <param name="workerCount">The number of workers.</param>
        /// <param name="priors">The priors.</param>
        protected override void SetPriors(int workerCount, BCC.Posteriors priors)
        {
            int communityCount = m.SizeAsInt;
            int labelCount     = c.SizeAsInt;

            WorkerCount.ObservedValue = workerCount;
            NoiseMatrix.ObservedValue = PositiveDefiniteMatrix.IdentityScaledBy(labelCount, NoisePrecision);
            CommunityModel.Posteriors cbccPriors = (CommunityModel.Posteriors)priors;

            if (cbccPriors == null || cbccPriors.BackgroundLabelProb == null)
            {
                BackgroundLabelProbPrior.ObservedValue = Dirichlet.Uniform(labelCount);
            }
            else
            {
                BackgroundLabelProbPrior.ObservedValue = cbccPriors.BackgroundLabelProb;
            }

            if (cbccPriors == null || cbccPriors.CommunityProb == null)
            {
                CommunityProbPrior.ObservedValue = CommunityProbPriorObserved;
            }
            else
            {
                CommunityProbPrior.ObservedValue = cbccPriors.CommunityProb;
            }

            if (cbccPriors == null || cbccPriors.CommunityScoreMatrix == null)
            {
                CommunityScoreMatrixPrior.ObservedValue = CommunityScoreMatrixPriorObserved;
            }
            else
            {
                CommunityScoreMatrixPrior.ObservedValue = cbccPriors.CommunityScoreMatrix;
            }

            if (cbccPriors == null || cbccPriors.TrueLabelConstraint == null)
            {
                TrueLabelConstraint.ObservedValue = Util.ArrayInit(TaskCount, t => Discrete.Uniform(labelCount));
            }
            else
            {
                TrueLabelConstraint.ObservedValue = cbccPriors.TrueLabelConstraint;
            }
        }
コード例 #2
0
ファイル: Results.cs プロジェクト: kant2002/infer
        /// <summary>
        /// Deserializes the parameters of CBCC from an xml file (used in the LoadAndUseCommunityPriors mode).
        /// </summary>
        /// <param name="fileName">The file name.</param>
        /// <param name="numCommunities">The number of communities.</param>
        /// <returns></returns>
        CommunityModel.Posteriors DeserializeCommunityPosteriors(string fileName, int numCommunities)
        {
            CommunityModel.Posteriors cbccPriors = new CommunityModel.Posteriors();
            DataContractSerializer    serializer = new DataContractSerializer(typeof(NonTaskWorkerParameters), new DataContractSerializerSettings {
                DataContractResolver = new InferDataContractResolver()
            });

            using (XmlDictionaryReader reader = XmlDictionaryReader.CreateTextReader(new FileStream(fileName, FileMode.Open), new XmlDictionaryReaderQuotas()))
            {
                var ntwp = (NonTaskWorkerParameters)serializer.ReadObject(reader);

                if (ntwp.BackgroundLabelProb.Dimension != Mapping.LabelCount)
                {
                    throw new ApplicationException("Unexpected number of labels");
                }

                BackgroundLabelProb            = ntwp.BackgroundLabelProb;
                cbccPriors.BackgroundLabelProb = ntwp.BackgroundLabelProb;
                if (ntwp.CommunityScoreMatrix.Length != numCommunities)
                {
                    throw new ApplicationException("Unexpected number of communities");
                }

                if (ntwp.CommunityScoreMatrix[0][0].Dimension != Mapping.LabelCount)
                {
                    throw new ApplicationException("Unexpected number of labels");
                }

                CommunityScoreMatrix            = ntwp.CommunityScoreMatrix;
                cbccPriors.CommunityScoreMatrix = ntwp.CommunityScoreMatrix;

                if (ntwp.CommunityProb.Dimension != numCommunities)
                {
                    throw new ApplicationException("Unexpected number of communities");
                }

                CommunityProb            = ntwp.CommunityProb;
                cbccPriors.CommunityProb = ntwp.CommunityProb;
            }

            return(cbccPriors);
        }
コード例 #3
0
ファイル: Results.cs プロジェクト: ScriptBox21/dotnet-infer
        /// <summary>
        /// Deserializes the parameters of CBCC from an xml file (used in the LoadAndUseCommunityPriors mode).
        /// </summary>
        /// <param name="modelName">The model name.</param>
        /// <param name="numCommunities">The number of communities.</param>
        /// <returns></returns>
        CommunityModel.Posteriors DeserializeCommunityPosteriors(string modelName, int numCommunities)
        {
            CommunityModel.Posteriors cbccPriors = new CommunityModel.Posteriors();
            using (FileStream stream = new FileStream(modelName + "CommunityPriors.xml", FileMode.Open))
            {
                var serializer = new System.Xml.Serialization.XmlSerializer(typeof(NonTaskWorkerParameters));
                var ntwp       = (NonTaskWorkerParameters)serializer.Deserialize(stream);

                if (ntwp.BackgroundLabelProb.Dimension != Mapping.LabelCount)
                {
                    throw new ApplicationException("Unexpected number of labels");
                }

                BackgroundLabelProb            = ntwp.BackgroundLabelProb;
                cbccPriors.BackgroundLabelProb = ntwp.BackgroundLabelProb;
                if (ntwp.CommunityScoreMatrix.Length != numCommunities)
                {
                    throw new ApplicationException("Unexpected number of communities");
                }

                if (ntwp.CommunityScoreMatrix[0][0].Dimension != Mapping.LabelCount)
                {
                    throw new ApplicationException("Unexpected number of labels");
                }

                CommunityScoreMatrix            = ntwp.CommunityScoreMatrix;
                cbccPriors.CommunityScoreMatrix = ntwp.CommunityScoreMatrix;

                if (ntwp.CommunityProb.Dimension != numCommunities)
                {
                    throw new ApplicationException("Unexpected number of communities");
                }

                CommunityProb            = ntwp.CommunityProb;
                cbccPriors.CommunityProb = ntwp.CommunityProb;
            }

            return(cbccPriors);
        }
コード例 #4
0
        /// <summary>
        /// Infers the posteriors of CBCC using the attached data.
        /// </summary>
        /// <param name="taskIndices">The matrix of the task indices (columns) of each worker (rows).</param>
        /// <param name="workerLabels">The matrix of the labels (columns) of each worker (rows).</param>
        /// <param name="priors">The priors.</param>
        /// <returns></returns>
        public override BCC.Posteriors Infer(int[][] taskIndices, int[][] workerLabels, BCC.Posteriors priors)
        {
            var cbccPriors = (CommunityModel.Posteriors)priors;

            VectorGaussian[][] scoreConstraint     = (cbccPriors == null ? null : cbccPriors.WorkerScoreMatrixConstraint);
            Discrete[]         communityConstraint = (cbccPriors == null ? null : cbccPriors.WorkerCommunityConstraint);
            SetPriors(workerLabels.Length, priors);
            AttachData(taskIndices, workerLabels, scoreConstraint, communityConstraint);
            var result = new CommunityModel.Posteriors();

            Engine.NumberOfIterations          = NumberOfIterations;
            result.Evidence                    = Engine.Infer <Bernoulli>(Evidence);
            result.BackgroundLabelProb         = Engine.Infer <Dirichlet>(BackgroundLabelProb);
            result.WorkerConfusionMatrix       = Engine.Infer <Dirichlet[][]>(WorkerConfusionMatrix);
            result.TrueLabel                   = Engine.Infer <Discrete[]>(TrueLabel);
            result.TrueLabelConstraint         = Engine.Infer <Discrete[]>(TrueLabel, QueryTypes.MarginalDividedByPrior);
            result.CommunityScoreMatrix        = Engine.Infer <VectorGaussian[][]>(CommunityScoreMatrix);
            result.CommunityConfusionMatrix    = Engine.Infer <Dirichlet[][]>(CommunityConfusionMatrix);
            result.WorkerScoreMatrixConstraint = Engine.Infer <VectorGaussian[][]>(ScoreMatrix, QueryTypes.MarginalDividedByPrior);
            result.CommunityProb               = Engine.Infer <Dirichlet>(CommunityProb);
            result.Community                   = Engine.Infer <Discrete[]>(Community);
            result.WorkerCommunityConstraint   = Engine.Infer <Discrete[]>(Community, QueryTypes.MarginalDividedByPrior);
            return(result);
        }
コード例 #5
0
ファイル: Results.cs プロジェクト: ScriptBox21/dotnet-infer
        /// <summary>
        /// Loads the priors of BCC and CBCC.
        /// </summary>
        /// <returns>A BCC posterior instance with the loaded priors.</returns>
        BCC.Posteriors ToPriors()
        {
            int numClasses = Mapping.LabelCount;
            int numTasks   = Mapping.TaskCount;
            int numWorkers = Mapping.WorkerCount;

            CommunityModel.Posteriors cbccPriors = new CommunityModel.Posteriors();
            BCC.Posteriors            priors     = IsCommunityModel ? cbccPriors : new BCC.Posteriors();

            /// Loads the prior of the background probabilities of the tasks
            priors.BackgroundLabelProb = BackgroundLabelProb;

            /// Loads the prior of the confusion matrix of each worker
            priors.WorkerConfusionMatrix = Util.ArrayInit(numWorkers,
                                                          w =>
            {
                string wid = Mapping.WorkerIndexToId[w];
                if (WorkerConfusionMatrix.ContainsKey(wid))
                {
                    return(Util.ArrayInit(numClasses, lab => WorkerConfusionMatrix[wid][lab]));
                }
                else
                {
                    return(Util.ArrayInit(numClasses, lab => Dirichlet.Uniform(numClasses)));
                }
            });

            /// Loads the true label constraint of each task
            priors.TrueLabelConstraint = Util.ArrayInit(numTasks,
                                                        t =>
            {
                string tid = Mapping.TaskIndexToId[t];
                if (TrueLabelConstraint.ContainsKey(tid))
                {
                    return(TrueLabelConstraint[Mapping.TaskIndexToId[t]]);
                }
                else
                {
                    return(Discrete.Uniform(numClasses));
                }
            });

            /// Loads the priors of the parameters of CBCC
            if (IsCommunityModel)
            {
                cbccPriors.CommunityConfusionMatrix    = CommunityConfusionMatrix;
                cbccPriors.WorkerScoreMatrixConstraint = Util.ArrayInit(numWorkers,
                                                                        w =>
                {
                    string wid = Mapping.WorkerIndexToId[w];
                    if (WorkerScoreMatrixConstraint.ContainsKey(wid))
                    {
                        return(Util.ArrayInit(numClasses, lab => WorkerScoreMatrixConstraint[wid][lab]));
                    }
                    else
                    {
                        return(Util.ArrayInit(numClasses, lab => VectorGaussian.Uniform(numClasses)));
                    }
                });
                cbccPriors.CommunityProb             = CommunityProb;
                cbccPriors.CommunityScoreMatrix      = CommunityScoreMatrix;
                cbccPriors.WorkerCommunityConstraint = Util.ArrayInit(numWorkers,
                                                                      w =>
                {
                    string wid = Mapping.WorkerIndexToId[w];
                    if (CommunityConstraint.ContainsKey(wid))
                    {
                        return(CommunityConstraint[wid]);
                    }
                    else
                    {
                        return(Discrete.Uniform(CommunityCount));
                    }
                });
            }

            priors.Evidence = ModelEvidence;

            return(priors);
        }
コード例 #6
0
ファイル: Results.cs プロジェクト: ScriptBox21/dotnet-infer
        /// <summary>
        /// Updates the results of with the new posteriors.
        /// </summary>
        /// <param name="posteriors">The posteriors.</param>
        /// <param name="mode">The mode (for example training, prediction, etc.).</param>
        void UpdateResults(BCC.Posteriors posteriors, RunMode mode)
        {
            /// In the lookAheadExperiment mode, update only the LookAhead results
            if (mode == RunMode.LookAheadExperiment)
            {
                for (int t = 0; t < posteriors.TrueLabel.Length; t++)
                {
                    LookAheadTrueLabel[Mapping.TaskIndexToId[t]] = posteriors.TrueLabel[t];
                }
                for (int w = 0; w < posteriors.WorkerConfusionMatrix.Length; w++)
                {
                    LookAheadWorkerConfusionMatrix[Mapping.WorkerIndexToId[w]] = posteriors.WorkerConfusionMatrix[w];
                }
            }

            /// In the prediction mode, update only the worker prediction results
            else if (mode == RunMode.Prediction)
            {
                for (int w = 0; w < posteriors.WorkerConfusionMatrix.Length; w++)
                {
                    WorkerPrediction[Mapping.WorkerIndexToId[w]] = new Dictionary <string, Discrete>();
                    for (int tw = 0; tw < posteriors.WorkerPrediction[w].Length; tw++)
                    {
                        WorkerPrediction[Mapping.WorkerIndexToId[w]][Mapping.TaskIndexToId[tw]] = posteriors.WorkerPrediction[w][tw];
                    }
                }
            }
            else
            {
                /// In the all other modes, update all the results
                CommunityModel.Posteriors communityPosteriors = posteriors as CommunityModel.Posteriors;
                bool isCommunityModel = communityPosteriors != null;
                BackgroundLabelProb = posteriors.BackgroundLabelProb;
                for (int w = 0; w < posteriors.WorkerConfusionMatrix.Length; w++)
                {
                    WorkerConfusionMatrix[Mapping.WorkerIndexToId[w]] = posteriors.WorkerConfusionMatrix[w];
                }
                for (int t = 0; t < posteriors.TrueLabel.Length; t++)
                {
                    TrueLabel[Mapping.TaskIndexToId[t]] = posteriors.TrueLabel[t];
                }
                for (int t = 0; t < posteriors.TrueLabelConstraint.Length; t++)
                {
                    TrueLabelConstraint[Mapping.TaskIndexToId[t]] = posteriors.TrueLabelConstraint[t];
                }

                if (isCommunityModel)
                {
                    CommunityConfusionMatrix = communityPosteriors.CommunityConfusionMatrix;
                    for (int w = 0; w < communityPosteriors.WorkerScoreMatrixConstraint.Length; w++)
                    {
                        WorkerScoreMatrixConstraint[Mapping.WorkerIndexToId[w]] = communityPosteriors.WorkerScoreMatrixConstraint[w];
                        CommunityConstraint[Mapping.WorkerIndexToId[w]]         = communityPosteriors.WorkerCommunityConstraint[w];
                        WorkerCommunity[Mapping.WorkerIndexToId[w]]             = communityPosteriors.Community[w];
                    }

                    CommunityProb        = communityPosteriors.CommunityProb;
                    CommunityScoreMatrix = communityPosteriors.CommunityScoreMatrix;
                }

                this.ModelEvidence = posteriors.Evidence;
            }
        }