示例#1
0
        public RMSPropWeightUpdater Clone()
        {
            RMSPropWeightUpdater copy = new RMSPropWeightUpdater(_n.Length, EwmaDecayRate, Momentum, LearningRate, RMSRegularizer);

            Array.Copy(_n, copy._n, _n.Length);
            Array.Copy(_g, copy._g, _g.Length);
            Array.Copy(_delta, copy._delta, _delta.Length);
            return(copy);
        }
示例#2
0
文件: Program.cs 项目: weizh/NTM-1
        static void Main()
        {
            DataStream reportStream = null;
            try
            {
                YoVisionClientHelper yoVisionClientHelper = new YoVisionClientHelper();
                yoVisionClientHelper.Connect(EndpointType.NetTcp, 8081, "localhost", "YoVisionServer");
                reportStream = yoVisionClientHelper.RegisterDataStream("Copy task training",
                    new Int32DataType("Iteration"),
                    new DoubleDataType("Average data loss"),
                    new Int32DataType("Training time"),
                    new Int32DataType("Sequence length"));
            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.Message);
            }

            double[] errors = new double[100];
            long[] times = new long[100];
            for (int i = 0; i < 100; i++)
            {
                errors[i] = 1;
            }
            const int seed = 32702;
            Console.WriteLine(seed);
            //TODO args parsing shit
            Random rand = new Random(seed);

            const int vectorSize = 8;
            const int controllerSize = 100;
            const int headsCount = 1;
            const int memoryN = 128;
            const int memoryM = 20;
            const int inputSize = vectorSize + 2;
            const int outputSize = vectorSize;

            //TODO remove rand
            NeuralTuringMachine machine = new NeuralTuringMachine(vectorSize + 2, vectorSize, controllerSize, headsCount, memoryN, memoryM, new RandomWeightInitializer(rand));

            //TODO extract weight count calculation
            int headUnitSize = Head.GetUnitSize(memoryM);

            var weightsCount =
                (headsCount * memoryN) +
                (memoryN * memoryM) +
                (controllerSize * headsCount * memoryM) +
                (controllerSize * inputSize) +
                (controllerSize) +
                (outputSize * (controllerSize + 1)) +
                (headsCount * headUnitSize * (controllerSize + 1));

            Console.WriteLine(weightsCount);

            RMSPropWeightUpdater rmsPropWeightUpdater = new RMSPropWeightUpdater(weightsCount, 0.95, 0.5, 0.001, 0.001);

            //NeuralTuringMachine machine2 = NeuralTuringMachine.Load(@"NTM2015-03-22T210312");

            BPTTTeacher teacher = new BPTTTeacher(machine, rmsPropWeightUpdater);

            for (int i = 1; i < 10000; i++)
            {
                Tuple<double[][], double[][]> sequence = SequenceGenerator.GenerateSequence(rand.Next(20) + 1, vectorSize);
                Stopwatch stopwatch = new Stopwatch();
                stopwatch.Start();
                double[][] machinesOutput = teacher.Train(sequence.Item1, sequence.Item2);
                stopwatch.Stop();
                times[i%100] = stopwatch.ElapsedMilliseconds;

                double error = CalculateLogLoss(sequence.Item2, machinesOutput);
                double averageError = error / (sequence.Item2.Length * sequence.Item2[0].Length);

                errors[i % 100] = averageError;

                if (reportStream != null)
                {
                    reportStream.Set("Iteration", i);
                    reportStream.Set("Average data loss", averageError);
                    reportStream.Set("Training time", stopwatch.ElapsedMilliseconds);
                    reportStream.Set("Sequence length", (sequence.Item1.Length - 2)/2);
                    reportStream.SendData();
                }

                if (i % 100 == 0)
                {
                    Console.WriteLine("Iteration: {0}, average error: {1}, iterations per second: {2:0.0}", i, errors.Average(), 1000/times.Average());

                }
            }

            machine.Save("NTM"+DateTime.Now.ToString("s").Replace(":",""));
        }