/// <summary> /// Constructs an active learning instance with a specified data set and model instance. /// </summary> /// <param name="data">The data.</param> /// <param name="model">The model instance.</param> /// <param name="results">The results instance.</param> /// <param name="numCommunities">The number of communities (only for CBCC).</param> public ActiveLearning(IList <Datum> data, BCC model, Results results, int numCommunities) { this.bcc = model; CommunityModel communityModel = model as CommunityModel; IsCommunityModel = (communityModel != null); ActiveLearningResults = results; BatchResults = results; WorkerIds = ActiveLearningResults.Mapping.WorkerIdToIndex.Keys.ToArray(); TaskIds = ActiveLearningResults.Mapping.TaskIdToIndex.Keys.ToArray(); /// Builds the full matrix of data from every task and worker PredictionData = new List <Datum>(); foreach (var workerId in WorkerIds) { foreach (var task in TaskIds) { PredictionData.Add(new Datum { TaskId = task, WorkerId = workerId, WorkerLabel = 0, GoldLabel = null }); } } }
/// <summary> /// Runs the BCC or CBCC model. /// </summary> /// <param name="modelName">The model name.</param> /// <param name="data">The data that will be used for this run.</param> /// <param name="fullData">The full data set of data.</param> /// <param name="model">The model instance (BCC or CBCC).</param> /// <param name="mode">The mode (for example training, prediction, etc.).</param> /// <param name="calculateAccuracy">Whether to calculate accuracy.</param> /// <param name="numCommunities">The number of communities (community model only).</param> /// <param name="serialize">Whether to serialize all posteriors.</param> /// <param name="serializeCommunityPosteriors">Whether to serialize community posteriors.</param> public void RunBCC(string modelName, IList <Datum> data, IList <Datum> fullData, BCC model, RunMode mode, bool calculateAccuracy, int numCommunities = -1, bool serialize = false, bool serializeCommunityPosteriors = false) { CommunityModel communityModel = model as CommunityModel; IsCommunityModel = communityModel != null; if (this.Mapping == null) { this.Mapping = new DataMapping(fullData, numCommunities); this.GoldLabels = this.Mapping.GetGoldLabelsPerTaskId(); } /// A new model is created if the label count or the task count has changed bool createModel = (Mapping.LabelCount != model.LabelCount) || (Mapping.TaskCount != model.TaskCount); if (IsCommunityModel) { /// Creates a new CBCC model instance CommunityCount = numCommunities; createModel = createModel || (numCommunities != communityModel.CommunityCount); if (createModel) { communityModel.CreateModel(Mapping.TaskCount, Mapping.LabelCount, numCommunities); } } else if (createModel) { /// Creates a new BCC model instance model.CreateModel(Mapping.TaskCount, Mapping.LabelCount); } /// Selects the prior according to the run mode BCC.Posteriors priors = null; switch (mode) { /// Use existing priors case RunMode.OnlineTraining: case RunMode.LookAheadExperiment: case RunMode.Prediction: priors = ToPriors(); break; default: /// Use default priors ClearResults(); if (mode == RunMode.LoadAndUseCommunityPriors && IsCommunityModel) { priors = DeserializeCommunityPosteriors(modelName, numCommunities); } break; } /// Get data to observe var labelsPerWorkerIndex = Mapping.GetLabelsPerWorkerIndex(data); if (mode == RunMode.Prediction) { /// Signal prediction mode by setting all labels to null labelsPerWorkerIndex = labelsPerWorkerIndex.Select(arr => (int[])null).ToArray(); } /// Run model inference BCC.Posteriors posteriors = model.Infer( Mapping.GetTaskIndicesPerWorkerIndex(data), labelsPerWorkerIndex, priors); UpdateResults(posteriors, mode); /// Compute accuracy if (calculateAccuracy) { UpdateAccuracy(); } /// Serialize parameters if (serialize) { using (FileStream stream = new FileStream(modelName + ".xml", FileMode.Create)) { var serializer = new System.Xml.Serialization.XmlSerializer(IsCommunityModel ? typeof(CommunityModel.Posteriors) : typeof(BCC.Posteriors)); serializer.Serialize(stream, posteriors); } } if (serializeCommunityPosteriors && IsCommunityModel) { SerializeCommunityPosteriors(modelName); } }