Пример #1
0
        public void TrainTestSplit()
        {
            var env        = new ConsoleEnvironment(seed: 0);
            var dataPath   = GetDataPath(TestDatasets.iris.trainFilename);
            var dataSource = new MultiFileSource(dataPath);

            var ctx = new BinaryClassificationContext(env);

            var reader = TextLoader.CreateReader(env,
                                                 c => (label: c.LoadFloat(0), features: c.LoadFloat(1, 4)));
            var data = reader.Read(dataSource);

            var(train, test) = ctx.TrainTestSplit(data, 0.5);

            // Just make sure that the train is about the same size as the test set.
            var trainCount = train.GetColumn(r => r.label).Count();
            var testCount  = test.GetColumn(r => r.label).Count();

            Assert.InRange(trainCount * 1.0 / testCount, 0.8, 1.2);

            // Now stratify by label. Silly thing to do.
            (train, test) = ctx.TrainTestSplit(data, 0.5, stratificationColumn: r => r.label);
            var trainLabels = train.GetColumn(r => r.label).Distinct();
            var testLabels  = test.GetColumn(r => r.label).Distinct();

            Assert.True(trainLabels.Count() > 0);
            Assert.True(testLabels.Count() > 0);
            Assert.False(trainLabels.Intersect(testLabels).Any());
        }
