Пример #1
0
        private IEnumerable <IList <TranslationInfo> > GetNBestLists(ThotSmtParameters parameters,
                                                                     IReadOnlyList <IReadOnlyList <string> > sourceCorpus)
        {
            var results = new IList <TranslationInfo> [sourceCorpus.Count];

            Parallel.ForEach(Partitioner.Create(0, sourceCorpus.Count), range =>
            {
                IntPtr smtModelHandle = IntPtr.Zero, decoderHandle = IntPtr.Zero;
                try
                {
                    smtModelHandle = Thot.LoadSmtModel(parameters);
                    decoderHandle  = Thot.LoadDecoder(smtModelHandle, parameters);
                    for (int i = range.Item1; i < range.Item2; i++)
                    {
                        IReadOnlyList <string> sourceSegment = sourceCorpus[i];
                        results[i] = Thot.DoTranslateNBest(decoderHandle, Thot.decoder_translateNBest, K,
                                                           sourceSegment, false, sourceSegment, CreateTranslationInfo).ToArray();
                    }
                }
                finally
                {
                    if (decoderHandle != IntPtr.Zero)
                    {
                        Thot.decoder_close(decoderHandle);
                    }
                    if (smtModelHandle != IntPtr.Zero)
                    {
                        Thot.smtModel_close(smtModelHandle);
                    }
                }
            });
            return(results);
        }
