コード例 #1
0
        [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441.
        public void MatrixFactorizationSimpleTrainAndPredict()
        {
            var mlContext = new MLContext(seed: 1, conc: 1);

            // Specific column names of the considered data set
            string labelColumnName = "Label";
            string userColumnName  = "User";
            string itemColumnName  = "Item";
            string scoreColumnName = "Score";

            // Create reader for both of training and test data sets
            var reader = new TextLoader(mlContext, GetLoaderArgs(labelColumnName, userColumnName, itemColumnName));

            // Read training data as an IDataView object
            var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.trainFilename)));

            // Create a pipeline with a single operator.
            var pipeline = new MatrixFactorizationTrainer(mlContext, userColumnName, itemColumnName, labelColumnName,
                                                          advancedSettings: s =>
            {
                s.NumIterations = 3;
                s.NumThreads    = 1;  // To eliminate randomness, # of threads must be 1.
                s.K             = 7;
            });

            // Train a matrix factorization model.
            var model = pipeline.Fit(data);

            // Read the test data set as an IDataView
            var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename)));

            // Apply the trained model to the test set
            var prediction = model.Transform(testData);

            // Get output schema and check its column names
            var outputSchema        = model.GetOutputSchema(data.Schema);
            var expectedOutputNames = new string[] { labelColumnName, userColumnName, itemColumnName, scoreColumnName };

            foreach (var(i, col) in outputSchema.GetColumns())
            {
                Assert.True(col.Name == expectedOutputNames[i]);
            }

            // Retrieve label column's index from the test IDataView
            testData.Schema.TryGetColumnIndex(labelColumnName, out int labelColumnId);

            // Retrieve score column's index from the IDataView produced by the trained model
            prediction.Schema.TryGetColumnIndex(scoreColumnName, out int scoreColumnId);

            // Compute prediction errors
            var metrices = mlContext.Regression.Evaluate(prediction, label: labelColumnName, score: scoreColumnName);

            // Determine if the selected metric is reasonable for different platforms
            double tolerance = Math.Pow(10, -7);

            if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
            {
                // Linux case
                var expectedUnixL2Error = 0.616821448679879; // Linux baseline
                Assert.InRange(metrices.L2, expectedUnixL2Error - tolerance, expectedUnixL2Error + tolerance);
            }
            else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
            {
                // The Mac case is just broken. Should be fixed later. Re-enable when done.
                // Mac case
                //var expectedMacL2Error = 0.61192207960271; // Mac baseline
                //Assert.InRange(metrices.L2, expectedMacL2Error - 5e-3, expectedMacL2Error + 5e-3); // 1e-7 is too small for Mac so we try 1e-5
            }
            else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
            {
                // Windows case
                var expectedWindowsL2Error = 0.61528733643754685; // Windows baseline
                Assert.InRange(metrices.L2, expectedWindowsL2Error - tolerance, expectedWindowsL2Error + tolerance);
            }

            var modelWithValidation = pipeline.Train(data, testData);
        }