Пример #2
0
        PrepareData(MLContext mlContext)
        {
            IDataView data      = null;
            IDataView trainData = null;
            IDataView testData  = null;

            // Step one: read the data as an IDataView.
            // Create the reader: define the data columns
            // and where to find them in the text file.
            var reader = new TextLoader(mlContext, new TextLoader.Arguments
            {
                Column = new[] {
                    // A boolean column depicting the 'label'.
                    new TextLoader.Column("Label", DataKind.BL, 30),
                    // 29 Features V1..V28 + Amount
                    new TextLoader.Column("V1", DataKind.R4, 1),
                    new TextLoader.Column("V2", DataKind.R4, 2),
                    new TextLoader.Column("V3", DataKind.R4, 3),
                    new TextLoader.Column("V4", DataKind.R4, 4),
                    new TextLoader.Column("V5", DataKind.R4, 5),
                    new TextLoader.Column("V6", DataKind.R4, 6),
                    new TextLoader.Column("V7", DataKind.R4, 7),
                    new TextLoader.Column("V8", DataKind.R4, 8),
                    new TextLoader.Column("V9", DataKind.R4, 9),
                    new TextLoader.Column("V10", DataKind.R4, 10),
                    new TextLoader.Column("V11", DataKind.R4, 11),
                    new TextLoader.Column("V12", DataKind.R4, 12),
                    new TextLoader.Column("V13", DataKind.R4, 13),
                    new TextLoader.Column("V14", DataKind.R4, 14),
                    new TextLoader.Column("V15", DataKind.R4, 15),
                    new TextLoader.Column("V16", DataKind.R4, 16),
                    new TextLoader.Column("V17", DataKind.R4, 17),
                    new TextLoader.Column("V18", DataKind.R4, 18),
                    new TextLoader.Column("V19", DataKind.R4, 19),
                    new TextLoader.Column("V20", DataKind.R4, 20),
                    new TextLoader.Column("V21", DataKind.R4, 21),
                    new TextLoader.Column("V22", DataKind.R4, 22),
                    new TextLoader.Column("V23", DataKind.R4, 23),
                    new TextLoader.Column("V24", DataKind.R4, 24),
                    new TextLoader.Column("V25", DataKind.R4, 25),
                    new TextLoader.Column("V26", DataKind.R4, 26),
                    new TextLoader.Column("V27", DataKind.R4, 27),
                    new TextLoader.Column("V28", DataKind.R4, 28),
                    new TextLoader.Column("Amount", DataKind.R4, 29),
                },
                // First line of the file is a header, not a data row.
                HasHeader = true,
                Separator = ","
            });


            // We know that this is a Binary Classification task,
            // so we create a Binary Classification context:
            // it will give us the algorithms we need,
            // as well as the evaluation procedure.
            var classification = new BinaryClassificationContext(mlContext);

            if (!File.Exists(Path.Combine(_outputPath, "testData.idv")) &&
                !File.Exists(Path.Combine(_outputPath, "trainData.idv")))
            {
                // Split the data 80:20 into train and test sets, train and evaluate.

                data = reader.Read(new MultiFileSource(_dataSetFile));
                ConsoleHelpers.ConsoleWriteHeader("Show 4 transactions fraud (true) and 4 transactions not fraud (false) -  (source)");
                ConsoleHelpers.InspectData(mlContext, data, 4);



                // Can't do stratification when column type is a boolean, is this an issue?
                //(trainData, testData) = classification.TrainTestSplit(data, testFraction: 0.2, stratificationColumn: "Label");
                (trainData, testData) = classification.TrainTestSplit(data, testFraction: 0.2);

                // save test split
                IHostEnvironment env = (IHostEnvironment)mlContext;
                using (var ch = env.Start("SaveData"))
                    using (var file = env.CreateOutputFile(Path.Combine(_outputPath, "testData.idv")))
                    {
                        var saver = new BinarySaver(mlContext, new BinarySaver.Arguments());
                        DataSaverUtils.SaveDataView(ch, saver, testData, file);
                    }

                // save train split
                using (var ch = ((IHostEnvironment)env).Start("SaveData"))
                    using (var file = env.CreateOutputFile(Path.Combine(_outputPath, "trainData.idv")))
                    {
                        var saver = new BinarySaver(mlContext, new BinarySaver.Arguments());
                        DataSaverUtils.SaveDataView(ch, saver, trainData, file);
                    }
            }
            else
            {
                // Load splited data
                var binTrainData = new BinaryLoader(mlContext, new BinaryLoader.Arguments(), new MultiFileSource(Path.Combine(_outputPath, "trainData.idv")));
                var trainRoles   = new RoleMappedData(binTrainData, roles: TransactionObservation.Roles());
                trainData = trainRoles.Data;


                var binTestData = new BinaryLoader(mlContext, new BinaryLoader.Arguments(), new MultiFileSource(Path.Combine(_outputPath, "testData.idv")));
                var testRoles   = new RoleMappedData(binTestData, roles: TransactionObservation.Roles());
                testData = testRoles.Data;
            }

            ConsoleHelpers.ConsoleWriteHeader("Show 4 transactions fraud (true) and 4 transactions not fraud (false) -  (traindata)");
            ConsoleHelpers.InspectData(mlContext, trainData, 4);

            ConsoleHelpers.ConsoleWriteHeader("Show 4 transactions fraud (true) and 4 transactions not fraud (false) -  (testData)");
            ConsoleHelpers.InspectData(mlContext, testData, 4);

            return(classification, reader, trainData, testData);
        }
