private static void Main() { var mlContext = new MLContext(seed: 0); var dataView = mlContext.Data.LoadFromTextFile <IrisData>(DataPath, hasHeader: false, separatorChar: ','); const string featuresColumnName = "Features"; var pipeline = mlContext.Transforms .Concatenate(featuresColumnName, "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(mlContext.Clustering.Trainers.KMeans(featuresColumnName, numberOfClusters: 3)); var model = pipeline.Fit(dataView); using (var fileStream = new FileStream(ModelPath, FileMode.Create, FileAccess.Write, FileShare.Write)) { mlContext.Model.Save(model, dataView.Schema, fileStream); } var predictor = mlContext.Model.CreatePredictionEngine <IrisData, ClusterPrediction>(model); var setosa = new IrisData { SepalLength = 5.1f, SepalWidth = 3.5f, PetalLength = 1.4f, PetalWidth = 0.2f }; var prediction = predictor.Predict(setosa); Console.WriteLine($"Cluster: {prediction.PredictedClusterId}"); Console.WriteLine($"Distances: {string.Join(" ", prediction.Distances)}"); Console.ReadKey(); }
static void Main(string[] args) { Helper.PrintLine("创建 MLContext..."); MLContext mlContext = new MLContext(seed: 0); ITransformer model; if (File.Exists(ModelPath)) { Helper.PrintLine("加载神经网络模型..."); model = mlContext.Model.Load(ModelPath, out DataViewSchema inputScema); } else { // 训练数据集合 IDataView trainingDataView = mlContext.Data.LoadFromTextFile <IrisData>(TrainingDataPath, hasHeader: false, separatorChar: ','); // 创建神经网络管道 Helper.PrintLine("创建神经网络管道..."); IEstimator <ITransformer> pipeline = mlContext.Transforms .Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") // 拆分为三个集群 .Append(mlContext.Clustering.Trainers.KMeans("Features", numberOfClusters: 3)); // 开始训练神经网络 Helper.PrintSplit(); Helper.PrintLine("开始训练神经网络..."); model = pipeline.Fit(trainingDataView); Helper.PrintLine("训练神经网络完成"); Helper.PrintSplit(); Helper.PrintLine($"导出神经网络模型..."); mlContext.Model.Save(model, trainingDataView.Schema, ModelPath); } IrisData setosa = new IrisData { SepalLength = 5.1f, SepalWidth = 3.5f, PetalLength = 1.4f, PetalWidth = 0.2f }; // 预测 Helper.PrintLine("预测:"); var predictor = mlContext.Model.CreatePredictionEngine <IrisData, ClusterPrediction>(model); var prediction = predictor.Predict(setosa); Helper.PrintLine($"所属集群: {prediction.PredictedClusterId}"); Helper.PrintLine($"特征差距: {string.Join(" ", prediction.Distances)}"); Helper.Exit(0); }