public override void run(string format, string[] args)
        {
            base.run(format, args);

            mlParams = CmdLineUtil.loadTrainingParameters(@params.Params, false);

            if (mlParams != null)
            {
                if (!TrainUtil.isValid(mlParams.Settings))
                {
                    throw new TerminateToolException(1, "Training parameters file '" + @params.Params + "' is invalid!");
                }

                if (TrainUtil.isSequenceTraining(mlParams.Settings))
                {
                    throw new TerminateToolException(1, "Sequence training is not supported!");
                }
            }

            if (mlParams == null)
            {
                mlParams = ModelUtil.createTrainingParameters(@params.Iterations.Value, @params.Cutoff.Value);
            }

            File modelOutFile = @params.Model;

            CmdLineUtil.checkOutputFile("tokenizer model", modelOutFile);

            TokenizerModel model;

            try
            {
                Dictionary dict = loadDict(@params.AbbDict);

                TokenizerFactory tokFactory = TokenizerFactory.create(@params.Factory, @params.Lang, dict, @params.AlphaNumOpt.Value, null);
                model = opennlp.tools.tokenize.TokenizerME.train(sampleStream, tokFactory, mlParams);
            }
            catch (IOException e)
            {
                throw new TerminateToolException(-1, "IO error while reading training data or indexing data: " + e.Message, e);
            }
            finally
            {
                try
                {
                    sampleStream.close();
                }
                catch (IOException)
                {
                    // sorry that this can fail
                }
            }

            CmdLineUtil.writeModel("tokenizer", modelOutFile, model);
        }
Beispiel #2
0
        public override void run(string format, string[] args)
        {
            base.run(format, args);

            mlParams = CmdLineUtil.loadTrainingParameters(@params.Params, false);

            if (mlParams != null)
            {
                if (TrainUtil.isSequenceTraining(mlParams.Settings))
                {
                    throw new TerminateToolException(1, "Sequence training is not supported!");
                }
            }

            if (mlParams == null)
            {
                mlParams = ModelUtil.createTrainingParameters(@params.Iterations.Value, @params.Cutoff.Value);
            }

            Jfile modelOutFile = @params.Model;

            CmdLineUtil.checkOutputFile("sentence detector model", modelOutFile);

            char[] eos = null;
            if (@params.EosChars != null)
            {
                eos = @params.EosChars.ToCharArray();
            }

            SentenceModel model;

            try
            {
                Dictionary dict = loadDict(@params.AbbDict);
                SentenceDetectorFactory sdFactory = SentenceDetectorFactory.create(@params.Factory, @params.Lang, true, dict, eos);
                model = SentenceDetectorME.train(@params.Lang, sampleStream, sdFactory, mlParams);
            }
            catch (IOException e)
            {
                throw new TerminateToolException(-1, "IO error while reading training data or indexing data: " + e.Message, e);
            }
            finally
            {
                try
                {
                    sampleStream.close();
                }
                catch (IOException)
                {
                    // sorry that this can fail
                }
            }

            CmdLineUtil.writeModel("sentence detector", modelOutFile, model);
        }
Beispiel #3
0
 public void GenerateTrainPath_Test()
 {
     if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
     {
         var path1 = TrainUtil.GenerateTrainPath("D:\\Test1", "_000001_");
         Assert.Equal("D:\\Test1\\_000001_", path1);
     }
     else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
     {
         var path1 = TrainUtil.GenerateTrainPath("/usr/local/", "_000001_");
         Assert.Equal("/usr/local/_000001_", path1);
     }
 }
Beispiel #4
0
        /// <summary>
        /// Ctor
        /// </summary>
        /// <param name="logger"></param>
        /// <param name="configuration"></param>
        /// <param name="index"></param>
        public Train(
            ILogger <Train> logger,
            FilePoolConfiguration configuration,
            int index)
        {
            _logger        = logger;
            _configuration = configuration;

            Name      = TrainUtil.GenerateTrainName(index);
            Path      = TrainUtil.GenerateTrainPath(configuration.Path, Name);
            Index     = index;
            TrainType = TrainType.Default;

            _pendingQueue    = new ConcurrentQueue <SpoolFile>();
            _progressingDict = new ConcurrentDictionary <string, SpoolFile>();
        }