Пример #3
0
        PrepareData(MLContext mlContext)
        {
            IDataView data      = null;
            IDataView trainData = null;
            IDataView testData  = null;

            TextLoader.Column[] columns = new[] {
                // A boolean column depicting the 'label'.
                new TextLoader.Column("Label", DataKind.BL, 30),
                // 29 Features V1..V28 + Amount
                new TextLoader.Column("V1", DataKind.R4, 1),
                new TextLoader.Column("V2", DataKind.R4, 2),
                new TextLoader.Column("V3", DataKind.R4, 3),
                new TextLoader.Column("V4", DataKind.R4, 4),
                new TextLoader.Column("V5", DataKind.R4, 5),
                new TextLoader.Column("V6", DataKind.R4, 6),
                new TextLoader.Column("V7", DataKind.R4, 7),
                new TextLoader.Column("V8", DataKind.R4, 8),
                new TextLoader.Column("V9", DataKind.R4, 9),
                new TextLoader.Column("V10", DataKind.R4, 10),
                new TextLoader.Column("V11", DataKind.R4, 11),
                new TextLoader.Column("V12", DataKind.R4, 12),
                new TextLoader.Column("V13", DataKind.R4, 13),
                new TextLoader.Column("V14", DataKind.R4, 14),
                new TextLoader.Column("V15", DataKind.R4, 15),
                new TextLoader.Column("V16", DataKind.R4, 16),
                new TextLoader.Column("V17", DataKind.R4, 17),
                new TextLoader.Column("V18", DataKind.R4, 18),
                new TextLoader.Column("V19", DataKind.R4, 19),
                new TextLoader.Column("V20", DataKind.R4, 20),
                new TextLoader.Column("V21", DataKind.R4, 21),
                new TextLoader.Column("V22", DataKind.R4, 22),
                new TextLoader.Column("V23", DataKind.R4, 23),
                new TextLoader.Column("V24", DataKind.R4, 24),
                new TextLoader.Column("V25", DataKind.R4, 25),
                new TextLoader.Column("V26", DataKind.R4, 26),
                new TextLoader.Column("V27", DataKind.R4, 27),
                new TextLoader.Column("V28", DataKind.R4, 28),
                new TextLoader.Column("Amount", DataKind.R4, 29)
            };

            TextLoader.Arguments txtLoaderArgs = new TextLoader.Arguments
            {
                Column = columns,
                // First line of the file is a header, not a data row.
                HasHeader = true,
                Separator = ","
            };

            // Step one: read the data as an IDataView.
            // Create the reader: define the data columns
            // and where to find them in the text file.
            var reader = new TextLoader(mlContext, txtLoaderArgs);


            // We know that this is a Binary Classification task,
            // so we create a Binary Classification context:
            // it will give us the algorithms we need,
            // as well as the evaluation procedure.
            var classification = new BinaryClassificationContext(mlContext);

            if (!File.Exists(Path.Combine(_outputPath, "testData.idv")) &&
                !File.Exists(Path.Combine(_outputPath, "trainData.idv")))
            {
                // Split the data 80:20 into train and test sets, train and evaluate.

                data = reader.Read(new MultiFileSource(_dataSetFile));
                ConsoleHelpers.ConsoleWriteHeader("Show 4 transactions fraud (true) and 4 transactions not fraud (false) -  (source)");
                ConsoleHelpers.InspectData(mlContext, data, 4);



                // Can't do stratification when column type is a boolean, is this an issue?
                //(trainData, testData) = classification.TrainTestSplit(data, testFraction: 0.2, stratificationColumn: "Label");
                (trainData, testData) = classification.TrainTestSplit(data, testFraction: 0.2);

                // save test split
                using (var fileStream = File.Create(Path.Combine(_outputPath, "testData.csv")))
                {
                    mlContext.Data.SaveAsText(testData, fileStream, separator: ',', headerRow: true, schema: true);
                }

                // save train split
                using (var fileStream = File.Create(Path.Combine(_outputPath, "trainData.csv")))
                {
                    mlContext.Data.SaveAsText(testData, fileStream, separator: ',', headerRow: true, schema: true);
                }
            }
            else
            {
                //Add the "StratificationColumn" that was added by classification.TrainTestSplit()
                // And Label is moved to column 0

                TextLoader.Column[] columnsPlus = new[] {
                    // A boolean column depicting the 'label'.
                    new TextLoader.Column("Label", DataKind.BL, 0),
                    // 30 Features V1..V28 + Amount + StratificationColumn
                    new TextLoader.Column("V1", DataKind.R4, 1),
                    new TextLoader.Column("V2", DataKind.R4, 2),
                    new TextLoader.Column("V3", DataKind.R4, 3),
                    new TextLoader.Column("V4", DataKind.R4, 4),
                    new TextLoader.Column("V5", DataKind.R4, 5),
                    new TextLoader.Column("V6", DataKind.R4, 6),
                    new TextLoader.Column("V7", DataKind.R4, 7),
                    new TextLoader.Column("V8", DataKind.R4, 8),
                    new TextLoader.Column("V9", DataKind.R4, 9),
                    new TextLoader.Column("V10", DataKind.R4, 10),
                    new TextLoader.Column("V11", DataKind.R4, 11),
                    new TextLoader.Column("V12", DataKind.R4, 12),
                    new TextLoader.Column("V13", DataKind.R4, 13),
                    new TextLoader.Column("V14", DataKind.R4, 14),
                    new TextLoader.Column("V15", DataKind.R4, 15),
                    new TextLoader.Column("V16", DataKind.R4, 16),
                    new TextLoader.Column("V17", DataKind.R4, 17),
                    new TextLoader.Column("V18", DataKind.R4, 18),
                    new TextLoader.Column("V19", DataKind.R4, 19),
                    new TextLoader.Column("V20", DataKind.R4, 20),
                    new TextLoader.Column("V21", DataKind.R4, 21),
                    new TextLoader.Column("V22", DataKind.R4, 22),
                    new TextLoader.Column("V23", DataKind.R4, 23),
                    new TextLoader.Column("V24", DataKind.R4, 24),
                    new TextLoader.Column("V25", DataKind.R4, 25),
                    new TextLoader.Column("V26", DataKind.R4, 26),
                    new TextLoader.Column("V27", DataKind.R4, 27),
                    new TextLoader.Column("V28", DataKind.R4, 28),
                    new TextLoader.Column("Amount", DataKind.R4, 29),
                    new TextLoader.Column("StratificationColumn", DataKind.R4, 30)
                };

                // Load splited data
                trainData = mlContext.Data.ReadFromTextFile(columnsPlus, Path.Combine(_outputPath, "trainData.csv"),
                                                            advancedSettings: s => {
                    s.HasHeader = txtLoaderArgs.HasHeader;
                    s.Separator = txtLoaderArgs.Separator;
                }
                                                            );
                testData = mlContext.Data.ReadFromTextFile(columnsPlus, Path.Combine(_outputPath, "testData.csv"),
                                                           advancedSettings: s => {
                    s.HasHeader = txtLoaderArgs.HasHeader;
                    s.Separator = txtLoaderArgs.Separator;
                }
                                                           );
            }

            ConsoleHelpers.ConsoleWriteHeader("Show 4 transactions fraud (true) and 4 transactions not fraud (false) -  (traindata)");
            ConsoleHelpers.InspectData(mlContext, trainData, 4);

            ConsoleHelpers.ConsoleWriteHeader("Show 4 transactions fraud (true) and 4 transactions not fraud (false) -  (testData)");
            ConsoleHelpers.InspectData(mlContext, testData, 4);

            return(classification, reader, trainData, testData);
        }
