public void TrainModels_EmptyCorpus_GeneratesModels()
        {
            using (var tempDir = new TempDirectory("ThotSmtEngineTests"))
            {
                var sourceCorpus    = new DictionaryTextCorpus(Enumerable.Empty <MemoryText>());
                var targetCorpus    = new DictionaryTextCorpus(Enumerable.Empty <MemoryText>());
                var alignmentCorpus = new DictionaryTextAlignmentCorpus(
                    Enumerable.Empty <MemoryTextAlignmentCollection>());

                var corpus = new ParallelTextCorpus(sourceCorpus, targetCorpus, alignmentCorpus);

                var parameters = new ThotSmtParameters
                {
                    TranslationModelFileNamePrefix = Path.Combine(tempDir.Path, "tm", "src_trg"),
                    LanguageModelFileNamePrefix    = Path.Combine(tempDir.Path, "lm", "trg.lm")
                };

                using (var trainer = new ThotSmtBatchTrainer(parameters, s => s, s => s, corpus))
                {
                    trainer.Train();
                    trainer.Save();
                }

                Assert.That(File.Exists(Path.Combine(tempDir.Path, "lm", "trg.lm")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_swm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_invswm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg.ttable")), Is.True);
                // TODO: test for more than just existence of files
            }
        }
예제 #2
0
        public ThotSmtBatchTrainer(ThotSmtParameters parameters, Func <string, string> sourcePreprocessor,
                                   Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, int maxCorpusCount = int.MaxValue)
        {
            Parameters = parameters;
            Parameters.Freeze();
            _sourcePreprocessor = sourcePreprocessor;
            _targetPreprocessor = targetPreprocessor;
            _maxCorpusCount     = maxCorpusCount;
            _parallelCorpus     = corpus;
            //_modelWeightTuner = new MiraModelWeightTuner();
            _modelWeightTuner = new SimplexModelWeightTuner {
                ProgressIncrementInterval = 10
            };
            _tuneCorpusIndices = CreateTuneCorpus();

            do
            {
                _tempDir = Path.Combine(Path.GetTempPath(), "thot-train-" + Guid.NewGuid());
            } while (Directory.Exists(_tempDir));
            Directory.CreateDirectory(_tempDir);

            _lmFilePrefix = Path.GetFileName(Parameters.LanguageModelFileNamePrefix);
            _tmFilePrefix = Path.GetFileName(Parameters.TranslationModelFileNamePrefix);
            _trainLMDir   = Path.Combine(_tempDir, "lm");
            _trainTMDir   = Path.Combine(_tempDir, "tm_train");
        }
예제 #3
0
 public BatchTrainer(ThotSmtModel smtModel, ThotSmtParameters parameters,
                     Func <string, string> sourcePreprocessor, Func <string, string> targetPreprocessor,
                     ParallelTextCorpus corpus)
     : base(parameters, sourcePreprocessor, targetPreprocessor, corpus)
 {
     _smtModel = smtModel;
 }
예제 #4
0
 public static void AddSegmentPairs(this IWordAlignmentModel model, ParallelTextCorpus corpus,
                                    Func <string, string> sourcePreprocessor = null, Func <string, string> targetPreprocessor = null,
                                    int maxCount = int.MaxValue)
 {
     foreach (ParallelTextSegment segment in corpus.Segments.Where(s => !s.IsEmpty).Take(maxCount))
     {
         model.AddSegmentPair(segment, sourcePreprocessor, targetPreprocessor);
     }
 }
예제 #5
0
        public ISmtBatchTrainer CreateBatchTrainer(Func <string, string> sourcePreprocessor, ITextCorpus sourceCorpus,
                                                   Func <string, string> targetPreprocessor, ITextCorpus targetCorpus,
                                                   ITextAlignmentCorpus alignmentCorpus = null)
        {
            CheckDisposed();

            var corpus = new ParallelTextCorpus(sourceCorpus, targetCorpus, alignmentCorpus);

            return(string.IsNullOrEmpty(ConfigFileName)
                                ? new BatchTrainer(this, Parameters, sourcePreprocessor, targetPreprocessor, corpus)
                                : new BatchTrainer(this, ConfigFileName, sourcePreprocessor, targetPreprocessor, corpus));
        }
예제 #6
0
        private IEnumerable <ParallelTextSegment> GetSegments(ParallelTextCorpus corpus, Func <int, bool> filter)
        {
            int corpusCount = 0;
            int index       = 0;

            foreach (ParallelTextSegment segment in corpus.Segments)
            {
                if (!segment.IsEmpty)
                {
                    if (filter(index))
                    {
                        yield return(segment);
                    }
                    corpusCount++;
                }
                index++;
                if (corpusCount == _maxCorpusCount)
                {
                    break;
                }
            }
        }
예제 #7
0
 private IEnumerable <ParallelTextSegment> GetTuningSegments(ParallelTextCorpus corpus)
 {
     return(GetSegments(corpus, i => _tuneCorpusIndices.Contains(i)));
 }
예제 #8
0
        private void GenerateBestAlignments(string swmPrefix, string fileName, Func <string, string> sourcePreprocessor,
                                            Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name,
                                            ThotTrainProgressReporter reporter)
        {
            reporter.Step($"Generating best {name} alignments");

            using (var model = new ThotWordAlignmentModel(swmPrefix))
                using (var writer = new StreamWriter(fileName))
                {
                    foreach (ParallelTextSegment segment in GetTrainingSegments(corpus))
                    {
                        writer.Write($"# {segment.Text.Id} {segment.SegmentRef}\n");
                        writer.Write(model.GetGizaFormatString(segment, sourcePreprocessor, targetPreprocessor));

                        reporter.CheckCanceled();
                    }
                }
        }
예제 #9
0
        private void TrainWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                             Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name,
                                             ThotTrainProgressReporter reporter)
        {
            using (var model = new ThotWordAlignmentModel(swmPrefix, true))
            {
                foreach (ParallelTextSegment segment in GetTrainingSegments(corpus))
                {
                    model.AddSegmentPair(segment, sourcePreprocessor, targetPreprocessor);
                }
                for (int i = 0; i < 5; i++)
                {
                    reporter.Step($"Training {name} alignment model");

                    model.TrainingIteration();
                }
                model.Save();
            }
        }
예제 #10
0
        private void GenerateSingleWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                                      Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, string name,
                                                      ThotTrainProgressReporter reporter)
        {
            TrainWordAlignmentModel(swmPrefix, sourcePreprocessor, targetPreprocessor, corpus, name, reporter);

            reporter.CheckCanceled();

            PruneLexTable(swmPrefix + ".hmm_lexnd", 0.00001);

            GenerateBestAlignments(swmPrefix, swmPrefix + ".bestal", sourcePreprocessor, targetPreprocessor, corpus,
                                   name, reporter);
        }
