[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // This test is being fixed as part of issue #1441. public void MatrixFactorizationSimpleTrainAndPredict() { var mlContext = new MLContext(seed: 1, conc: 1); // Specific column names of the considered data set string labelColumnName = "Label"; string userColumnName = "User"; string itemColumnName = "Item"; string scoreColumnName = "Score"; // Create reader for both of training and test data sets var reader = new TextLoader(mlContext, GetLoaderArgs(labelColumnName, userColumnName, itemColumnName)); // Read training data as an IDataView object var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.trainFilename))); // Create a pipeline with a single operator. var pipeline = new MatrixFactorizationTrainer(mlContext, userColumnName, itemColumnName, labelColumnName, advancedSettings: s => { s.NumIterations = 3; s.NumThreads = 1; // To eliminate randomness, # of threads must be 1. s.K = 7; }); // Train a matrix factorization model. var model = pipeline.Fit(data); // Read the test data set as an IDataView var testData = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.trivialMatrixFactorization.testFilename))); // Apply the trained model to the test set var prediction = model.Transform(testData); // Get output schema and check its column names var outputSchema = model.GetOutputSchema(data.Schema); var expectedOutputNames = new string[] { labelColumnName, userColumnName, itemColumnName, scoreColumnName }; foreach (var(i, col) in outputSchema.GetColumns()) { Assert.True(col.Name == expectedOutputNames[i]); } // Retrieve label column's index from the test IDataView testData.Schema.TryGetColumnIndex(labelColumnName, out int labelColumnId); // Retrieve score column's index from the IDataView produced by the trained model prediction.Schema.TryGetColumnIndex(scoreColumnName, out int scoreColumnId); // Compute prediction errors var metrices = mlContext.Regression.Evaluate(prediction, label: labelColumnName, score: scoreColumnName); // Determine if the selected metric is reasonable for different platforms double tolerance = Math.Pow(10, -7); if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { // Linux case var expectedUnixL2Error = 0.616821448679879; // Linux baseline Assert.InRange(metrices.L2, expectedUnixL2Error - tolerance, expectedUnixL2Error + tolerance); } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { // The Mac case is just broken. Should be fixed later. Re-enable when done. // Mac case //var expectedMacL2Error = 0.61192207960271; // Mac baseline //Assert.InRange(metrices.L2, expectedMacL2Error - 5e-3, expectedMacL2Error + 5e-3); // 1e-7 is too small for Mac so we try 1e-5 } else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { // Windows case var expectedWindowsL2Error = 0.61528733643754685; // Windows baseline Assert.InRange(metrices.L2, expectedWindowsL2Error - tolerance, expectedWindowsL2Error + tolerance); } var modelWithValidation = pipeline.Train(data, testData); }