public virtual IList <Pair <int, int> > GetClusterMerges(ClustererDataLoader.ClustererDoc doc) { IList <Pair <int, int> > merges = new List <Pair <int, int> >(); Clusterer.State currentState = new Clusterer.State(doc); while (!currentState.IsComplete()) { Pair <int, int> currentPair = currentState.mentionPairs[currentState.currentIndex]; if (currentState.DoBestAction(classifier)) { merges.Add(currentPair); } } return(merges); }
public virtual Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> GetActions(Clusterer.ClustererClassifier classifier) { ICounter <string> mergeFeatures = GetFeatures(doc, c1, c2, globalFeatures[currentIndex]); double mergeScore = Math.Exp(classifier.WeightFeatureProduct(mergeFeatures)); hashedScores[new Clusterer.MergeKey(c1, c2, currentIndex)] = mergeScore > 0.5; Clusterer.State merge = new Clusterer.State(this); merge.DoAction(true); double mergeB3 = merge.GetFinalCost(classifier); Clusterer.State noMerge = new Clusterer.State(this); noMerge.DoAction(false); double noMergeB3 = noMerge.GetFinalCost(classifier); double weight = doc.mentions.Count / 100.0; double maxB3 = Math.Max(mergeB3, noMergeB3); return(new Pair <Clusterer.CandidateAction, Clusterer.CandidateAction>(new Clusterer.CandidateAction(mergeFeatures, weight * (maxB3 - mergeB3)), new Clusterer.CandidateAction(new ClassicCounter <string>(), weight * (maxB3 - noMergeB3)))); }
private double EvaluatePolicy(IList <ClustererDataLoader.ClustererDoc> docs, bool training) { isTraining = 0; EvalUtils.B3Evaluator evaluator = new EvalUtils.B3Evaluator(); foreach (ClustererDataLoader.ClustererDoc doc in docs) { Clusterer.State currentState = new Clusterer.State(doc); while (!currentState.IsComplete()) { currentState.DoBestAction(classifier); } currentState.UpdateEvaluator(evaluator); } isTraining = 1; double score = evaluator.GetF1(); Redwood.Log("scoref.train", string.Format("B3 F1 score on %s: %.4f", training ? "train" : "validate", score)); return(score); }
private IList <Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> > RunPolicy(ClustererDataLoader.ClustererDoc doc, double beta) { IList <Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> > examples = new List <Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> >(); Clusterer.State currentState = new Clusterer.State(doc); while (!currentState.IsComplete()) { Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> actions = currentState.GetActions(classifier); if (actions == null) { continue; } examples.Add(actions); bool useExpert = random.NextDouble() < beta; double action1Score = useExpert ? -actions.first.cost : classifier.WeightFeatureProduct(actions.first.features); double action2Score = useExpert ? -actions.second.cost : classifier.WeightFeatureProduct(actions.second.features); currentState.DoAction(action1Score >= action2Score); } return(examples); }
public State(Clusterer.State state) { this.hashedScores = state.hashedScores; this.hashedCosts = state.hashedCosts; this.doc = state.doc; this.hash = state.hash; this.mentionPairs = state.mentionPairs; this.currentIndex = state.currentIndex; this.globalFeatures = state.globalFeatures; this.clusters = new List <Clusterer.Cluster>(); this.mentionToCluster = new Dictionary <int, Clusterer.Cluster>(); foreach (Clusterer.Cluster c in state.clusters) { Clusterer.Cluster copy = new Clusterer.Cluster(c); clusters.Add(copy); foreach (int m in copy.mentions) { mentionToCluster[m] = copy; } } SetClusters(); }