コード例 #1
0
 private void TrainLanguageModel(string lmPrefix, int ngramSize, ThotTrainProgressReporter reporter)
 {
     reporter.Step("Training target language model");
     WriteNgramCountsFile(lmPrefix, ngramSize);
     WriteLanguageModelWeightsFile(lmPrefix, ngramSize, Enumerable.Repeat(0.5, ngramSize * 3));
     WriteWordPredictionFile(lmPrefix);
 }
コード例 #2
0
        private void TrainTranslationModel(string tmPrefix, ThotTrainProgressReporter reporter)
        {
            string invswmPrefix = tmPrefix + "_invswm";

            GenerateSingleWordAlignmentModel(invswmPrefix, _sourcePreprocessor, _targetPreprocessor, _parallelCorpus,
                                             "direct", reporter);

            string swmPrefix = tmPrefix + "_swm";

            GenerateSingleWordAlignmentModel(swmPrefix, _targetPreprocessor, _sourcePreprocessor,
                                             _parallelCorpus.Invert(), "inverse", reporter);

            reporter.Step("Merging alignments");

            Thot.giza_symmetr1(swmPrefix + ".bestal", invswmPrefix + ".bestal", tmPrefix + ".A3.final", true);

            reporter.Step("Generating phrase table");

            Thot.phraseModel_generate(tmPrefix + ".A3.final", 10, tmPrefix + ".ttable");

            reporter.Step("Filtering phrase table");

            FilterPhraseTableNBest(tmPrefix + ".ttable", 20);

            File.WriteAllText(tmPrefix + ".lambda", "0.7 0.7");
            File.WriteAllText(tmPrefix + ".srcsegmlentable", "Uniform");
            File.WriteAllText(tmPrefix + ".trgcutstable", "0.999");
            File.WriteAllText(tmPrefix + ".trgsegmlentable", "Geometric");
        }
コード例 #3
0
        private void TrainTranslationModel(string tmPrefix, ThotTrainProgressReporter reporter)
        {
            string invswmPrefix = tmPrefix + "_invswm";

            GenerateWordAlignmentModel(invswmPrefix, _sourcePreprocessor, _targetPreprocessor, _parallelCorpus,
                                       reporter, false);

            string swmPrefix = tmPrefix + "_swm";

            GenerateWordAlignmentModel(swmPrefix, _targetPreprocessor, _sourcePreprocessor, _parallelCorpus.Invert(),
                                       reporter, true);

            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
                Thot.giza_symmetr1(swmPrefix + ".bestal", invswmPrefix + ".bestal", tmPrefix + ".A3.final", true);

            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
                Thot.phraseModel_generate(tmPrefix + ".A3.final", 10, tmPrefix + ".ttable");

            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
                FilterPhraseTableNBest(tmPrefix + ".ttable", 20);

            File.WriteAllText(tmPrefix + ".lambda", "0.7 0.7");
            File.WriteAllText(tmPrefix + ".srcsegmlentable", "Uniform");
            File.WriteAllText(tmPrefix + ".trgcutstable", "0.999");
            File.WriteAllText(tmPrefix + ".trgsegmlentable", "Geometric");
        }
コード例 #4
0
        private void GenerateSingleWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                                      Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name,
                                                      ThotTrainProgressReporter reporter)
        {
            TrainWordAlignmentModel(swmPrefix, sourcePreprocessor, targetPreprocessor, corpus, name, reporter);

            reporter.CheckCanceled();

            PruneLexTable(swmPrefix + ".hmm_lexnd", 0.00001);

            GenerateBestAlignments(swmPrefix, swmPrefix + ".bestal", sourcePreprocessor, targetPreprocessor, corpus,
                                   name, reporter);
        }
コード例 #5
0
        private void TuneLanguageModel(string lmPrefix, IList <IReadOnlyList <string> > tuneTargetCorpus,
                                       int ngramSize, ThotTrainProgressReporter reporter)
        {
            reporter.Step("Tuning target language model");

            if (tuneTargetCorpus.Count == 0)
            {
                return;
            }

            var simplex = new NelderMeadSimplex(0.1, 200, 1.0);
            MinimizationResult result = simplex.FindMinimum(w =>
                                                            CalculatePerplexity(tuneTargetCorpus, lmPrefix, ngramSize, w), Enumerable.Repeat(0.5, ngramSize * 3));

            WriteLanguageModelWeightsFile(lmPrefix, ngramSize, result.MinimizingPoint);
            Stats.LanguageModelPerplexity = result.ErrorValue;
        }
