Esempio n. 1
0
        /// <summary>
        /// Constructs an active learning instance with a specified data set and model instance.
        /// </summary>
        /// <param name="data">The data.</param>
        /// <param name="model">The model instance.</param>
        /// <param name="results">The results instance.</param>
        /// <param name="numCommunities">The number of communities (only for CBCC).</param>
        public ActiveLearning(IList <Datum> data, BCC model, Results results, int numCommunities)
        {
            this.bcc = model;
            CommunityModel communityModel = model as CommunityModel;

            IsCommunityModel      = (communityModel != null);
            ActiveLearningResults = results;
            BatchResults          = results;
            WorkerIds             = ActiveLearningResults.Mapping.WorkerIdToIndex.Keys.ToArray();
            TaskIds = ActiveLearningResults.Mapping.TaskIdToIndex.Keys.ToArray();

            /// Builds the full matrix of data from every task and worker
            PredictionData = new List <Datum>();
            foreach (var workerId in WorkerIds)
            {
                foreach (var task in TaskIds)
                {
                    PredictionData.Add(new Datum
                    {
                        TaskId      = task,
                        WorkerId    = workerId,
                        WorkerLabel = 0,
                        GoldLabel   = null
                    });
                }
            }
        }
Esempio n. 2
0
        /// <summary>
        /// Runs a model with the full gold set.
        /// </summary>
        /// <param name="dataSet">The data.</param>
        /// <param name="runType">The model run type.</param>
        /// <param name="model">The model instance.</param>
        /// <param name="communityCount">The number of communities (only for CBCC).</param>
        /// <returns>The inference results</returns>
        static Results RunGold(string dataSet, RunType runType, BCC model, int communityCount = 3)
        {
            // Reset the random seed so results can be duplicated for the paper
            Rand.Restart(12347);
            var data = Datum.LoadData(@".\Data\" + dataSet + ".csv");

            string  modelName = GetModelName(dataSet, runType, TaskSelectionMethod.EntropyTask, WorkerSelectionMethod.RandomWorker);
            Results results   = new Results();

            switch (runType)
            {
            case RunType.VoteDistribution:
                results.RunMajorityVote(data, true, true);
                break;

            case RunType.MajorityVote:
                results.RunMajorityVote(data, true, false);
                break;

            case RunType.DawidSkene:
                results.RunDawidSkene(data, true);
                break;

            default:
                results.RunBCC(ResultsDir + modelName, data, data, model, Results.RunMode.ClearResults, false, communityCount, false, false);
                break;
            }

            // Write the inference results on a csv file
            using (StreamWriter writer = new StreamWriter(ResultsDir + "endpoints.csv", true))
            {
                writer.WriteLine("{0},{1:0.000},{2:0.0000}", modelName, results.Accuracy, results.NegativeLogProb);
            }
            return(results);
        }
Esempio n. 3
0
        /// <summary>
        /// Runs the standard active learning procedure on a model instance and an input data set.
        /// </summary>
        /// <param name="data">The data.</param>
        /// <param name="modelName">The model name.</param>
        /// <param name="runType">The model run type.</param>
        /// <param name="model">The model instance.</param>
        /// <param name="taskSelectionMethod">The method for selecting tasks (Random / Entropy).</param>
        /// <param name="workerSelectionMethod">The method for selecting workers (only Random is implemented).</param>
        /// <param name="resultsDir">The directory to save the log files.</param>
        /// <param name="communityCount">The number of communities (only for CBCC).</param>
        /// <param name="initialNumLabelsPerTask">The initial number of exploratory labels that are randomly selected for each task.</param>
        public static void RunActiveLearning(IList <Datum> data, string modelName, RunType runType, BCC model, TaskSelectionMethod taskSelectionMethod, WorkerSelectionMethod workerSelectionMethod, string resultsDir, int communityCount = -1, int initialNumLabelsPerTask = 1)
        {
            //Count elapsed time
            Stopwatch stopWatchTotal = new Stopwatch();

            stopWatchTotal.Start();
            int totalLabels = data.Count();

            // Dictionary keyed by task Id, with randomly order labelings
            var groupedRandomisedData =
                data.GroupBy(d => d.TaskId).
                Select(g =>
            {
                var arr  = g.ToArray();
                int cnt  = arr.Length;
                var perm = Rand.Perm(cnt);
                return(new
                {
                    key = g.Key,
                    arr = g.Select((t, i) => arr[perm[i]]).ToArray()
                });
            }).ToDictionary(a => a.key, a => a.arr);

            // Dictionary keyed by task Id, with label counts
            Dictionary <string, int> totalCounts   = groupedRandomisedData.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.Length);
            Dictionary <string, int> currentCounts = groupedRandomisedData.ToDictionary(kvp => kvp.Key, kvp => initialNumLabelsPerTask);

            // Keyed by task, value is a HashSet containing all the remaining workers with a label - workers are removed after adding a new datum
            Dictionary <string, HashSet <string> > remainingWorkersPerTask = groupedRandomisedData.ToDictionary(kvp => kvp.Key, kvp => new HashSet <string>(kvp.Value.Select(dat => dat.WorkerId)));
            int numTaskIds     = totalCounts.Count();
            int totalInstances = data.Count - initialNumLabelsPerTask * numTaskIds;

            string[] WorkerIds = data.Select(d => d.WorkerId).Distinct().ToArray();

            // Log structures
            List <double> accuracy  = new List <double>();
            List <double> nlpd      = new List <double>();
            List <double> avgRecall = new List <double>();
            List <ActiveLearningResult> taskValueList = new List <ActiveLearningResult>();
            int index = 0;

            Console.WriteLine("Active Learning: {0}", modelName);
            Console.WriteLine("\t\tAcc\tAvgRec");

            // Get initial data
            Results      results = new Results();
            List <Datum> subData = null;

            subData = GetSubdata(groupedRandomisedData, currentCounts, remainingWorkersPerTask);
            var            s              = remainingWorkersPerTask.Select(w => w.Value.Count).Sum();
            List <Datum>   nextData       = null;
            int            numIncremData  = 3;
            ActiveLearning activeLearning = null;


            for (int iter = 0; iter < 500; iter++)
            {
                bool calculateAccuracy = true;
                ////bool doSnapShot = iter % 100 == 0; // Frequency of snapshots
                bool doSnapShot = true;
                if (subData != null || nextData != null)
                {
                    switch (runType)
                    {
                    case RunType.VoteDistribution:
                        results.RunMajorityVote(subData, calculateAccuracy, true);
                        break;

                    case RunType.MajorityVote:
                        results.RunMajorityVote(subData, calculateAccuracy, false);
                        break;

                    case RunType.DawidSkene:
                        results.RunDawidSkene(subData, calculateAccuracy);
                        break;

                    default:     // Run BCC models
                        results.RunBCC(modelName, subData, data, model, Results.RunMode.ClearResults, calculateAccuracy, communityCount, false);
                        break;
                    }
                }

                if (activeLearning == null)
                {
                    activeLearning = new ActiveLearning(data, model, results, communityCount);
                }
                else
                {
                    activeLearning.UpdateActiveLearningResults(results);
                }

                // Select next task
                Dictionary <string, ActiveLearningResult>            TaskValue  = null;
                List <Tuple <string, string, ActiveLearningResult> > LabelValue = null;
                switch (taskSelectionMethod)
                {
                case TaskSelectionMethod.EntropyTask:
                    TaskValue = activeLearning.EntropyTrueLabelPosterior();
                    break;

                case TaskSelectionMethod.RandomTask:
                    TaskValue = data.GroupBy(d => d.TaskId).ToDictionary(a => a.Key, a => new ActiveLearningResult
                    {
                        TaskValue = Rand.Double()
                    });
                    break;

                default:     // Entropy task selection
                    TaskValue = activeLearning.EntropyTrueLabelPosterior();
                    break;
                }

                nextData = GetNextData(groupedRandomisedData, TaskValue, currentCounts, totalCounts, numIncremData);

                if (nextData == null || nextData.Count == 0)
                {
                    break;
                }

                index += nextData.Count;
                subData.AddRange(nextData);

                // Logs
                if (calculateAccuracy)
                {
                    accuracy.Add(results.Accuracy);
                    nlpd.Add(results.NegativeLogProb);
                    avgRecall.Add(results.AvgRecall);

                    if (TaskValue == null)
                    {
                        var sortedLabelValue = LabelValue.OrderByDescending(kvp => kvp.Item3.TaskValue).ToArray();
                        taskValueList.Add(sortedLabelValue.First().Item3);
                    }
                    else
                    {
                        taskValueList.Add(TaskValue[nextData.First().TaskId]);
                    }

                    if (doSnapShot)
                    {
                        Console.WriteLine("{0} of {1}:\t{2:0.000}\t{3:0.0000}", index, totalInstances, accuracy.Last(), avgRecall.Last());
                        DoSnapshot(accuracy, nlpd, avgRecall, taskValueList, results, modelName, "interim", resultsDir);
                    }
                }
            }
            stopWatchTotal.Stop();
            DoSnapshot(accuracy, nlpd, avgRecall, taskValueList, results, modelName, "final", resultsDir);
            Console.WriteLine("Elapsed time: {0}\n", stopWatchTotal.Elapsed);
        }