Пример #4
0
        public static PredictionFunction <TrafficObservation, AlertPrediction> ModelAndTrain()
        {
            Console.WriteLine("Starting Machine Learning Binary Classification");
            MLContext mlContext = new MLContext(seed: 1);

            IDataView data      = null;
            IDataView trainData = null;
            IDataView testData  = null;

            // Step one: read the data as an IDataView.
            // Create the reader: define the data columns
            // and where to find them in the text file.
            var reader = new TextLoader(mlContext, new TextLoader.Arguments
            {
                Column = new[] {
                    // A boolean column depicting the 'label'.
                    new TextLoader.Column("NextHourAlert", DataKind.BL, 20),
                    // 18 Features
                    new TextLoader.Column("AvgTotalBytes", DataKind.R4, 2),
                    new TextLoader.Column("AvgTotalPackets", DataKind.R4, 3),
                    new TextLoader.Column("AvgAveragebps", DataKind.R4, 4),
                    new TextLoader.Column("AvgOutPercentUtil", DataKind.R4, 5),
                    new TextLoader.Column("AvgInPercentUtil", DataKind.R4, 6),
                    new TextLoader.Column("AvgPercentUtil", DataKind.R4, 7),
                    new TextLoader.Column("MinTotalBytes", DataKind.R4, 8),
                    new TextLoader.Column("MinTotalPackets", DataKind.R4, 9),
                    new TextLoader.Column("MinAveragebps", DataKind.R4, 10),
                    new TextLoader.Column("MinOutPercentUtil", DataKind.R4, 11),
                    new TextLoader.Column("MinInPercentUtil", DataKind.R4, 12),
                    new TextLoader.Column("MinPercentUtil", DataKind.R4, 13),
                    new TextLoader.Column("MaxTotalBytes", DataKind.R4, 14),
                    new TextLoader.Column("MaxTotalPackets", DataKind.R4, 15),
                    new TextLoader.Column("MaxAveragebps", DataKind.R4, 16),
                    new TextLoader.Column("MaxOutPercentUtil", DataKind.R4, 17),
                    new TextLoader.Column("MaxInPercentUtil", DataKind.R4, 18),
                    new TextLoader.Column("MaxPercentUtil", DataKind.R4, 19)
                },
                // First line of the file is a header, not a data row.
                HasHeader = true,
                Separator = ","
            });

            // We know that this is a Binary Classification task,
            // so we create a Binary Classification context:
            // it will give us the algorithms we need,
            // as well as the evaluation procedure.
            var classification = new BinaryClassificationContext(mlContext);

            data = reader.Read(new MultiFileSource(_datapath));

            (trainData, testData) = classification.TrainTestSplit(data, testFraction: 0.2);

            //Create a flexible pipeline (composed by a chain of estimators) for building/traing the model.

            var pipeline = mlContext.Transforms.Concatenate("Features", new[] { "AvgTotalBytes", "AvgTotalPackets", "AvgAveragebps", "AvgOutPercentUtil", "AvgInPercentUtil", "AvgPercentUtil",
                                                                                "MinTotalBytes", "MinTotalPackets", "MinAveragebps", "MinOutPercentUtil", "MinInPercentUtil", "MinPercentUtil",
                                                                                "MaxTotalBytes", "MaxTotalPackets", "MaxAveragebps", "MaxOutPercentUtil", "MaxInPercentUtil", "MaxPercentUtil" })
                           .Append(mlContext.Transforms.Normalize(inputName: "Features", outputName: "FeaturesNormalizedByMeanVar", mode: NormalizerMode.MeanVariance))
                           .Append(mlContext.BinaryClassification.Trainers.FastTree(label: "NextHourAlert",
                                                                                    features: "Features",
                                                                                    numLeaves: 20,
                                                                                    numTrees: 100,
                                                                                    minDatapointsInLeafs: 10,
                                                                                    learningRate: 0.2));
            var model = pipeline.Fit(trainData);

            var metrics = classification.Evaluate(model.Transform(testData), "NextHourAlert");

            Console.WriteLine("Acuracy: " + metrics.Accuracy);
            Console.WriteLine($"Area under ROC curve: {metrics.Auc}");
            Console.WriteLine($"Area under the precision/recall curve: {metrics.Auprc}");
            Console.WriteLine($"Entropy: {metrics.Entropy}");
            Console.WriteLine($"F1 Score: {metrics.F1Score}");
            Console.WriteLine($"Log loss: {metrics.LogLoss}");
            Console.WriteLine($"Log loss reduction: {metrics.LogLossReduction}");
            Console.WriteLine($"Negative precision: {metrics.NegativePrecision}");
            Console.WriteLine($"Positive precision: {metrics.PositivePrecision}");
            Console.WriteLine($"Positive recall: {metrics.PositiveRecall}");

            var predictor = model.MakePredictionFunction <TrafficObservation, AlertPrediction>(mlContext);

            return(predictor);
        }