コード例 #6
0
        public virtual void Train(IProgress <ProgressData> progress = null, Action checkCanceled = null)
        {
            var reporter = new ThotTrainProgressReporter(TrainingStepCount, progress, checkCanceled);

            Directory.CreateDirectory(_trainLMDir);
            string trainLMPrefix = Path.Combine(_trainLMDir, _lmFilePrefix);

            Directory.CreateDirectory(_trainTMDir);
            string trainTMPrefix = Path.Combine(_trainTMDir, _tmFilePrefix);

            TrainLanguageModel(trainLMPrefix, 3, reporter);

            reporter.CheckCanceled();

            TrainTranslationModel(trainTMPrefix, reporter);

            reporter.CheckCanceled();

            string tuneTMDir = Path.Combine(_tempDir, "tm_tune");

            Directory.CreateDirectory(tuneTMDir);
            string tuneTMPrefix = Path.Combine(tuneTMDir, _tmFilePrefix);

            CopyFiles(_trainTMDir, tuneTMDir, _tmFilePrefix);

            var tuneSourceCorpus = new List <IReadOnlyList <string> >(_tuneCorpusIndices.Count);
            var tuneTargetCorpus = new List <IReadOnlyList <string> >(_tuneCorpusIndices.Count);

            foreach (ParallelTextSegment segment in GetTuningSegments(_parallelCorpus))
            {
                tuneSourceCorpus.Add(segment.SourceSegment.Preprocess(_sourcePreprocessor));
                tuneTargetCorpus.Add(segment.TargetSegment.Preprocess(_targetPreprocessor));
            }

            TuneLanguageModel(trainLMPrefix, tuneTargetCorpus, 3, reporter);

            reporter.CheckCanceled();

            TuneTranslationModel(tuneTMPrefix, trainLMPrefix, tuneSourceCorpus, tuneTargetCorpus, reporter);

            reporter.CheckCanceled();

            TrainTuneCorpus(trainTMPrefix, trainLMPrefix, tuneSourceCorpus, tuneTargetCorpus, reporter);

            reporter.Step("Completed");
        }
コード例 #7
0
        private void GenerateBestAlignments(string swmPrefix, string fileName, Func <string, string> sourcePreprocessor,
                                            Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name,
                                            ThotTrainProgressReporter reporter)
        {
            reporter.Step($"Generating best {name} alignments");

            using (var model = new ThotWordAlignmentModel(swmPrefix))
                using (var writer = new StreamWriter(fileName))
                {
                    foreach (ParallelTextSegment segment in GetTrainingSegments(corpus))
                    {
                        writer.Write($"# {segment.Text.Id} {segment.SegmentRef}\n");
                        writer.Write(model.GetGizaFormatString(segment, sourcePreprocessor, targetPreprocessor));

                        reporter.CheckCanceled();
                    }
                }
        }
コード例 #8
0
        private void TrainWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                             Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name,
                                             ThotTrainProgressReporter reporter)
        {
            using (var model = new ThotWordAlignmentModel(swmPrefix, true))
            {
                foreach (ParallelTextSegment segment in GetTrainingSegments(corpus))
                {
                    model.AddSegmentPair(segment, sourcePreprocessor, targetPreprocessor);
                }
                for (int i = 0; i < 5; i++)
                {
                    reporter.Step($"Training {name} alignment model");

                    model.TrainingIteration();
                }
                model.Save();
            }
        }
コード例 #9
0
        private void TrainTuneCorpus(string trainTMPrefix, string trainLMPrefix,
                                     IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                     IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter)
        {
            reporter.Step("Finalizing", TrainingStepCount - 1);

            if (tuneSourceCorpus.Count == 0)
            {
                return;
            }

            ThotSmtParameters parameters = Parameters.Clone();

            parameters.TranslationModelFileNamePrefix = trainTMPrefix;
            parameters.LanguageModelFileNamePrefix    = trainLMPrefix;
            using (var smtModel = new ThotSmtModel(parameters))
                using (ISmtEngine engine = smtModel.CreateEngine())
                {
                    for (int i = 0; i < tuneSourceCorpus.Count; i++)
                    {
                        engine.TrainSegment(tuneSourceCorpus[i], tuneTargetCorpus[i]);
                    }
                }
        }
コード例 #10
0
        private void TuneTranslationModel(string tuneTMPrefix, string tuneLMPrefix,
                                          IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                          IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter)
        {
            reporter.Step("Tuning translation model");

            if (tuneSourceCorpus.Count == 0)
            {
                return;
            }

            string phraseTableFileName = tuneTMPrefix + ".ttable";

            FilterPhraseTableUsingCorpus(phraseTableFileName, tuneSourceCorpus);
            FilterPhraseTableNBest(phraseTableFileName, 20);

            ThotSmtParameters oldParameters     = Parameters;
            ThotSmtParameters initialParameters = oldParameters.Clone();

            initialParameters.TranslationModelFileNamePrefix = tuneTMPrefix;
            initialParameters.LanguageModelFileNamePrefix    = tuneLMPrefix;
            initialParameters.ModelWeights = new[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0f };
            initialParameters.Freeze();

            ThotSmtParameters tunedParameters = _modelWeightTuner.Tune(initialParameters, tuneSourceCorpus,
                                                                       tuneTargetCorpus, reporter, Stats);

            Parameters = tunedParameters.Clone();
            Parameters.TranslationModelFileNamePrefix = oldParameters.TranslationModelFileNamePrefix;
            Parameters.LanguageModelFileNamePrefix    = oldParameters.LanguageModelFileNamePrefix;
            Parameters.Freeze();
        }
