public virtual void RunCoref(Document document) { IDictionary <Pair <int, int>, bool> mentionPairs = CorefUtils.GetUnlabeledMentionPairs(document); if (mentionPairs.Count == 0) { return; } Compressor <string> compressor = new Compressor <string>(); DocumentExamples examples = extractor.Extract(0, document, mentionPairs, compressor); ICounter <Pair <int, int> > classificationScores = new ClassicCounter <Pair <int, int> >(); ICounter <Pair <int, int> > rankingScores = new ClassicCounter <Pair <int, int> >(); ICounter <int> anaphoricityScores = new ClassicCounter <int>(); foreach (Example example in examples.examples) { CorefUtils.CheckForInterrupt(); Pair <int, int> mentionPair = new Pair <int, int>(example.mentionId1, example.mentionId2); classificationScores.IncrementCount(mentionPair, classificationModel.Predict(example, examples.mentionFeatures, compressor)); rankingScores.IncrementCount(mentionPair, rankingModel.Predict(example, examples.mentionFeatures, compressor)); if (!anaphoricityScores.ContainsKey(example.mentionId2)) { anaphoricityScores.IncrementCount(example.mentionId2, anaphoricityModel.Predict(new Example(example, false), examples.mentionFeatures, compressor)); } } ClustererDataLoader.ClustererDoc doc = new ClustererDataLoader.ClustererDoc(0, classificationScores, rankingScores, anaphoricityScores, mentionPairs, null, document.predictedMentionsByID.Stream().Collect(Collectors.ToMap(null, null))); foreach (Pair <int, int> mentionPair_1 in clusterer.GetClusterMerges(doc)) { CorefUtils.MergeCoreferenceClusters(mentionPair_1, document); } }
public State(ClustererDataLoader.ClustererDoc doc) { currentDocId = doc.id; this.doc = doc; this.hashedScores = new Dictionary <Clusterer.MergeKey, bool>(); this.hashedCosts = new Dictionary <long, double>(); this.clusters = new List <Clusterer.Cluster>(); this.hash = 0; mentionToCluster = new Dictionary <int, Clusterer.Cluster>(); foreach (int m in doc.mentions) { Clusterer.Cluster c = new Clusterer.Cluster(m); clusters.Add(c); mentionToCluster[m] = c; hash ^= c.hash * 7; } IList <Pair <int, int> > allPairs = new List <Pair <int, int> >(doc.classificationScores.KeySet()); ICounter <Pair <int, int> > scores = UseRanking ? doc.rankingScores : doc.classificationScores; allPairs.Sort(null); int i = 0; for (i = 0; i < allPairs.Count; i++) { double score = scores.GetCount(allPairs[i]); if (score < MinPairwiseScore && i > MinPairs) { break; } if (i >= EarlyStopThreshold && i / score > EarlyStopVal) { break; } } mentionPairs = allPairs.SubList(0, i); ICounter <int> seenAnaphors = new ClassicCounter <int>(); ICounter <int> seenAntecedents = new ClassicCounter <int>(); globalFeatures = new List <Clusterer.GlobalFeatures>(); for (int j = 0; j < allPairs.Count; j++) { Pair <int, int> mentionPair = allPairs[j]; Clusterer.GlobalFeatures gf = new Clusterer.GlobalFeatures(); gf.currentIndex = j; gf.anaphorSeen = seenAnaphors.ContainsKey(mentionPair.second); gf.size = mentionPairs.Count; gf.docSize = doc.mentions.Count / 300.0; globalFeatures.Add(gf); seenAnaphors.IncrementCount(mentionPair.second); seenAntecedents.IncrementCount(mentionPair.first); } currentIndex = 0; SetClusters(); }
private static ICounter <string> GetFeatures(ClustererDataLoader.ClustererDoc doc, Clusterer.Cluster c1, Clusterer.Cluster c2, Clusterer.GlobalFeatures gf) { Clusterer.MergeKey key = new Clusterer.MergeKey(c1, c2, gf.currentIndex); CompressedFeatureVector cfv = featuresCache[key]; ICounter <string> features = cfv == null ? null : compressor.Uncompress(cfv); if (features != null) { featuresCacheHits += isTraining; return(features); } featuresCacheMisses += isTraining; features = new ClassicCounter <string>(); if (gf.anaphorSeen) { features.IncrementCount("anaphorSeen"); } features.IncrementCount("docSize", gf.docSize); features.IncrementCount("percentComplete", gf.currentIndex / (double)gf.size); features.IncrementCount("bias", 1.0); int earliest1 = EarliestMention(c1, doc); int earliest2 = EarliestMention(c2, doc); if (doc.mentionIndices[earliest1] > doc.mentionIndices[earliest2]) { int tmp = earliest1; earliest1 = earliest2; earliest2 = tmp; } features.IncrementCount("anaphoricity", doc.anaphoricityScores.GetCount(earliest2)); if (c1.mentions.Count == 1 && c2.mentions.Count == 1) { Pair <int, int> mentionPair = new Pair <int, int>(c1.mentions[0], c2.mentions[0]); features.AddAll(AddSuffix(GetFeatures(doc, mentionPair, doc.classificationScores), "-classification")); features.AddAll(AddSuffix(GetFeatures(doc, mentionPair, doc.rankingScores), "-ranking")); features = AddSuffix(features, "-single"); } else { IList <Pair <int, int> > between = new List <Pair <int, int> >(); foreach (int m1 in c1.mentions) { foreach (int m2 in c2.mentions) { between.Add(new Pair <int, int>(m1, m2)); } } features.AddAll(AddSuffix(GetFeatures(doc, between, doc.classificationScores), "-classification")); features.AddAll(AddSuffix(GetFeatures(doc, between, doc.rankingScores), "-ranking")); } featuresCache[key] = compressor.Compress(features); return(features); }
private static ICounter <string> GetFeatures(ClustererDataLoader.ClustererDoc doc, Pair <int, int> mentionPair, ICounter <Pair <int, int> > scores) { ICounter <string> features = new ClassicCounter <string>(); if (!scores.ContainsKey(mentionPair)) { mentionPair = new Pair <int, int>(mentionPair.second, mentionPair.first); } double score = scores.GetCount(mentionPair); features.IncrementCount("max", score); return(features); }
private static int EarliestMention(Clusterer.Cluster c, ClustererDataLoader.ClustererDoc doc) { int earliest = -1; foreach (int m in c.mentions) { int pos = doc.mentionIndices[m]; if (earliest == -1 || pos < doc.mentionIndices[earliest]) { earliest = m; } } return(earliest); }
public virtual IList <Pair <int, int> > GetClusterMerges(ClustererDataLoader.ClustererDoc doc) { IList <Pair <int, int> > merges = new List <Pair <int, int> >(); Clusterer.State currentState = new Clusterer.State(doc); while (!currentState.IsComplete()) { Pair <int, int> currentPair = currentState.mentionPairs[currentState.currentIndex]; if (currentState.DoBestAction(classifier)) { merges.Add(currentPair); } } return(merges); }
private static ICounter <string> GetFeatures(ClustererDataLoader.ClustererDoc doc, IList <Pair <int, int> > mentionPairs, ICounter <Pair <int, int> > scores) { ICounter <string> features = new ClassicCounter <string>(); double maxScore = 0; double minScore = 1; ICounter <string> totals = new ClassicCounter <string>(); ICounter <string> totalsLog = new ClassicCounter <string>(); ICounter <string> counts = new ClassicCounter <string>(); foreach (Pair <int, int> mentionPair in mentionPairs) { if (!scores.ContainsKey(mentionPair)) { mentionPair = new Pair <int, int>(mentionPair.second, mentionPair.first); } double score = scores.GetCount(mentionPair); double logScore = CappedLog(score); string mt1 = doc.mentionTypes[mentionPair.first]; string mt2 = doc.mentionTypes[mentionPair.second]; mt1 = mt1.Equals("PRONOMINAL") ? "PRONOMINAL" : "NON_PRONOMINAL"; mt2 = mt2.Equals("PRONOMINAL") ? "PRONOMINAL" : "NON_PRONOMINAL"; string conj = "_" + mt1 + "_" + mt2; maxScore = Math.Max(maxScore, score); minScore = Math.Min(minScore, score); totals.IncrementCount(string.Empty, score); totalsLog.IncrementCount(string.Empty, logScore); counts.IncrementCount(string.Empty); totals.IncrementCount(conj, score); totalsLog.IncrementCount(conj, logScore); counts.IncrementCount(conj); } features.IncrementCount("max", maxScore); features.IncrementCount("min", minScore); foreach (string key in counts.KeySet()) { features.IncrementCount("avg" + key, totals.GetCount(key) / mentionPairs.Count); features.IncrementCount("avgLog" + key, totalsLog.GetCount(key) / mentionPairs.Count); } return(features); }
public State(Clusterer.State state) { this.hashedScores = state.hashedScores; this.hashedCosts = state.hashedCosts; this.doc = state.doc; this.hash = state.hash; this.mentionPairs = state.mentionPairs; this.currentIndex = state.currentIndex; this.globalFeatures = state.globalFeatures; this.clusters = new List <Clusterer.Cluster>(); this.mentionToCluster = new Dictionary <int, Clusterer.Cluster>(); foreach (Clusterer.Cluster c in state.clusters) { Clusterer.Cluster copy = new Clusterer.Cluster(c); clusters.Add(copy); foreach (int m in copy.mentions) { mentionToCluster[m] = copy; } } SetClusters(); }
private IList <Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> > RunPolicy(ClustererDataLoader.ClustererDoc doc, double beta) { IList <Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> > examples = new List <Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> >(); Clusterer.State currentState = new Clusterer.State(doc); while (!currentState.IsComplete()) { Pair <Clusterer.CandidateAction, Clusterer.CandidateAction> actions = currentState.GetActions(classifier); if (actions == null) { continue; } examples.Add(actions); bool useExpert = random.NextDouble() < beta; double action1Score = useExpert ? -actions.first.cost : classifier.WeightFeatureProduct(actions.first.features); double action2Score = useExpert ? -actions.second.cost : classifier.WeightFeatureProduct(actions.second.features); currentState.DoAction(action1Score >= action2Score); } return(examples); }