//進階分析,分析行為屬於哪種攻擊手法 private static List <(string, int)> DeepAnalysis(DataFlowStatistics[] FlowStatistics) { List <(double[], string)> Inputs = DeepLearningTools.FlowStatisticsToLearningData(FlowStatistics); List <(string, int)> Result = new List <(string, int)>(); for (int i = 0; i < Inputs.Count; i++) { double[] outputValues = DBNetwork.Compute(Inputs[i].Item1); Result.Add((Inputs[i].Item2, Convert.ToInt32(DeepLearningTools.FormatOutputResult(outputValues)))); } return(Result); }
//開始學習 public bool Run() { bool IsDone = false; try { FlowDatas db = new FlowDatas(); (double[][] Inputs, double[][] Outputs) = DeepLearningTools.FlowSampleToLearningData(db.FlowSampleStatistics.Where(c => c.BehaviorNumber != 0).ToArray()); db.Dispose(); //產生DBN網路 DBNetwork = new DeepBeliefNetwork(Inputs.First().Length, (int)((Inputs.First().Length + Outputs.First().Length) / 1.5), (int)((Inputs.First().Length + Outputs.First().Length) / 2), Outputs.First().Length); //亂數打亂整個網路參數 new GaussianWeights(DBNetwork, 0.1).Randomize(); DBNetwork.UpdateVisibleWeights(); //設定無監督學習組態 DeepBeliefNetworkLearning teacher = new DeepBeliefNetworkLearning(DBNetwork) { Algorithm = (h, v, i) => new ContrastiveDivergenceLearning(h, v) { LearningRate = 0.01, Momentum = 0.5, Decay = 0.001, } }; //設置批量輸入學習。 int batchCount1 = Math.Max(1, Inputs.Length / 10); //創建小批量加速學習。 int[] groups1 = Accord.Statistics.Classes.Random(Inputs.Length, batchCount1); double[][][] batches = Inputs.Subgroups(groups1); //學習指定圖層的數據。 double[][][] layerData; //運行無監督學習。 for (int layerIndex = 0; layerIndex < DBNetwork.Machines.Count - 1; layerIndex++) { teacher.LayerIndex = layerIndex; layerData = teacher.GetLayerInput(batches); for (int i = 0; i < 200; i++) { double error = teacher.RunEpoch(layerData) / Inputs.Length; if (i % 10 == 0) { Console.WriteLine(i + ", Error = " + error); } } } //對整個網絡進行監督學習,提供輸出分類。 var teacher2 = new ParallelResilientBackpropagationLearning(DBNetwork); double error1 = double.MaxValue; //運行監督學習。 for (int i = 0; i < 500; i++) { error1 = teacher2.RunEpoch(Inputs, Outputs) / Inputs.Length; Console.WriteLine(i + ", Error = " + error1); DBNetwork.Save(Path); Console.WriteLine("Save Done"); } DBNetwork.Save(Path); Console.WriteLine("Save Done"); IsDone = true; } catch (Exception ex) { Debug.Write(ex.ToString()); } return(IsDone); }