public RMSPropWeightUpdater Clone() { RMSPropWeightUpdater copy = new RMSPropWeightUpdater(_n.Length, EwmaDecayRate, Momentum, LearningRate, RMSRegularizer); Array.Copy(_n, copy._n, _n.Length); Array.Copy(_g, copy._g, _g.Length); Array.Copy(_delta, copy._delta, _delta.Length); return(copy); }
static void Main() { DataStream reportStream = null; try { YoVisionClientHelper yoVisionClientHelper = new YoVisionClientHelper(); yoVisionClientHelper.Connect(EndpointType.NetTcp, 8081, "localhost", "YoVisionServer"); reportStream = yoVisionClientHelper.RegisterDataStream("Copy task training", new Int32DataType("Iteration"), new DoubleDataType("Average data loss"), new Int32DataType("Training time"), new Int32DataType("Sequence length")); } catch (Exception ex) { Console.WriteLine(ex.Message); } double[] errors = new double[100]; long[] times = new long[100]; for (int i = 0; i < 100; i++) { errors[i] = 1; } const int seed = 32702; Console.WriteLine(seed); //TODO args parsing shit Random rand = new Random(seed); const int vectorSize = 8; const int controllerSize = 100; const int headsCount = 1; const int memoryN = 128; const int memoryM = 20; const int inputSize = vectorSize + 2; const int outputSize = vectorSize; //TODO remove rand NeuralTuringMachine machine = new NeuralTuringMachine(vectorSize + 2, vectorSize, controllerSize, headsCount, memoryN, memoryM, new RandomWeightInitializer(rand)); //TODO extract weight count calculation int headUnitSize = Head.GetUnitSize(memoryM); var weightsCount = (headsCount * memoryN) + (memoryN * memoryM) + (controllerSize * headsCount * memoryM) + (controllerSize * inputSize) + (controllerSize) + (outputSize * (controllerSize + 1)) + (headsCount * headUnitSize * (controllerSize + 1)); Console.WriteLine(weightsCount); RMSPropWeightUpdater rmsPropWeightUpdater = new RMSPropWeightUpdater(weightsCount, 0.95, 0.5, 0.001, 0.001); //NeuralTuringMachine machine2 = NeuralTuringMachine.Load(@"NTM2015-03-22T210312"); BPTTTeacher teacher = new BPTTTeacher(machine, rmsPropWeightUpdater); for (int i = 1; i < 10000; i++) { Tuple<double[][], double[][]> sequence = SequenceGenerator.GenerateSequence(rand.Next(20) + 1, vectorSize); Stopwatch stopwatch = new Stopwatch(); stopwatch.Start(); double[][] machinesOutput = teacher.Train(sequence.Item1, sequence.Item2); stopwatch.Stop(); times[i%100] = stopwatch.ElapsedMilliseconds; double error = CalculateLogLoss(sequence.Item2, machinesOutput); double averageError = error / (sequence.Item2.Length * sequence.Item2[0].Length); errors[i % 100] = averageError; if (reportStream != null) { reportStream.Set("Iteration", i); reportStream.Set("Average data loss", averageError); reportStream.Set("Training time", stopwatch.ElapsedMilliseconds); reportStream.Set("Sequence length", (sequence.Item1.Length - 2)/2); reportStream.SendData(); } if (i % 100 == 0) { Console.WriteLine("Iteration: {0}, average error: {1}, iterations per second: {2:0.0}", i, errors.Average(), 1000/times.Average()); } } machine.Save("NTM"+DateTime.Now.ToString("s").Replace(":","")); }