Пример #1
0
 /// <summary>
 /// Собирает данные о модели для инициализации данного объекта
 /// </summary>
 /// <param name="metrics">Снятые с модели метрики.</param>
 /// <param name="type">Тип модели.</param>
 /// <param name="model">Модель как ITransformer.</param>
 /// <param name="schema">Схема данных, использовавшихся при тренировке и снятии метрик.</param>
 /// <param name="parameters">Макропараметры, использовавшиеся для тренировки модели.</param>
 public MLModel(MulticlassClassificationMetrics metrics, MLModelTypes type, ITransformer model, DataViewSchema schema, TreeParameters parameters)
 {
     Metrics    = new MyModelMetrics(metrics);
     ModelType  = type;
     Model      = model;
     Schema     = schema;
     Parameters = parameters;
 }
        /// <summary>
        /// Основной метод генерации моделей машинного обучения в программе.
        /// </summary>
        /// <param name="trainPath">Путь к тренировочным данным.</param>
        /// <param name="type">Нужный тип модели.</param>
        /// <param name="parameters">Дополнительные параметры модели.</param>
        /// <param name="testPath">
        /// Опциональный аргумент — путь к тестовым данным.
        /// Если не указан, то тестовый набор данных будет получен как произвольный поднабор
        /// тестовых данных в объёме 10% от размера тренировочных данных.
        /// </param>
        /// <returns></returns>
        public static MLModel GetModelOverPipeline(string trainPath, MLModelTypes type, TreeParameters parameters, string testPath = "")
        {
            // подписываемся на ивент логгинга
            mlContext.Log += LogMessage;

            IDataView     trainData, testData;
            object        dataTransform;
            TrainTestData data;

            // оперделяем, нужно или не нужно делить данные.
            if (testPath == "")
            {
                data      = mlContext.Data.TrainTestSplit(loader.Load(trainPath), seed: rand.Next(10000));
                trainData = data.TrainSet;
                testData  = data.TestSet;
            }
            else
            {
                trainData = loader.Load(testPath);
                testData  = loader.Load(trainPath);
            }



            // определяем трансформацию данных и валидируем данные
            if (testPath == "")
            {
                dataTransform = mlContext.Transforms.CustomMapping <InputLabelsRow, OutputLabelsRow>(CustomMappings.LabelMapping,
                                                                                                     nameof(CustomMappings.LabelMapping))
                                .Append(mlContext.Transforms.Conversion.MapValueToKey("CodedLabel", "CorrectLabel"));
                var      trainDataView = (dataTransform as EstimatorChain <ValueToKeyMappingTransformer>).Fit(trainData).Transform(trainData);
                string[] trainCols     = GetLabelNamesArray(trainDataView);

                if (trainCols.Length != 7)
                {
                    throw new IndexOutOfRangeException();
                }
            }
            else
            {
                dataTransform = mlContext.Transforms.Conversion.MapValueToKey("CodedLabel", "Label");
                var testDataView  = (dataTransform as ValueToKeyMappingEstimator).Fit(testData).Transform(testData);
                var trainDataView = (dataTransform as ValueToKeyMappingEstimator).Fit(trainData).Transform(trainData);

                string[] sel1cols = GetLabelNamesArray(testDataView);
                string[] sel2cols = GetLabelNamesArray(trainDataView);

                Array.Sort(sel1cols);
                Array.Sort(sel2cols);

                for (int i = 0; i < 7; i++)
                {
                    if (sel1cols[i] != sel2cols[i])
                    {
                        throw new FormatException("Не вышло сопоставить ярлыки наборов данных!");
                    }
                }
            }



            // и генерируем модель нужного типа!
            switch (type)
            {
            case MLModelTypes.LightGbm:
                var trainer = mlContext.MulticlassClassification.Trainers.LightGbm("CodedLabel", "Features",
                                                                                   numberOfLeaves: parameters.LeavesPerTree,
                                                                                   minimumExampleCountPerLeaf: parameters.DataPointsNumber,
                                                                                   numberOfIterations: parameters.NumberOfTrees,
                                                                                   learningRate: parameters.LearningRate);
                object pipeline;
                if (dataTransform is EstimatorChain <ValueToKeyMappingTransformer> )
                {
                    pipeline = (dataTransform as EstimatorChain <ValueToKeyMappingTransformer>).Append(trainer);
                }
                else
                {
                    pipeline = (dataTransform as ValueToKeyMappingEstimator).Append(trainer);
                }
                var fittedModel = (pipeline as EstimatorChain <MulticlassPredictionTransformer <OneVersusAllModelParameters> >).Fit(trainData);
                var metrics     = mlContext.MulticlassClassification.Evaluate(fittedModel.Transform(testData), "CodedLabel", topKPredictionCount: 3);
                return(new MLModel(metrics, type, fittedModel, fittedModel.Transform(testData).Schema, parameters));

            case MLModelTypes.PCFastTree:
                var trainer0 = mlContext.MulticlassClassification.Trainers.PairwiseCoupling(mlContext.BinaryClassification.Trainers.FastTree(
                                                                                                "CodedLabel", "Features",
                                                                                                numberOfLeaves: parameters.LeavesPerTree,
                                                                                                minimumExampleCountPerLeaf: parameters.DataPointsNumber,
                                                                                                numberOfTrees: parameters.NumberOfTrees,
                                                                                                learningRate: parameters.LearningRate), "CodedLabel");
                object pipeline0;
                if (dataTransform is EstimatorChain <ValueToKeyMappingTransformer> )
                {
                    pipeline0 = (dataTransform as EstimatorChain <ValueToKeyMappingTransformer>).Append(trainer0);
                }
                else
                {
                    pipeline0 = (dataTransform as ValueToKeyMappingEstimator).Append(trainer0);
                }
                var fittedModel0 = (pipeline0 as EstimatorChain <MulticlassPredictionTransformer <MaximumEntropyModelParameters> >).Fit(trainData);
                var metrics0     = mlContext.MulticlassClassification.Evaluate(fittedModel0.Transform(testData), "CodedLabel", topKPredictionCount: 3);
                return(new MLModel(metrics0, type, fittedModel0, fittedModel0.Transform(testData).Schema, parameters));

            case MLModelTypes.PCFastForest:
                var trainer1 = mlContext.MulticlassClassification.Trainers.PairwiseCoupling(mlContext.BinaryClassification.Trainers.FastForest(
                                                                                                "CodedLabel", "Features",
                                                                                                numberOfLeaves: parameters.LeavesPerTree,
                                                                                                minimumExampleCountPerLeaf: parameters.DataPointsNumber,
                                                                                                numberOfTrees: parameters.NumberOfTrees), "CodedLabel");
                object pipeline1;
                if (dataTransform is EstimatorChain <ValueToKeyMappingTransformer> )
                {
                    pipeline1 = (dataTransform as EstimatorChain <ValueToKeyMappingTransformer>).Append(trainer1);
                }
                else
                {
                    pipeline1 = (dataTransform as ValueToKeyMappingEstimator).Append(trainer1);
                }
                var fittedModel1 = (pipeline1 as EstimatorChain <MulticlassPredictionTransformer <PairwiseCouplingModelParameters> >).Fit(trainData);
                var metrics1     = mlContext.MulticlassClassification.Evaluate(fittedModel1.Transform(testData), "CodedLabel", topKPredictionCount: 3);
                return(new MLModel(metrics1, type, fittedModel1, fittedModel1.Transform(testData).Schema, parameters));

            case MLModelTypes.OVAFastTree:
                var trainer4 = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastTree(
                                                                                            "CodedLabel", "Features",
                                                                                            numberOfLeaves: parameters.LeavesPerTree,
                                                                                            minimumExampleCountPerLeaf: parameters.DataPointsNumber,
                                                                                            numberOfTrees: parameters.NumberOfTrees,
                                                                                            learningRate: parameters.LearningRate), "CodedLabel");
                object pipeline4;
                if (dataTransform is EstimatorChain <ValueToKeyMappingTransformer> )
                {
                    pipeline4 = (dataTransform as EstimatorChain <ValueToKeyMappingTransformer>).Append(trainer4);
                }
                else
                {
                    pipeline4 = (dataTransform as ValueToKeyMappingEstimator).Append(trainer4);
                }
                var fittedModel4 = (pipeline4 as EstimatorChain <MulticlassPredictionTransformer <OneVersusAllModelParameters> >).Fit(trainData);
                var metrics4     = mlContext.MulticlassClassification.Evaluate(fittedModel4.Transform(testData), "CodedLabel", topKPredictionCount: 3);
                return(new MLModel(metrics4, type, fittedModel4, fittedModel4.Transform(testData).Schema, parameters));

            case MLModelTypes.OVAFastForest:
                var trainer5 = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(
                                                                                            "CodedLabel", "Features",
                                                                                            numberOfLeaves: parameters.LeavesPerTree,
                                                                                            minimumExampleCountPerLeaf: parameters.DataPointsNumber,
                                                                                            numberOfTrees: parameters.NumberOfTrees), "CodedLabel");
                object pipeline5;
                if (dataTransform is EstimatorChain <ValueToKeyMappingTransformer> )
                {
                    pipeline5 = (dataTransform as EstimatorChain <ValueToKeyMappingTransformer>).Append(trainer5);
                }
                else
                {
                    pipeline5 = (dataTransform as ValueToKeyMappingEstimator).Append(trainer5);
                }
                var fittedModel5 = (pipeline5 as EstimatorChain <MulticlassPredictionTransformer <OneVersusAllModelParameters> >).Fit(trainData);
                var metrics5     = mlContext.MulticlassClassification.Evaluate(fittedModel5.Transform(testData), "CodedLabel", topKPredictionCount: 3);
                return(new MLModel(metrics5, type, fittedModel5, fittedModel5.Transform(testData).Schema, parameters));
            }
            return(null);
        }