public void StartTrain(int epoch) { Sources = new CnnSource(epoch); Sources.CollectImgInfo(@"D:\AI\test_img", @"D:\AI\test_img2"); Tuple <double[, , , ], double[, ]> Train = Sources.GetNextEpoch(); MapInput = Train.Item1; LMatrix lable = Train.Item2; #region 初始化隐藏层基本属性和隐藏层权重值 var layers = new LayerAttribute[] { new LayerAttribute { Count = 26, Depth = 3 }, new LayerAttribute { Count = 52, Depth = 26 }, new LayerAttribute { Count = 52, Depth = 38 } }; LayerInfo = new LayerParamsInfo[layers.Length]; for (int index = 0; index < layers.Length; index++) { LayerInfo[index] = new LayerParamsInfo(); LayerInfo[index].CurrentLayer = index; LayerInfo[index].InitWeightBias(layers[index].Count, layers[index].Depth); } #endregion #region 准备训练对象 LOptimizer optimizer = new LAdam(0.1); ThreeLayerNet net = null; bool isRead = false; string path = @"d:\wen\good\"; for (int i = 0; i < 50000; i++) { Debug.Write("开始时间" + DateTime.Now.ToString("HH:mm:ss\n")); var inputFeature = neuralnetworkforward(); if (net == null) { net = new ThreeLayerNet(inputFeature.GetLength(1), 200, 2, isRead: isRead); } net.Input = inputFeature; double loss = net.forward(lable); Debug.Write(string.Format("{0}:{1}当前误差:{2}\n", DateTime.Now.ToString("HH:mm:ss"), i, loss)); double accuracy = net.Accuracy(lable); Debug.Write(string.Format("{0}:{1}当前识别精度:{2}\n", DateTime.Now.ToString("HH:mm:ss"), i, accuracy)); LMatrix dout = net.backward(lable); List <LMatrix> grads = net.Gradient(); List <LMatrix> param = net.GetParams(); string strWB = JsonConvert.SerializeObject(param.Select(x => x.Matrix)); if (!Directory.Exists(path + i.ToString())) { Directory.CreateDirectory(path + i.ToString()); } File.WriteAllText(path + i.ToString() + @"\wb_" + accuracy.ToString() + "_.txt", strWB); var _backdout = dout.ResetSize(38, CurRow, CurColumn); neuralnetworkbackward(_backdout); #region 卷积梯度更新 int infolen = LayerInfo.Length; for (int idx = 0; idx < infolen; idx++) { LayerInfo[idx].UpdateGradient(param, grads); } #endregion optimizer.Update(param, grads); //isRead = true; //记录每次的卷积值信息 File.WriteAllText(path + i.ToString() + @"\LayerInfo.txt", JsonConvert.SerializeObject(LayerInfo)); File.WriteAllText(path + i.ToString() + @"\ConvInfo.txt", JsonConvert.SerializeObject(ConvLayer)); } #endregion }