Exemplo n.º 1
0
        public static void Save(this Alpha2 network)
        {
            string Folder = "NeuralNets".GetMyDocs();

            if (!Directory.Exists(Folder))
            {
                Directory.CreateDirectory(Folder);
            }
            string FileName = Folder + "\\AlphaNetwork";

            FileName += ".bin";
            WriteToBinaryFile(FileName, network, true);
        }
        public static NeuralNetwork GetNetwork(WriteToCMDLine write)
        {
            Alpha2        a   = new Alpha2(CMDLibrary.WriteNull);
            NeuralNetwork net = datatype.LoadNetwork(write);

            if (net.Datatype == Datatype.None)
            {
                net = new NeuralNetwork(datatype);
                net.Layers.Add(new Layer(100, a.GetSize(), Activation.LRelu, 1e-5, 1e-5));
                net.Layers.Add(new Layer(100, net.Layers.Last().Weights.GetLength(0), Activation.LRelu, 1e-5, 1e-5));
                net.Layers.Add(new Layer(40, net.Layers.Last().Weights.GetLength(0), Activation.CombinedCrossEntropySoftmax));
            }
            return(net);
        }
        public static double Propogate
            (Sample s, WriteToCMDLine write, bool tf = false)
        {
            double error = 0;

            //var Pred = Predict(s.TextInput, new WriteToCMDLine(CMDLibrary.WriteNull));

            //if (s.DesiredOutput.ToList().IndexOf(s.DesiredOutput.Max()) != Pred.ToList().IndexOf(Pred.Max()) || tf)
            {
                NeuralNetwork net     = GetNetwork(write);
                var           Samples = s.ReadSamples(24);
                Alpha2        a       = datatype.LoadAlpha(write);
                var           am      = a.CreateMemory();
                //Alpha a = new Alpha(write);
                //AlphaContext ctxt = new AlphaContext(datatype, write);
                NetworkMem MFMem = new NetworkMem(net);
                //NetworkMem AlphaMem = new NetworkMem(a.Network);
                //NetworkMem CtxtMem = new NetworkMem(ctxt.Network);

                try
                {
                    Parallel.For(0, Samples.Count(), j =>
                    {
                        //AlphaMem am = new AlphaMem(Samples[j].TextInput.ToCharArray());
                        //var output = a.Forward(Samples[j].TextInput, ctxt, am);
                        var AMem   = a.CreateAlphaMemory(Samples[j].TextInput);
                        var output = a.Forward(Samples[j].TextInput, AMem, write);
                        var F      = net.Forward(output, dropout, write);
                        error     += CategoricalCrossEntropy.Forward(F.Last().GetRank(0), Samples[j].DesiredOutput).Max();

                        var DValues = net.Backward(F, Samples[j].DesiredOutput, MFMem, write);
                        a.Backward(Samples[j].TextInput, DValues, AMem, am, write);
                        //a.Backward(Samples[j].TextInput, DValues, ctxt, am, AlphaMem, CtxtMem);
                    });
                }
                catch (Exception e) { e.OutputError(); }
                MFMem.Update(Samples.Count(), 0.00001, net);
                a.Update(am, Samples.Count());
                //AlphaMem.Update(Samples.Count(), 0.00001, a.Network);
                //CtxtMem.Update(Samples.Count(), 0.00001, ctxt.Network);
                write("Pre Training Error : " + error);

                net.Save();
                a.Save();
                //a.Network.Save();
                //ctxt.Network.Save(Datatype.Masterformat);

                error = 0;
                Parallel.For(0, Samples.Count(), j =>
                {
                    var AMem   = a.CreateAlphaMemory(Samples[j].TextInput);
                    var output = a.Forward(Samples[j].TextInput, AMem, write);
                    var F      = net.Forward(output, 0, write);
                    error     += CategoricalCrossEntropy.Forward(F.Last().GetRank(0), Samples[j].DesiredOutput).Max();
                    //AlphaMem am = new AlphaMem(Samples[j].TextInput.ToCharArray());
                    //var output = a.Forward(Samples[j].TextInput, ctxt, am);
                    //var F = net.Forward(output, dropout, write);
                    //error += CategoricalCrossEntropy.Forward(F.Last().GetRank(0), Samples[j].DesiredOutput).Max();
                });
                write("Post Training Error : " + error);

                //s.Save();
            }
            return(error);
        }