Beispiel #5
0
        // its optional, passing null is allowed
        public static TrainingParameters loadTrainingParameters(string paramFile, bool supportSequenceTraining)
        {
            TrainingParameters @params = null;

            if (paramFile != null)
            {
                checkInputFile("Training Parameter", new Jfile(paramFile));

                InputStream paramsIn = null;
                try
                {
                    paramsIn = new FileInputStream(new Jfile(paramFile));

                    @params = new TrainingParameters(paramsIn);
                }
                catch (IOException e)
                {
                    throw new TerminateToolException(-1, "Error during parameters loading: " + e.Message, e);
                }
                finally
                {
                    try
                    {
                        if (paramsIn != null)
                        {
                            paramsIn.close();
                        }
                    }
                    catch (IOException)
                    {
                        //sorry that this can fail
                    }
                }

                if (!TrainUtil.isValid(@params.getSettings()))
                {
                    throw new TerminateToolException(1, "Training parameters file '" + paramFile + "' is invalid!");
                }

                if (!supportSequenceTraining && TrainUtil.isSequenceTraining(@params.getSettings()))
                {
                    throw new TerminateToolException(1, "Sequence training is not supported!");
                }
            }

            return(@params);
        }
        // TODO: Add param to train tree insert parser
        public override void run(string format, string[] args)
        {
            base.run(format, args);

            mlParams = CmdLineUtil.loadTrainingParameters(@params.Params, true);

            if (mlParams != null)
            {
                if (!TrainUtil.isValid(mlParams.getSettings("build")))
                {
                    throw new TerminateToolException(1, "Build training parameters are invalid!");
                }

                if (!TrainUtil.isValid(mlParams.getSettings("check")))
                {
                    throw new TerminateToolException(1, "Check training parameters are invalid!");
                }

                if (!TrainUtil.isValid(mlParams.getSettings("attach")))
                {
                    throw new TerminateToolException(1, "Attach training parameters are invalid!");
                }

                if (!TrainUtil.isValid(mlParams.getSettings("tagger")))
                {
                    throw new TerminateToolException(1, "Tagger training parameters are invalid!");
                }

                if (!TrainUtil.isValid(mlParams.getSettings("chunker")))
                {
                    throw new TerminateToolException(1, "Chunker training parameters are invalid!");
                }
            }

            if (mlParams == null)
            {
                mlParams = ModelUtil.createTrainingParameters(@params.Iterations.Value, @params.Cutoff.Value);
            }

            Jfile modelOutFile = @params.Model;

            CmdLineUtil.checkOutputFile("parser model", modelOutFile);

            ParserModel model;

            try
            {
                // TODO hard-coded language reference
                HeadRules rules = new opennlp.tools.parser.lang.en.HeadRules(new InputStreamReader(new FileInputStream(@params.HeadRules), @params.Encoding));

                var type = parseParserType(@params.ParserType);
                if (@params.Fun.Value)
                {
                    Parse.useFunctionTags(true);
                }

                if (ParserType.CHUNKING == type)
                {
                    model = Parser.train(@params.Lang, sampleStream, rules, mlParams);
                }
                else if (ParserType.TREEINSERT == type)
                {
                    model = opennlp.tools.parser.treeinsert.Parser.train(@params.Lang, sampleStream, rules, mlParams);
                }
                else
                {
                    throw new IllegalStateException();
                }
            }
            catch (IOException e)
            {
                throw new TerminateToolException(-1, "IO error while reading training data or indexing data: " + e.Message, e);
            }
            finally
            {
                try
                {
                    sampleStream.close();
                }
                catch (IOException)
                {
                    // sorry that this can fail
                }
            }

            CmdLineUtil.writeModel("parser", modelOutFile, model);
        }