コード例 #11
0
        private void GenerateWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                                Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, ThotTrainProgressReporter reporter,
                                                bool inverted)
        {
            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
            {
                TrainWordAlignmentModel(swmPrefix, sourcePreprocessor, targetPreprocessor, corpus, phaseProgress);
            }

            reporter.CheckCanceled();

            PruneLexTable(swmPrefix + ".hmm_lexnd", 0.00001);

            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
            {
                GenerateBestAlignments(swmPrefix, swmPrefix + ".bestal", sourcePreprocessor, targetPreprocessor, corpus,
                                       phaseProgress);
            }
        }
コード例 #12
0
        public ThotSmtParameters Tune(ThotSmtParameters parameters,
                                      IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                      IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter,
                                      SmtBatchTrainStats stats)
        {
            float sentLenWeight = parameters.ModelWeights[7];
            int   numFuncEvals  = 0;

            double Evaluate(Vector weights)
            {
                ThotSmtParameters newParameters = parameters.Clone();

                newParameters.ModelWeights = weights.Select(w => (float)w).Concat(sentLenWeight).ToArray();
                newParameters.Freeze();
                double quality = CalculateBleu(newParameters, tuneSourceCorpus, tuneTargetCorpus);

                numFuncEvals++;
                if (numFuncEvals < MaxFunctionEvaluations && ProgressIncrementInterval > 0 &&
                    numFuncEvals % ProgressIncrementInterval == 0)
                {
                    reporter.Step();
                }
                else
                {
                    reporter.CheckCanceled();
                }
                return(quality);
            };
            var simplex = new NelderMeadSimplex(ConvergenceTolerance, MaxFunctionEvaluations, 1.0);
            MinimizationResult result = simplex.FindMinimum(Evaluate,
                                                            parameters.ModelWeights.Select(w => (double)w).Take(7));

            stats.TranslationModelBleu = 1.0 - result.ErrorValue;

            ThotSmtParameters bestParameters = parameters.Clone();

            bestParameters.ModelWeights = result.MinimizingPoint.Select(w => (float)w).Concat(sentLenWeight).ToArray();
            bestParameters.Freeze();
            return(bestParameters);
        }
コード例 #13
0
        public ThotSmtParameters Tune(ThotSmtParameters parameters,
                                      IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                      IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter,
                                      SmtBatchTrainStats stats)
        {
            IntPtr weightUpdaterHandle = Thot.llWeightUpdater_create();

            try
            {
                var                         iterQualities  = new List <double>();
                double                      bestQuality    = double.MinValue;
                ThotSmtParameters           bestParameters = null;
                int                         iter           = 1;
                HashSet <TranslationInfo>[] curNBestLists  = null;
                float[]                     curWeights     = parameters.ModelWeights.ToArray();

                while (true)
                {
                    ThotSmtParameters newParameters = parameters.Clone();
                    newParameters.ModelWeights = curWeights;
                    newParameters.Freeze();
                    IList <TranslationInfo>[] nbestLists = GetNBestLists(newParameters, tuneSourceCorpus).ToArray();
                    double quality = Evaluation.CalculateBleu(nbestLists.Select(nbl => nbl.First().Translation),
                                                              tuneTargetCorpus);
                    iterQualities.Add(quality);
                    if (quality > bestQuality)
                    {
                        bestQuality    = quality;
                        bestParameters = newParameters;
                    }

                    if (iter >= MaxIterations || IsTuningConverged(iterQualities))
                    {
                        break;
                    }

                    if (curNBestLists == null)
                    {
                        curNBestLists = nbestLists.Select(nbl => new HashSet <TranslationInfo>(nbl)).ToArray();
                    }
                    else
                    {
                        for (int i = 0; i < nbestLists.Length; i++)
                        {
                            curNBestLists[i].UnionWith(nbestLists[i]);
                        }
                    }

                    UpdateWeights(weightUpdaterHandle, tuneTargetCorpus, curNBestLists, curWeights);

                    iter++;

                    reporter.Step();
                }

                stats.TranslationModelBleu = bestQuality;
                return(bestParameters);
            }
            finally
            {
                Thot.llWeightUpdater_close(weightUpdaterHandle);
            }
        }