Пример #2
0
 public BatchTrainer(ThotSmtModel smtModel, ThotSmtParameters parameters,
                     Func <string, string> sourcePreprocessor, Func <string, string> targetPreprocessor,
                     ParallelTextCorpus corpus)
     : base(parameters, sourcePreprocessor, targetPreprocessor, corpus)
 {
     _smtModel = smtModel;
 }
        public void TrainModels_EmptyCorpus_GeneratesModels()
        {
            using (var tempDir = new TempDirectory("ThotSmtEngineTests"))
            {
                var sourceCorpus    = new DictionaryTextCorpus(Enumerable.Empty <MemoryText>());
                var targetCorpus    = new DictionaryTextCorpus(Enumerable.Empty <MemoryText>());
                var alignmentCorpus = new DictionaryTextAlignmentCorpus(
                    Enumerable.Empty <MemoryTextAlignmentCollection>());

                var corpus = new ParallelTextCorpus(sourceCorpus, targetCorpus, alignmentCorpus);

                var parameters = new ThotSmtParameters
                {
                    TranslationModelFileNamePrefix = Path.Combine(tempDir.Path, "tm", "src_trg"),
                    LanguageModelFileNamePrefix    = Path.Combine(tempDir.Path, "lm", "trg.lm")
                };

                using (var trainer = new ThotSmtBatchTrainer(parameters, s => s, s => s, corpus))
                {
                    trainer.Train();
                    trainer.Save();
                }

                Assert.That(File.Exists(Path.Combine(tempDir.Path, "lm", "trg.lm")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_swm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_invswm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg.ttable")), Is.True);
                // TODO: test for more than just existence of files
            }
        }
Пример #4
0
        public static IntPtr LoadDecoder(IntPtr smtModelHandle, ThotSmtParameters parameters)
        {
            IntPtr handle = decoder_create(smtModelHandle);

            decoder_setS(handle, parameters.DecoderS);
            decoder_setBreadthFirst(handle, parameters.DecoderBreadthFirst);
            decoder_setG(handle, parameters.DecoderG);
            return(handle);
        }
Пример #5
0
        public ThotSmtModel(ThotSmtParameters parameters)
        {
            Parameters = parameters;
            Parameters.Freeze();

            Handle = Thot.LoadSmtModel(Parameters);

            _directWordAlignmentModel = new ThotWordAlignmentModel(
                Thot.smtModel_getSingleWordAlignmentModel(Handle));
            _inverseWordAlignmentModel = new ThotWordAlignmentModel(
                Thot.smtModel_getInverseSingleWordAlignmentModel(Handle));
        }
Пример #6
0
        public static IntPtr LoadSmtModel(ThotSmtParameters parameters)
        {
            IntPtr handle = smtModel_create();

            smtModel_loadTranslationModel(handle, parameters.TranslationModelFileNamePrefix);
            smtModel_loadLanguageModel(handle, parameters.LanguageModelFileNamePrefix);
            smtModel_setNonMonotonicity(handle, parameters.ModelNonMonotonicity);
            smtModel_setW(handle, parameters.ModelW);
            smtModel_setA(handle, parameters.ModelA);
            smtModel_setE(handle, parameters.ModelE);
            smtModel_setHeuristic(handle, (uint)parameters.ModelHeuristic);
            smtModel_setOnlineTrainingParameters(handle, (uint)parameters.LearningAlgorithm,
                                                 (uint)parameters.LearningRatePolicy, parameters.LearningStepSize, parameters.LearningEMIters,
                                                 parameters.LearningE, parameters.LearningR);
            if (parameters.ModelWeights != null)
            {
                smtModel_setWeights(handle, parameters.ModelWeights.ToArray(), (uint)parameters.ModelWeights.Count);
            }
            return(handle);
        }
Пример #7
0
        private static double CalculateBleu(ThotSmtParameters parameters,
                                            IReadOnlyList <IReadOnlyList <string> > sourceCorpus, IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus)
        {
            IEnumerable <IReadOnlyList <string> > translations = GenerateTranslations(parameters, sourceCorpus);
            double bleu    = Evaluation.CalculateBleu(translations, tuneTargetCorpus);
            double penalty = 0;

            for (int i = 0; i < parameters.ModelWeights.Count; i++)
            {
                if (i == 0 || i == 2 || i == 7)
                {
                    continue;
                }

                if (parameters.ModelWeights[i] < 0)
                {
                    penalty += parameters.ModelWeights[i] * 1000 * -1;
                }
            }
            return((1.0 - bleu) + penalty);
        }
Пример #8
0
        private static IEnumerable <IReadOnlyList <string> > GenerateTranslations(ThotSmtParameters parameters,
                                                                                  IReadOnlyList <IReadOnlyList <string> > sourceCorpus)
        {
            IntPtr smtModelHandle = IntPtr.Zero;

            try
            {
                smtModelHandle = Thot.LoadSmtModel(parameters);
                var results = new IReadOnlyList <string> [sourceCorpus.Count];
                Parallel.ForEach(Partitioner.Create(0, sourceCorpus.Count), range =>
                {
                    IntPtr decoderHandle = IntPtr.Zero;
                    try
                    {
                        decoderHandle = Thot.LoadDecoder(smtModelHandle, parameters);
                        for (int i = range.Item1; i < range.Item2; i++)
                        {
                            IReadOnlyList <string> segment = sourceCorpus[i];
                            results[i] = Thot.DoTranslate(decoderHandle, Thot.decoder_translate, segment, false,
                                                          segment, (s, t, d) => t);
                        }
                    }
                    finally
                    {
                        if (decoderHandle != IntPtr.Zero)
                        {
                            Thot.decoder_close(decoderHandle);
                        }
                    }
                });
                return(results);
            }
            finally
            {
                if (smtModelHandle != IntPtr.Zero)
                {
                    Thot.smtModel_close(smtModelHandle);
                }
            }
        }
Пример #9
0
        public ThotSmtParameters Tune(ThotSmtParameters parameters,
                                      IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                      IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter,
                                      SmtBatchTrainStats stats)
        {
            float sentLenWeight = parameters.ModelWeights[7];
            int   numFuncEvals  = 0;

            double Evaluate(Vector weights)
            {
                ThotSmtParameters newParameters = parameters.Clone();

                newParameters.ModelWeights = weights.Select(w => (float)w).Concat(sentLenWeight).ToArray();
                newParameters.Freeze();
                double quality = CalculateBleu(newParameters, tuneSourceCorpus, tuneTargetCorpus);

                numFuncEvals++;
                if (numFuncEvals < MaxFunctionEvaluations && ProgressIncrementInterval > 0 &&
                    numFuncEvals % ProgressIncrementInterval == 0)
                {
                    reporter.Step();
                }
                else
                {
                    reporter.CheckCanceled();
                }
                return(quality);
            };
            var simplex = new NelderMeadSimplex(ConvergenceTolerance, MaxFunctionEvaluations, 1.0);
            MinimizationResult result = simplex.FindMinimum(Evaluate,
                                                            parameters.ModelWeights.Select(w => (double)w).Take(7));

            stats.TranslationModelBleu = 1.0 - result.ErrorValue;

            ThotSmtParameters bestParameters = parameters.Clone();

            bestParameters.ModelWeights = result.MinimizingPoint.Select(w => (float)w).Concat(sentLenWeight).ToArray();
            bestParameters.Freeze();
            return(bestParameters);
        }
Пример #10
0
        public ThotSmtParameters Tune(ThotSmtParameters parameters,
                                      IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                      IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, SmtBatchTrainStats stats,
                                      IProgress <ProgressStatus> progress)
        {
            float sentLenWeight = parameters.ModelWeights[7];
            int   numFuncEvals  = 0;

            double Evaluate(Vector weights)
            {
                ThotSmtParameters newParameters = parameters.Clone();

                newParameters.ModelWeights = weights.Select(w => (float)w).Concat(sentLenWeight).ToArray();
                newParameters.Freeze();
                double quality = CalculateBleu(newParameters, tuneSourceCorpus, tuneTargetCorpus);

                numFuncEvals++;
                int currentStep = Math.Min(numFuncEvals, MaxProgressFunctionEvaluations);

                progress.Report(new ProgressStatus(currentStep, MaxProgressFunctionEvaluations));
                return(quality);
            };
            progress.Report(new ProgressStatus(0, MaxFunctionEvaluations));
            var simplex = new NelderMeadSimplex(ConvergenceTolerance, MaxFunctionEvaluations, 1.0);
            MinimizationResult result = simplex.FindMinimum(Evaluate,
                                                            parameters.ModelWeights.Select(w => (double)w).Take(7));

            stats.TranslationModelBleu = 1.0 - result.ErrorValue;

            ThotSmtParameters bestParameters = parameters.Clone();

            bestParameters.ModelWeights = result.MinimizingPoint.Select(w => (float)w).Concat(sentLenWeight).ToArray();
            bestParameters.Freeze();

            if (result.FunctionEvaluationCount < MaxProgressFunctionEvaluations)
            {
                progress.Report(new ProgressStatus(1.0));
            }
            return(bestParameters);
        }
Пример #11
0
        private void TrainTuneCorpus(string trainTMPrefix, string trainLMPrefix,
                                     IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                     IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, ThotTrainProgressReporter reporter)
        {
            reporter.Step("Finalizing", TrainingStepCount - 1);

            if (tuneSourceCorpus.Count == 0)
            {
                return;
            }

            ThotSmtParameters parameters = Parameters.Clone();

            parameters.TranslationModelFileNamePrefix = trainTMPrefix;
            parameters.LanguageModelFileNamePrefix    = trainLMPrefix;
            using (var smtModel = new ThotSmtModel(parameters))
                using (ISmtEngine engine = smtModel.CreateEngine())
                {
                    for (int i = 0; i < tuneSourceCorpus.Count; i++)
                    {
                        engine.TrainSegment(tuneSourceCorpus[i], tuneTargetCorpus[i]);
                    }
                }
        }
Пример #12
0
        private void TrainTuneCorpus(string trainTMPrefix, string trainLMPrefix,
                                     IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                     IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, IProgress <ProgressStatus> progress)
        {
            if (tuneSourceCorpus.Count == 0)
            {
                return;
            }

            ThotSmtParameters parameters = Parameters.Clone();

            parameters.TranslationModelFileNamePrefix = trainTMPrefix;
            parameters.LanguageModelFileNamePrefix    = trainLMPrefix;
            using (var smtModel = new ThotSmtModel(parameters))
                using (ISmtEngine engine = smtModel.CreateEngine())
                {
                    for (int i = 0; i < tuneSourceCorpus.Count; i++)
                    {
                        progress.Report(new ProgressStatus(i, tuneSourceCorpus.Count));
                        engine.TrainSegment(tuneSourceCorpus[i], tuneTargetCorpus[i]);
                    }
                    progress.Report(new ProgressStatus(1.0));
                }
        }
Пример #13
0
        public ThotSmtBatchTrainer(ThotSmtParameters parameters, Func <string, string> sourcePreprocessor,
                                   Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, int maxCorpusCount = int.MaxValue)
        {
            Parameters = parameters;
            Parameters.Freeze();
            _sourcePreprocessor = sourcePreprocessor;
            _targetPreprocessor = targetPreprocessor;
            _maxCorpusCount     = maxCorpusCount;
            _parallelCorpus     = corpus;
            //_modelWeightTuner = new MiraModelWeightTuner();
            _modelWeightTuner  = new SimplexModelWeightTuner();
            _tuneCorpusIndices = CreateTuneCorpus();

            do
            {
                _tempDir = Path.Combine(Path.GetTempPath(), "thot-train-" + Guid.NewGuid());
            } while (Directory.Exists(_tempDir));
            Directory.CreateDirectory(_tempDir);

            _lmFilePrefix = Path.GetFileName(Parameters.LanguageModelFileNamePrefix);
            _tmFilePrefix = Path.GetFileName(Parameters.TranslationModelFileNamePrefix);
            _trainLMDir   = Path.Combine(_tempDir, "lm");
            _trainTMDir   = Path.Combine(_tempDir, "tm_train");
        }
Пример #14
0
 public ThotSmtModel(string cfgFileName)
     : this(ThotSmtParameters.Load(cfgFileName))
 {
     ConfigFileName = cfgFileName;
 }
Пример #15
0
 public ThotSmtBatchTrainer(string cfgFileName, Func <string, string> sourcePreprocessor,
                            Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, int maxCorpusCount = int.MaxValue)
     : this(ThotSmtParameters.Load(cfgFileName), sourcePreprocessor, targetPreprocessor, corpus, maxCorpusCount)
 {
     ConfigFileName = cfgFileName;
 }
Пример #16
0
        public ThotSmtParameters Tune(ThotSmtParameters parameters,
                                      IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                      IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, SmtBatchTrainStats stats,
                                      IProgress <ProgressStatus> progress)
        {
            IntPtr weightUpdaterHandle = Thot.llWeightUpdater_create();

            try
            {
                var                         iterQualities  = new List <double>();
                double                      bestQuality    = double.MinValue;
                ThotSmtParameters           bestParameters = null;
                int                         iter           = 0;
                HashSet <TranslationInfo>[] curNBestLists  = null;
                float[]                     curWeights     = parameters.ModelWeights.ToArray();

                while (true)
                {
                    progress.Report(new ProgressStatus(iter, MaxIterations));

                    ThotSmtParameters newParameters = parameters.Clone();
                    newParameters.ModelWeights = curWeights;
                    newParameters.Freeze();
                    IList <TranslationInfo>[] nbestLists = GetNBestLists(newParameters, tuneSourceCorpus).ToArray();
                    double quality = Evaluation.CalculateBleu(nbestLists.Select(nbl => nbl.First().Translation),
                                                              tuneTargetCorpus);
                    iterQualities.Add(quality);
                    if (quality > bestQuality)
                    {
                        bestQuality    = quality;
                        bestParameters = newParameters;
                    }

                    iter++;
                    if (iter >= MaxIterations || IsTuningConverged(iterQualities))
                    {
                        break;
                    }

                    if (curNBestLists == null)
                    {
                        curNBestLists = nbestLists.Select(nbl => new HashSet <TranslationInfo>(nbl)).ToArray();
                    }
                    else
                    {
                        for (int i = 0; i < nbestLists.Length; i++)
                        {
                            curNBestLists[i].UnionWith(nbestLists[i]);
                        }
                    }

                    UpdateWeights(weightUpdaterHandle, tuneTargetCorpus, curNBestLists, curWeights);
                }

                if (iter < MaxIterations)
                {
                    progress.Report(new ProgressStatus(1.0));
                }
                stats.TranslationModelBleu = bestQuality;
                return(bestParameters);
            }
            finally
            {
                Thot.llWeightUpdater_close(weightUpdaterHandle);
            }
        }
        public void Train_NonEmptyCorpus_GeneratesModels()
        {
            using (var tempDir = new TempDirectory("ThotSmtEngineTests"))
            {
                var sourceCorpus = new DictionaryTextCorpus(new[]
                {
                    new MemoryText("text1", new[]
                    {
                        new TextSegment(new TextSegmentRef(1, 1),
                                        "¿ le importaría darnos las llaves de la habitación , por favor ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 2),
                                        "he hecho la reserva de una habitación tranquila doble con teléfono y televisión a nombre de rosario cabedo .".Split()),
                        new TextSegment(new TextSegmentRef(1, 3),
                                        "¿ le importaría cambiarme a otra habitación más tranquila ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 4),
                                        "por favor , tengo reservada una habitación .".Split()),
                        new TextSegment(new TextSegmentRef(1, 5), "me parece que existe un problema .".Split())
                    })
                });

                var targetCorpus = new DictionaryTextCorpus(new[]
                {
                    new MemoryText("text1", new[]
                    {
                        new TextSegment(new TextSegmentRef(1, 1),
                                        "would you mind giving us the keys to the room , please ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 2),
                                        "i have made a reservation for a quiet , double room with a telephone and a tv for rosario cabedo .".Split()),
                        new TextSegment(new TextSegmentRef(1, 3),
                                        "would you mind moving me to a quieter room ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 4), "i have booked a room .".Split()),
                        new TextSegment(new TextSegmentRef(1, 5), "i think that there is a problem .".Split())
                    })
                });

                var alignmentCorpus = new DictionaryTextAlignmentCorpus(new[]
                {
                    new MemoryTextAlignmentCollection("text1", new[]
                    {
                        new TextAlignment(new TextSegmentRef(1, 1), new[] { new AlignedWordPair(8, 9) }),
                        new TextAlignment(new TextSegmentRef(1, 2), new[] { new AlignedWordPair(6, 10) }),
                        new TextAlignment(new TextSegmentRef(1, 3), new[] { new AlignedWordPair(6, 8) }),
                        new TextAlignment(new TextSegmentRef(1, 4), new[] { new AlignedWordPair(6, 4) }),
                        new TextAlignment(new TextSegmentRef(1, 5), new AlignedWordPair[0])
                    })
                });

                var corpus = new ParallelTextCorpus(sourceCorpus, targetCorpus, alignmentCorpus);

                var parameters = new ThotSmtParameters
                {
                    TranslationModelFileNamePrefix = Path.Combine(tempDir.Path, "tm", "src_trg"),
                    LanguageModelFileNamePrefix    = Path.Combine(tempDir.Path, "lm", "trg.lm")
                };

                using (var trainer = new ThotSmtBatchTrainer(parameters, s => s, s => s, corpus))
                {
                    trainer.Train();
                    trainer.Save();
                }

                Assert.That(File.Exists(Path.Combine(tempDir.Path, "lm", "trg.lm")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_swm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_invswm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg.ttable")), Is.True);
                // TODO: test for more than just existence of files
            }
        }