Esempio n. 4
0
        /// <summary>
        /// Runs the active learning experiment presented in Venanzi et.al (WWW14) on a single data set.
        /// </summary>
        /// <param name="dataSet">The data.</param>
        /// <param name="runType">The model run type.</param>
        /// <param name="taskSelectionMethod">The method for selecting tasks (Random / Entropy).</param>
        /// <param name="model">The model instance.</param>
        /// <param name="communityCount">The number of communities (only for CBCC).</param>
        static void RunWWWActiveLearning(string dataSet, RunType runType, TaskSelectionMethod taskSelectionMethod, BCC model, int communityCount = 4)
        {
            // Reset the random seed so results can be duplicated for the paper
            Rand.Restart(12347);
            var    workerSelectionMethod = WorkerSelectionMethod.RandomWorker;
            var    data      = Datum.LoadData(@"Data\" + dataSet + ".csv");
            string modelName = GetModelName(dataSet, runType, taskSelectionMethod, workerSelectionMethod, communityCount);

            ActiveLearning.RunActiveLearning(data, modelName, runType, model, taskSelectionMethod, workerSelectionMethod, ResultsDir, communityCount);
        }
Esempio n. 5
0
        /// <summary>
        /// Runs the BCC or CBCC model.
        /// </summary>
        /// <param name="modelName">The model name.</param>
        /// <param name="data">The data that will be used for this run.</param>
        /// <param name="fullData">The full data set of data.</param>
        /// <param name="model">The model instance (BCC or CBCC).</param>
        /// <param name="mode">The mode (for example training, prediction, etc.).</param>
        /// <param name="calculateAccuracy">Whether to calculate accuracy.</param>
        /// <param name="numCommunities">The number of communities (community model only).</param>
        /// <param name="serialize">Whether to serialize all posteriors.</param>
        /// <param name="serializeCommunityPosteriors">Whether to serialize community posteriors.</param>
        public void RunBCC(string modelName, IList <Datum> data, IList <Datum> fullData, BCC model, RunMode mode, bool calculateAccuracy, int numCommunities = -1, bool serialize = false, bool serializeCommunityPosteriors = false)
        {
            CommunityModel communityModel = model as CommunityModel;

            IsCommunityModel = communityModel != null;

            if (this.Mapping == null)
            {
                this.Mapping    = new DataMapping(fullData, numCommunities);
                this.GoldLabels = this.Mapping.GetGoldLabelsPerTaskId();
            }

            /// A new model is created if the label count or the task count has changed
            bool createModel = (Mapping.LabelCount != model.LabelCount) || (Mapping.TaskCount != model.TaskCount);

            if (IsCommunityModel)
            {
                /// Creates a new CBCC model instance
                CommunityCount = numCommunities;
                createModel    = createModel || (numCommunities != communityModel.CommunityCount);
                if (createModel)
                {
                    communityModel.CreateModel(Mapping.TaskCount, Mapping.LabelCount, numCommunities);
                }
            }
            else if (createModel)
            {
                /// Creates a new BCC model instance
                model.CreateModel(Mapping.TaskCount, Mapping.LabelCount);
            }

            /// Selects the prior according to the run mode
            BCC.Posteriors priors = null;
            switch (mode)
            {
            /// Use existing priors
            case RunMode.OnlineTraining:
            case RunMode.LookAheadExperiment:
            case RunMode.Prediction:
                priors = ToPriors();
                break;

            default:

                /// Use default priors
                ClearResults();
                if (mode == RunMode.LoadAndUseCommunityPriors && IsCommunityModel)
                {
                    priors = DeserializeCommunityPosteriors(modelName, numCommunities);
                }
                break;
            }

            /// Get data to observe
            var labelsPerWorkerIndex = Mapping.GetLabelsPerWorkerIndex(data);

            if (mode == RunMode.Prediction)
            {
                /// Signal prediction mode by setting all labels to null
                labelsPerWorkerIndex = labelsPerWorkerIndex.Select(arr => (int[])null).ToArray();
            }

            /// Run model inference
            BCC.Posteriors posteriors = model.Infer(
                Mapping.GetTaskIndicesPerWorkerIndex(data),
                labelsPerWorkerIndex, priors);
            UpdateResults(posteriors, mode);

            /// Compute accuracy
            if (calculateAccuracy)
            {
                UpdateAccuracy();
            }

            /// Serialize parameters
            if (serialize)
            {
                using (FileStream stream = new FileStream(modelName + ".xml", FileMode.Create))
                {
                    var serializer = new System.Xml.Serialization.XmlSerializer(IsCommunityModel ? typeof(CommunityModel.Posteriors) : typeof(BCC.Posteriors));
                    serializer.Serialize(stream, posteriors);
                }
            }

            if (serializeCommunityPosteriors && IsCommunityModel)
            {
                SerializeCommunityPosteriors(modelName);
            }
        }