Beispiel #7
0
        /// <summary>
        /// Выполнить обучение.
        /// </summary>
        /// <param name="pathToSettings">Путь до сохранённого файла настроек.</param>
        private static void DoTrain(out string pathToSettings)
        {
            Console.Clear();
            ConsoleExtensions.WriteWithColors(ConsoleColor.Black, ConsoleColor.Green,
                                              "Вас приветствует обучение!\nУкажите директорию обучающей выборки " +
                                              "(enter для директории по-умолчанию):");

            var input = Console.ReadLine();

            if (input.Equals(string.Empty))
            {
                input = PathHelper.GetResourcesPath();
            }

            if (!Directory.Exists(input))
            {
                throw new Exception($"Указанная директория не существует!\nДиректория: {input}");
            }

            var directories      = Directory.GetDirectories(input).ToList();
            var matrixDictionary = new Dictionary <int, List <double[, ]> >();

            var key = 0;

            foreach (var directory in directories)
            {
                var files = Directory.GetFiles(directory).ToList();

                if (!files.Any())
                {
                    throw new Exception($"Файлы не найдены!\nДиректория: {directory}");
                }

                var images        = PathToImageConverter.LoadImages(files);
                var resizedImages = NormilizeUtil.ResizeImages(images, 6, 6);

                var normilizedMatrixies = NormilizeUtil.GetNormilizedMatrixesFromImages(resizedImages);
                matrixDictionary.Add(key, normilizedMatrixies);

                ++key;
            }

            var hyperParameters = new HyperParameters
            {
                EpochCount = 1,
                Epsilon    = 0.75,
                Alpha      = 0.001
            };

            var topology = new Topology();

            topology.Add(2, 2, LayerType.Convolution);
            topology.Add(3, 2, LayerType.Subsampling);
            topology.Add(4, 2, LayerType.Hidden, true);

            if (!topology.IsClosed)
            {
                throw new Exception("Не удалось замкнуть топологию.");
            }

            var dataSet = new DataSet(DataSetType.ForNumberRecognizing);

            if (!matrixDictionary.Count.Equals(dataSet.MaxCountInDataSet()))
            {
                throw new Exception("Не соответсвие количества выборок для распознавания чисел!");
            }

            foreach (var pair in matrixDictionary)
            {
                dataSet.Add(pair.Value);
            }

            var trainUtil = new TrainUtil(dataSet, TrainType.Backpropagation, hyperParameters, topology);

            trainUtil.Start(3, 2, out var error, out pathToSettings);

            var errorString = $"{Math.Round(error * 100, 2)}%";

            ConsoleExtensions.WriteWithColors(ConsoleColor.Black, ConsoleColor.Yellow,
                                              $"\nОшибка: {errorString}");

            if (pathToSettings.Equals(string.Empty))
            {
                ConsoleExtensions.WriteWithColors(ConsoleColor.Black, ConsoleColor.Red,
                                                  "\nДанные не были сохранены!");
            }
            else
            {
                ConsoleExtensions.WriteWithColors(ConsoleColor.Black, ConsoleColor.Blue,
                                                  $"\nДанные сохранены!\nДиректория: {pathToSettings}");
            }
        }
        public override void run(string format, string[] args)
        {
            base.run(format, args);

            mlParams = CmdLineUtil.loadTrainingParameters(@params.Params, true);
            if (mlParams != null && !TrainUtil.isValid(mlParams.Settings))
            {
                throw new TerminateToolException(1, "Training parameters file '" + @params.Params + "' is invalid!");
            }

            if (mlParams == null)
            {
                mlParams = ModelUtil.createTrainingParameters(@params.Iterations.Value, @params.Cutoff.Value);
                mlParams.put(TrainingParameters.ALGORITHM_PARAM, getModelType(@params.Type).ToString());
            }

            File modelOutFile = @params.Model;

            CmdLineUtil.checkOutputFile("pos tagger model", modelOutFile);

            Dictionary ngramDict = null;

            int?ngramCutoff = @params.Ngram;

            if (ngramCutoff != null)
            {
                Console.Error.Write("Building ngram dictionary ... ");
                try
                {
                    ngramDict = POSTaggerME.buildNGramDictionary(sampleStream, ngramCutoff.Value);
                    sampleStream.reset();
                }
                catch (IOException e)
                {
                    throw new TerminateToolException(-1, "IO error while building NGram Dictionary: " + e.Message, e);
                }
                Console.Error.WriteLine("done");
            }

            POSTaggerFactory postaggerFactory = null;

            try
            {
                postaggerFactory = POSTaggerFactory.create(@params.Factory, ngramDict, null);
            }
            catch (InvalidFormatException e)
            {
                throw new TerminateToolException(-1, e.Message, e);
            }

            if (@params.Dict != null)
            {
                try
                {
                    postaggerFactory.TagDictionary = postaggerFactory.createTagDictionary(@params.Dict);
                }
                catch (IOException e)
                {
                    throw new TerminateToolException(-1, "IO error while loading POS Dictionary: " + e.Message, e);
                }
            }

            if (@params.TagDictCutoff != null)
            {
                try
                {
                    TagDictionary dict = postaggerFactory.TagDictionary;
                    if (dict == null)
                    {
                        dict = postaggerFactory.createEmptyTagDictionary();
                        postaggerFactory.TagDictionary = dict;
                    }
                    if (dict is MutableTagDictionary)
                    {
                        POSTaggerME.populatePOSDictionary(sampleStream, (MutableTagDictionary)dict, @params.TagDictCutoff.Value);
                    }
                    else
                    {
                        throw new System.ArgumentException("Can't extend a POSDictionary that does not implement MutableTagDictionary.");
                    }
                    sampleStream.reset();
                }
                catch (IOException e)
                {
                    throw new TerminateToolException(-1, "IO error while creating/extending POS Dictionary: " + e.Message, e);
                }
            }

            POSModel model;

            try
            {
                model = POSTaggerME.train(@params.Lang, sampleStream, mlParams, postaggerFactory);
            }
            catch (IOException e)
            {
                throw new TerminateToolException(-1, "IO error while reading training data or indexing data: " + e.Message, e);
            }
            finally
            {
                try
                {
                    sampleStream.close();
                }
                catch (IOException)
                {
                    // sorry that this can fail
                }
            }

            CmdLineUtil.writeModel("pos tagger", modelOutFile, model);
        }
Beispiel #9
0
 public void GenerateTrainName_Test(int index, string expected)
 {
     Assert.Equal(expected, TrainUtil.GenerateTrainName(index));
 }
Beispiel #10
0
 public void IsTrainName_Test(string name, bool expected)
 {
     Assert.Equal(expected, TrainUtil.IsTrainName(name));
 }
Beispiel #11
0
 public void GetTrainIndex_Test(string name, int?expected)
 {
     Assert.Equal(expected, TrainUtil.GetTrainIndex(name));
 }