示例#1
0
        public void ReconfigurablePredictionNoPipeline()
        {
            var mlContext = new MLContext(seed: 1);
            var data      = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
            var pipeline  = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
                new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options {
                NumberOfThreads = 1
            });
            var model           = pipeline.Fit(data);
            var newModel        = mlContext.BinaryClassification.ChangeModelThreshold(model, -2.0f);
            var rnd             = new Random(1);
            var randomDataPoint = TypeTestData.GetRandomInstance(rnd);
            var engine          = mlContext.Model.CreatePredictionEngine <TypeTestData, Prediction>(model);
            var pr = engine.Predict(randomDataPoint);

            // Score is -1.38 so predicted label is false.
            Assert.False(pr.PredictedLabel);
            Assert.True(pr.Score <= 0);
            var newEngine = mlContext.Model.CreatePredictionEngine <TypeTestData, Prediction>(newModel);

            pr = newEngine.Predict(randomDataPoint);
            // Score is still -1.38 but since threshold is no longer 0 but -2 predicted label now is true.
            Assert.True(pr.PredictedLabel);
            Assert.True(pr.Score <= 0);
        }
示例#2
0
 public void TestDataTypes(string typeName, int id, TypeTestData data, string context)
 {
     using (var conn = new DataConnection(context))
     {
         Assert.That(data.Func(typeName, this, conn), Is.EqualTo(data.Result));
     }
 }
示例#3
0
        public void ReadFromIEnumerable()
        {
            var mlContext = new MLContext(seed: 1, conc: 1);

            // Read the dataset from an enumerable.
            var data = mlContext.Data.ReadFromEnumerable(TypeTestData.GenerateDataset());

            Common.AssertTypeTestDataset(data);
        }
示例#4
0
        public void WriteAndReadAFromABinaryFile()
        {
            var mlContext = new MLContext(seed: 1, conc: 1);

            var dataBefore = mlContext.Data.ReadFromEnumerable(TypeTestData.GenerateDataset());

            // Serialize a dataset with a known schema to a file.
            var filePath  = SerializeDatasetToBinaryFile(mlContext, dataBefore);
            var dataAfter = mlContext.Data.ReadFromBinary(filePath);

            Common.AssertTestTypeDatasetsAreEqual(mlContext, dataBefore, dataAfter);
        }
示例#5
0
        public void ExportToIEnumerable()
        {
            var mlContext = new MLContext(seed: 1, conc: 1);

            // Read the dataset from an enumerable.
            var enumerableBefore = TypeTestData.GenerateDataset();
            var data             = mlContext.Data.ReadFromEnumerable(enumerableBefore);

            // Export back to an enumerable.
            var enumerableAfter = mlContext.CreateEnumerable <TypeTestData>(data, true);

            Common.AssertEqual(enumerableBefore, enumerableAfter);
        }
示例#6
0
        public void WriteToAndReadASchemaFromADelimitedFile()
        {
            var mlContext = new MLContext(seed: 1, conc: 1);

            var dataBefore = mlContext.Data.ReadFromEnumerable(TypeTestData.GenerateDataset());

            foreach (var separator in _separators)
            {
                // Serialize a dataset with a known schema to a file.
                var filePath  = SerializeDatasetToFile(mlContext, dataBefore, separator);
                var dataAfter = mlContext.Data.ReadFromTextFile <TypeTestData>(filePath, hasHeader: true, separatorChar: separator);
                Common.AssertTestTypeDatasetsAreEqual(mlContext, dataBefore, dataAfter);
            }
        }
        public void WriteToAndReadFromADelimetedFile()
        {
            var mlContext = new MLContext(seed: 1);

            var dataBefore = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());

            foreach (var separator in _separators)
            {
                // Serialize a dataset with a known schema to a file.
                var filePath  = SerializeDatasetToFile(mlContext, dataBefore, separator);
                var dataAfter = TypeTestData.GetTextLoader(mlContext, separator).Load(filePath);
                Common.AssertTestTypeDatasetsAreEqual(mlContext, dataBefore, dataAfter);
            }
        }
示例#8
0
 public void TestDataTypes(string typeName, int id, TypeTestData data, string context)
 {
     using (var conn = new DataConnection(context))
     {
         var value = data.Func(typeName, this, conn);
         if (data.Result is NpgsqlPoint)
         {
             Assert.IsTrue(object.Equals(value, data.Result));
         }
         else
         {
             Assert.AreEqual(value, data.Result);
         }
     }
 }
示例#9
0
 /// <summary>
 /// Assert that two TypeTest datasets are equal.
 /// </summary>
 /// <param name="testType1">An <see cref="TypeTestData"/>.</param>
 /// <param name="testType2">An <see cref="TypeTestData"/>.</param>
 public static void AssertEqual(TypeTestData testType1, TypeTestData testType2)
 {
     Assert.Equal(testType1.Label, testType2.Label);
     Common.AssertEqual(testType1.Features, testType2.Features);
     Assert.Equal(testType1.I1, testType2.I1);
     Assert.Equal(testType1.U1, testType2.U1);
     Assert.Equal(testType1.I2, testType2.I2);
     Assert.Equal(testType1.U2, testType2.U2);
     Assert.Equal(testType1.I4, testType2.I4);
     Assert.Equal(testType1.U4, testType2.U4);
     Assert.Equal(testType1.I8, testType2.I8);
     Assert.Equal(testType1.U8, testType2.U8);
     Assert.Equal(testType1.R4, testType2.R4);
     Assert.Equal(testType1.R8, testType2.R8);
     Assert.Equal(testType1.Tx.ToString(), testType2.Tx.ToString());
     Assert.True(testType1.Ts.Equals(testType2.Ts));
     Assert.True(testType1.Dt.Equals(testType2.Dt));
     Assert.True(testType1.Dz.Equals(testType2.Dz));
 }
示例#10
0
        public void PredictionEngineModelDisposal()
        {
            var mlContext = new MLContext(seed: 1);
            var data      = mlContext.Data.LoadFromEnumerable(TypeTestData.GenerateDataset());
            var pipeline  = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(
                new Trainers.LbfgsLogisticRegressionBinaryTrainer.Options {
                NumberOfThreads = 1
            });
            var model = pipeline.Fit(data);

            var engine = mlContext.Model.CreatePredictionEngine <TypeTestData, Prediction>(model, new PredictionEngineOptions());

            // Dispose of prediction engine, should dispose of model
            engine.Dispose();

            // Get disposed flag using reflection
            var bfIsDisposed = BindingFlags.Instance | BindingFlags.NonPublic;
            var field        = model.GetType().BaseType.BaseType.GetField("_disposed", bfIsDisposed);

            // Make sure the model is actually disposed
            Assert.True((bool)field.GetValue(model));

            // Make a new model/prediction engine. Set the options so prediction engine doesn't dispose
            model = pipeline.Fit(data);

            var options = new PredictionEngineOptions()
            {
                OwnsTransformer = false
            };

            engine = mlContext.Model.CreatePredictionEngine <TypeTestData, Prediction>(model, options);

            // Dispose of prediction engine, shouldn't dispose of model
            engine.Dispose();

            // Make sure model is not disposed of.
            Assert.False((bool)field.GetValue(model));

            // Dispose of the model for test cleanliness
            model.Dispose();
        }