Exemplo n.º 1
0
        public ThotSmtParameters Tune(ThotSmtParameters parameters,
                                      IReadOnlyList <IReadOnlyList <string> > tuneSourceCorpus,
                                      IReadOnlyList <IReadOnlyList <string> > tuneTargetCorpus, SmtBatchTrainStats stats,
                                      IProgress <ProgressStatus> progress)
        {
            IntPtr weightUpdaterHandle = Thot.llWeightUpdater_create();

            try
            {
                var                         iterQualities  = new List <double>();
                double                      bestQuality    = double.MinValue;
                ThotSmtParameters           bestParameters = null;
                int                         iter           = 0;
                HashSet <TranslationInfo>[] curNBestLists  = null;
                float[]                     curWeights     = parameters.ModelWeights.ToArray();

                while (true)
                {
                    progress.Report(new ProgressStatus(iter, MaxIterations));

                    ThotSmtParameters newParameters = parameters.Clone();
                    newParameters.ModelWeights = curWeights;
                    newParameters.Freeze();
                    IList <TranslationInfo>[] nbestLists = GetNBestLists(newParameters, tuneSourceCorpus).ToArray();
                    double quality = Evaluation.CalculateBleu(nbestLists.Select(nbl => nbl.First().Translation),
                                                              tuneTargetCorpus);
                    iterQualities.Add(quality);
                    if (quality > bestQuality)
                    {
                        bestQuality    = quality;
                        bestParameters = newParameters;
                    }

                    iter++;
                    if (iter >= MaxIterations || IsTuningConverged(iterQualities))
                    {
                        break;
                    }

                    if (curNBestLists == null)
                    {
                        curNBestLists = nbestLists.Select(nbl => new HashSet <TranslationInfo>(nbl)).ToArray();
                    }
                    else
                    {
                        for (int i = 0; i < nbestLists.Length; i++)
                        {
                            curNBestLists[i].UnionWith(nbestLists[i]);
                        }
                    }

                    UpdateWeights(weightUpdaterHandle, tuneTargetCorpus, curNBestLists, curWeights);
                }

                if (iter < MaxIterations)
                {
                    progress.Report(new ProgressStatus(1.0));
                }
                stats.TranslationModelBleu = bestQuality;
                return(bestParameters);
            }
            finally
            {
                Thot.llWeightUpdater_close(weightUpdaterHandle);
            }
        }