private void TrainLanguageModel(string lmPrefix, int ngramSize, ThotTrainProgressReporter reporter) { reporter.Step("Training target language model"); WriteNgramCountsFile(lmPrefix, ngramSize); WriteLanguageModelWeightsFile(lmPrefix, ngramSize, Enumerable.Repeat(0.5, ngramSize * 3)); WriteWordPredictionFile(lmPrefix); }
private void TrainTranslationModel(string tmPrefix, ThotTrainProgressReporter reporter) { string invswmPrefix = tmPrefix + "_invswm"; GenerateSingleWordAlignmentModel(invswmPrefix, _sourcePreprocessor, _targetPreprocessor, _parallelCorpus, "direct", reporter); string swmPrefix = tmPrefix + "_swm"; GenerateSingleWordAlignmentModel(swmPrefix, _targetPreprocessor, _sourcePreprocessor, _parallelCorpus.Invert(), "inverse", reporter); reporter.Step("Merging alignments"); Thot.giza_symmetr1(swmPrefix + ".bestal", invswmPrefix + ".bestal", tmPrefix + ".A3.final", true); reporter.Step("Generating phrase table"); Thot.phraseModel_generate(tmPrefix + ".A3.final", 10, tmPrefix + ".ttable"); reporter.Step("Filtering phrase table"); FilterPhraseTableNBest(tmPrefix + ".ttable", 20); File.WriteAllText(tmPrefix + ".lambda", "0.7 0.7"); File.WriteAllText(tmPrefix + ".srcsegmlentable", "Uniform"); File.WriteAllText(tmPrefix + ".trgcutstable", "0.999"); File.WriteAllText(tmPrefix + ".trgsegmlentable", "Geometric"); }
private void TrainTranslationModel(string tmPrefix, ThotTrainProgressReporter reporter) { string invswmPrefix = tmPrefix + "_invswm"; GenerateWordAlignmentModel(invswmPrefix, _sourcePreprocessor, _targetPreprocessor, _parallelCorpus, reporter, false); string swmPrefix = tmPrefix + "_swm"; GenerateWordAlignmentModel(swmPrefix, _targetPreprocessor, _sourcePreprocessor, _parallelCorpus.Invert(), reporter, true); using (PhaseProgress phaseProgress = reporter.StartNextPhase()) Thot.giza_symmetr1(swmPrefix + ".bestal", invswmPrefix + ".bestal", tmPrefix + ".A3.final", true); using (PhaseProgress phaseProgress = reporter.StartNextPhase()) Thot.phraseModel_generate(tmPrefix + ".A3.final", 10, tmPrefix + ".ttable"); using (PhaseProgress phaseProgress = reporter.StartNextPhase()) FilterPhraseTableNBest(tmPrefix + ".ttable", 20); File.WriteAllText(tmPrefix + ".lambda", "0.7 0.7"); File.WriteAllText(tmPrefix + ".srcsegmlentable", "Uniform"); File.WriteAllText(tmPrefix + ".trgcutstable", "0.999"); File.WriteAllText(tmPrefix + ".trgsegmlentable", "Geometric"); }
private void GenerateSingleWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor, Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name, ThotTrainProgressReporter reporter) { TrainWordAlignmentModel(swmPrefix, sourcePreprocessor, targetPreprocessor, corpus, name, reporter); reporter.CheckCanceled(); PruneLexTable(swmPrefix + ".hmm_lexnd", 0.00001); GenerateBestAlignments(swmPrefix, swmPrefix + ".bestal", sourcePreprocessor, targetPreprocessor, corpus, name, reporter); }
private void TuneLanguageModel(string lmPrefix, IList <IReadOnlyList <string> > tuneTargetCorpus, int ngramSize, ThotTrainProgressReporter reporter) { reporter.Step("Tuning target language model"); if (tuneTargetCorpus.Count == 0) { return; } var simplex = new NelderMeadSimplex(0.1, 200, 1.0); MinimizationResult result = simplex.FindMinimum(w => CalculatePerplexity(tuneTargetCorpus, lmPrefix, ngramSize, w), Enumerable.Repeat(0.5, ngramSize * 3)); WriteLanguageModelWeightsFile(lmPrefix, ngramSize, result.MinimizingPoint); Stats.LanguageModelPerplexity = result.ErrorValue; }
public virtual void Train(IProgress <ProgressData> progress = null, Action checkCanceled = null) { var reporter = new ThotTrainProgressReporter(TrainingStepCount, progress, checkCanceled); Directory.CreateDirectory(_trainLMDir); string trainLMPrefix = Path.Combine(_trainLMDir, _lmFilePrefix); Directory.CreateDirectory(_trainTMDir); string trainTMPrefix = Path.Combine(_trainTMDir, _tmFilePrefix); TrainLanguageModel(trainLMPrefix, 3, reporter); reporter.CheckCanceled(); TrainTranslationModel(trainTMPrefix, reporter); reporter.CheckCanceled(); string tuneTMDir = Path.Combine(_tempDir, "tm_tune"); Directory.CreateDirectory(tuneTMDir); string tuneTMPrefix = Path.Combine(tuneTMDir, _tmFilePrefix); CopyFiles(_trainTMDir, tuneTMDir, _tmFilePrefix); var tuneSourceCorpus = new List <IReadOnlyList <string> >(_tuneCorpusIndices.Count); var tuneTargetCorpus = new List <IReadOnlyList <string> >(_tuneCorpusIndices.Count); foreach (ParallelTextSegment segment in GetTuningSegments(_parallelCorpus)) { tuneSourceCorpus.Add(segment.SourceSegment.Preprocess(_sourcePreprocessor)); tuneTargetCorpus.Add(segment.TargetSegment.Preprocess(_targetPreprocessor)); } TuneLanguageModel(trainLMPrefix, tuneTargetCorpus, 3, reporter); reporter.CheckCanceled(); TuneTranslationModel(tuneTMPrefix, trainLMPrefix, tuneSourceCorpus, tuneTargetCorpus, reporter); reporter.CheckCanceled(); TrainTuneCorpus(trainTMPrefix, trainLMPrefix, tuneSourceCorpus, tuneTargetCorpus, reporter); reporter.Step("Completed"); }
private void GenerateBestAlignments(string swmPrefix, string fileName, Func <string, string> sourcePreprocessor, Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name, ThotTrainProgressReporter reporter) { reporter.Step($"Generating best {name} alignments"); using (var model = new ThotWordAlignmentModel(swmPrefix)) using (var writer = new StreamWriter(fileName)) { foreach (ParallelTextSegment segment in GetTrainingSegments(corpus)) { writer.Write($"# {segment.Text.Id} {segment.SegmentRef}\n"); writer.Write(model.GetGizaFormatString(segment, sourcePreprocessor, targetPreprocessor)); reporter.CheckCanceled(); } } }
private void TrainWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor, Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name, ThotTrainProgressReporter reporter) { using (var model = new ThotWordAlignmentModel(swmPrefix, true)) { foreach (ParallelTextSegment segment in GetTrainingSegments(corpus)) { model.AddSegmentPair(segment, sourcePreprocessor, targetPreprocessor); } for (int i = 0; i < 5; i++) { reporter.Step($"Training {name} alignment model"); model.TrainingIteration(); } model.Save(); } }
private void TrainTuneCorpus(string trainTMPrefix, string trainLMPrefix, IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus, IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter) { reporter.Step("Finalizing", TrainingStepCount - 1); if (tuneSourceCorpus.Count == 0) { return; } ThotSmtParameters parameters = Parameters.Clone(); parameters.TranslationModelFileNamePrefix = trainTMPrefix; parameters.LanguageModelFileNamePrefix = trainLMPrefix; using (var smtModel = new ThotSmtModel(parameters)) using (ISmtEngine engine = smtModel.CreateEngine()) { for (int i = 0; i < tuneSourceCorpus.Count; i++) { engine.TrainSegment(tuneSourceCorpus[i], tuneTargetCorpus[i]); } } }
private void TuneTranslationModel(string tuneTMPrefix, string tuneLMPrefix, IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus, IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter) { reporter.Step("Tuning translation model"); if (tuneSourceCorpus.Count == 0) { return; } string phraseTableFileName = tuneTMPrefix + ".ttable"; FilterPhraseTableUsingCorpus(phraseTableFileName, tuneSourceCorpus); FilterPhraseTableNBest(phraseTableFileName, 20); ThotSmtParameters oldParameters = Parameters; ThotSmtParameters initialParameters = oldParameters.Clone(); initialParameters.TranslationModelFileNamePrefix = tuneTMPrefix; initialParameters.LanguageModelFileNamePrefix = tuneLMPrefix; initialParameters.ModelWeights = new[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0f }; initialParameters.Freeze(); ThotSmtParameters tunedParameters = _modelWeightTuner.Tune(initialParameters, tuneSourceCorpus, tuneTargetCorpus, reporter, Stats); Parameters = tunedParameters.Clone(); Parameters.TranslationModelFileNamePrefix = oldParameters.TranslationModelFileNamePrefix; Parameters.LanguageModelFileNamePrefix = oldParameters.LanguageModelFileNamePrefix; Parameters.Freeze(); }
private void GenerateWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor, Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, ThotTrainProgressReporter reporter, bool inverted) { using (PhaseProgress phaseProgress = reporter.StartNextPhase()) { TrainWordAlignmentModel(swmPrefix, sourcePreprocessor, targetPreprocessor, corpus, phaseProgress); } reporter.CheckCanceled(); PruneLexTable(swmPrefix + ".hmm_lexnd", 0.00001); using (PhaseProgress phaseProgress = reporter.StartNextPhase()) { GenerateBestAlignments(swmPrefix, swmPrefix + ".bestal", sourcePreprocessor, targetPreprocessor, corpus, phaseProgress); } }
public ThotSmtParameters Tune(ThotSmtParameters parameters, IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus, IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter, SmtBatchTrainStats stats) { float sentLenWeight = parameters.ModelWeights[7]; int numFuncEvals = 0; double Evaluate(Vector weights) { ThotSmtParameters newParameters = parameters.Clone(); newParameters.ModelWeights = weights.Select(w => (float)w).Concat(sentLenWeight).ToArray(); newParameters.Freeze(); double quality = CalculateBleu(newParameters, tuneSourceCorpus, tuneTargetCorpus); numFuncEvals++; if (numFuncEvals < MaxFunctionEvaluations && ProgressIncrementInterval > 0 && numFuncEvals % ProgressIncrementInterval == 0) { reporter.Step(); } else { reporter.CheckCanceled(); } return(quality); }; var simplex = new NelderMeadSimplex(ConvergenceTolerance, MaxFunctionEvaluations, 1.0); MinimizationResult result = simplex.FindMinimum(Evaluate, parameters.ModelWeights.Select(w => (double)w).Take(7)); stats.TranslationModelBleu = 1.0 - result.ErrorValue; ThotSmtParameters bestParameters = parameters.Clone(); bestParameters.ModelWeights = result.MinimizingPoint.Select(w => (float)w).Concat(sentLenWeight).ToArray(); bestParameters.Freeze(); return(bestParameters); }
public ThotSmtParameters Tune(ThotSmtParameters parameters, IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus, IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter, SmtBatchTrainStats stats) { IntPtr weightUpdaterHandle = Thot.llWeightUpdater_create(); try { var iterQualities = new List <double>(); double bestQuality = double.MinValue; ThotSmtParameters bestParameters = null; int iter = 1; HashSet <TranslationInfo>[] curNBestLists = null; float[] curWeights = parameters.ModelWeights.ToArray(); while (true) { ThotSmtParameters newParameters = parameters.Clone(); newParameters.ModelWeights = curWeights; newParameters.Freeze(); IList <TranslationInfo>[] nbestLists = GetNBestLists(newParameters, tuneSourceCorpus).ToArray(); double quality = Evaluation.CalculateBleu(nbestLists.Select(nbl => nbl.First().Translation), tuneTargetCorpus); iterQualities.Add(quality); if (quality > bestQuality) { bestQuality = quality; bestParameters = newParameters; } if (iter >= MaxIterations || IsTuningConverged(iterQualities)) { break; } if (curNBestLists == null) { curNBestLists = nbestLists.Select(nbl => new HashSet <TranslationInfo>(nbl)).ToArray(); } else { for (int i = 0; i < nbestLists.Length; i++) { curNBestLists[i].UnionWith(nbestLists[i]); } } UpdateWeights(weightUpdaterHandle, tuneTargetCorpus, curNBestLists, curWeights); iter++; reporter.Step(); } stats.TranslationModelBleu = bestQuality; return(bestParameters); } finally { Thot.llWeightUpdater_close(weightUpdaterHandle); } }