예제 #11
0
        private void GenerateBestAlignments(string swmPrefix, string fileName, Func <string, string> sourcePreprocessor,
                                            Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, IProgress <ProgressStatus> progress)
        {
            using (var model = new ThotWordAlignmentModel(swmPrefix))
                using (var writer = new StreamWriter(fileName))
                {
                    int i = 0;
                    foreach (ParallelTextSegment segment in GetTrainingSegments(corpus))
                    {
                        progress.Report(new ProgressStatus(i, Stats.TrainedSegmentCount));

                        writer.Write($"# {segment.Text.Id} {segment.SegmentRef}\n");
                        writer.Write(model.GetGizaFormatString(segment, sourcePreprocessor, targetPreprocessor));
                        i++;
                    }
                    progress.Report(new ProgressStatus(1.0));
                }
        }
예제 #12
0
 private void TrainWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                      Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, IProgress <ProgressStatus> progress)
 {
     using (var model = new ThotWordAlignmentModel(swmPrefix, true))
     {
         foreach (ParallelTextSegment segment in GetTrainingSegments(corpus))
         {
             model.AddSegmentPair(segment, sourcePreprocessor, targetPreprocessor);
         }
         model.Train(progress);
         model.Save();
     }
 }
예제 #13
0
        private void GenerateWordAlignmentModel(string swmPrefix, Func <string, string> sourcePreprocessor,
                                                Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, ThotTrainProgressReporter reporter,
                                                bool inverted)
        {
            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
            {
                TrainWordAlignmentModel(swmPrefix, sourcePreprocessor, targetPreprocessor, corpus, phaseProgress);
            }

            reporter.CheckCanceled();

            PruneLexTable(swmPrefix + ".hmm_lexnd", 0.00001);

            using (PhaseProgress phaseProgress = reporter.StartNextPhase())
            {
                GenerateBestAlignments(swmPrefix, swmPrefix + ".bestal", sourcePreprocessor, targetPreprocessor, corpus,
                                       phaseProgress);
            }
        }
        public void Train_NonEmptyCorpus_GeneratesModels()
        {
            using (var tempDir = new TempDirectory("ThotSmtEngineTests"))
            {
                var sourceCorpus = new DictionaryTextCorpus(new[]
                {
                    new MemoryText("text1", new[]
                    {
                        new TextSegment(new TextSegmentRef(1, 1),
                                        "¿ le importaría darnos las llaves de la habitación , por favor ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 2),
                                        "he hecho la reserva de una habitación tranquila doble con teléfono y televisión a nombre de rosario cabedo .".Split()),
                        new TextSegment(new TextSegmentRef(1, 3),
                                        "¿ le importaría cambiarme a otra habitación más tranquila ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 4),
                                        "por favor , tengo reservada una habitación .".Split()),
                        new TextSegment(new TextSegmentRef(1, 5), "me parece que existe un problema .".Split())
                    })
                });

                var targetCorpus = new DictionaryTextCorpus(new[]
                {
                    new MemoryText("text1", new[]
                    {
                        new TextSegment(new TextSegmentRef(1, 1),
                                        "would you mind giving us the keys to the room , please ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 2),
                                        "i have made a reservation for a quiet , double room with a telephone and a tv for rosario cabedo .".Split()),
                        new TextSegment(new TextSegmentRef(1, 3),
                                        "would you mind moving me to a quieter room ?".Split()),
                        new TextSegment(new TextSegmentRef(1, 4), "i have booked a room .".Split()),
                        new TextSegment(new TextSegmentRef(1, 5), "i think that there is a problem .".Split())
                    })
                });

                var alignmentCorpus = new DictionaryTextAlignmentCorpus(new[]
                {
                    new MemoryTextAlignmentCollection("text1", new[]
                    {
                        new TextAlignment(new TextSegmentRef(1, 1), new[] { new AlignedWordPair(8, 9) }),
                        new TextAlignment(new TextSegmentRef(1, 2), new[] { new AlignedWordPair(6, 10) }),
                        new TextAlignment(new TextSegmentRef(1, 3), new[] { new AlignedWordPair(6, 8) }),
                        new TextAlignment(new TextSegmentRef(1, 4), new[] { new AlignedWordPair(6, 4) }),
                        new TextAlignment(new TextSegmentRef(1, 5), new AlignedWordPair[0])
                    })
                });

                var corpus = new ParallelTextCorpus(sourceCorpus, targetCorpus, alignmentCorpus);

                var parameters = new ThotSmtParameters
                {
                    TranslationModelFileNamePrefix = Path.Combine(tempDir.Path, "tm", "src_trg"),
                    LanguageModelFileNamePrefix    = Path.Combine(tempDir.Path, "lm", "trg.lm")
                };

                using (var trainer = new ThotSmtBatchTrainer(parameters, s => s, s => s, corpus))
                {
                    trainer.Train();
                    trainer.Save();
                }

                Assert.That(File.Exists(Path.Combine(tempDir.Path, "lm", "trg.lm")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_swm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg_invswm.hmm_alignd")), Is.True);
                Assert.That(File.Exists(Path.Combine(tempDir.Path, "tm", "src_trg.ttable")), Is.True);
                // TODO: test for more than just existence of files
            }
        }
        protected override int ExecuteCommand()
        {
            if (!_sourceOption.HasValue())
            {
                Out.WriteLine("The source corpus was not specified.");
                return(1);
            }

            if (!_targetOption.HasValue())
            {
                Out.WriteLine("The target corpus was not specified.");
                return(1);
            }

            if (!ValidateTextCorpusOption(_sourceOption.Value(), out string sourceType, out string sourcePath))
            {
                Out.WriteLine("The specified source corpus is invalid.");
                return(1);
            }

            if (!ValidateTextCorpusOption(_targetOption.Value(), out string targetType, out string targetPath))
            {
                Out.WriteLine("The specified target corpus is invalid.");
                return(1);
            }

            string alignmentsType = null, alignmentsPath = null;

            if (_alignmentsOption != null && !ValidateAlignmentsOption(_alignmentsOption.Value(), out alignmentsType,
                                                                       out alignmentsPath))
            {
                Out.WriteLine("The specified partial alignments corpus is invalid.");
                return(1);
            }

            if (!ValidateWordTokenizerOption(_sourceWordTokenizerOption.Value()))
            {
                Out.WriteLine("The specified source word tokenizer type is invalid.");
                return(1);
            }

            if (!ValidateWordTokenizerOption(_targetWordTokenizerOption.Value()))
            {
                Out.WriteLine("The specified target word tokenizer type is invalid.");
                return(1);
            }

            if (_maxCorpusSizeOption.HasValue())
            {
                if (!int.TryParse(_maxCorpusSizeOption.Value(), out int maxCorpusSize) || maxCorpusSize <= 0)
                {
                    Out.WriteLine("The specified maximum corpus size is invalid.");
                    return(1);
                }
                MaxParallelCorpusCount = maxCorpusSize;
            }

            StringTokenizer sourceWordTokenizer = CreateWordTokenizer(_sourceWordTokenizerOption.Value());
            StringTokenizer targetWordTokenizer = CreateWordTokenizer(_targetWordTokenizerOption.Value());

            SourceCorpus     = CreateTextCorpus(sourceWordTokenizer, sourceType, sourcePath);
            TargetCorpus     = CreateTextCorpus(targetWordTokenizer, targetType, targetPath);
            AlignmentsCorpus = null;
            if (_alignmentsOption != null && _alignmentsOption.HasValue())
            {
                AlignmentsCorpus = CreateAlignmentsCorpus(alignmentsType, alignmentsPath);
            }

            ISet <string> includeTexts          = null;

            if (_includeOption.HasValue())
            {
                includeTexts = GetTexts(_includeOption.Values);
            }

            ISet <string> excludeTexts = null;

            if (_excludeOption.HasValue())
            {
                excludeTexts = GetTexts(_excludeOption.Values);
            }

            if (includeTexts != null || excludeTexts != null)
            {
                bool Filter(string id)
                {
                    if (excludeTexts != null && excludeTexts.Contains(id))
                    {
                        return(false);
                    }

                    if (includeTexts != null && includeTexts.Contains(id))
                    {
                        return(true);
                    }

                    return(includeTexts == null);
                }

                SourceCorpus = new FilteredTextCorpus(SourceCorpus, text => Filter(text.Id));
                TargetCorpus = new FilteredTextCorpus(TargetCorpus, text => Filter(text.Id));
                if (_alignmentsOption != null && _alignmentsOption.HasValue())
                {
                    AlignmentsCorpus = new FilteredTextAlignmentCorpus(AlignmentsCorpus,
                                                                       alignments => Filter(alignments.Id));
                }
            }

            ParallelCorpus = new ParallelTextCorpus(SourceCorpus, TargetCorpus, AlignmentsCorpus);

            return(0);
        }
예제 #16
0
 public ThotSmtBatchTrainer(string cfgFileName, Func <string, string> sourcePreprocessor,
                            Func <string, string> targetPreprocessor, ParallelTextCorpus corpus, int maxCorpusCount = int.MaxValue)
     : this(ThotSmtParameters.Load(cfgFileName), sourcePreprocessor, targetPreprocessor, corpus, maxCorpusCount)
 {
     ConfigFileName = cfgFileName;
 }
예제 #17
0
 public BatchTrainer(ThotSmtModel smtModel, string cfgFileName, Func <string, string> sourcePreprocessor,
                     Func <string, string> targetPreprocessor, ParallelTextCorpus corpus)
     : base(cfgFileName, sourcePreprocessor, targetPreprocessor, corpus)
 {
     _smtModel = smtModel;
 }