Ejemplo n.º 1
0
        Dictionary <string, MMNode> createModels(List <MMNodeFactory.Model> modelsToTrain)
        {
            Dictionary <string, MMNode> models = new Dictionary <string, MMNode>();

            int neuronsCount = 50;

            //for (int neuronsCount = 10; neuronsCount < 100; neuronsCount += 20)
            {
                foreach (MMNodeFactory.Model selectedModel in modelsToTrain)
                {
                    //Create the model
                    //MNNodeFactory.Model selectedModel;
                    //Enum.TryParse<MNNodeFactory.Model>(uncastedMdl.ToString(), out selectedModel);
                    models[selectedModel.ToString() + "_" + neuronsCount] = MMNodeFactory.obtain(selectedModel, neuronsCount);

                    MMNode network = models[selectedModel.ToString() + "_" + neuronsCount];
                    network.onEpoch += network_onEpoch;
                    //network.addModality( new Signal(2,1), "XY-t0");
                    //network.addModality( new Signal(2,1), "XY-t1");
                    network.addModality(new Signal(retinaSize * 4, retinaSize), "Vision-t0-Color");
                    network.addModality(new Signal(retinaSize * 2, retinaSize), "Vision-t0-Orientation");
                    //network.addModality(new Signal(retinaSize * 4, retinaSize), "Vision-t0-Shape");
                    network.addModality(new Signal(4, 1), "Saccade");
                    network.addModality(new Signal(retinaSize * 4, retinaSize), "Vision-t1-Color");
                    network.addModality(new Signal(retinaSize * 2, retinaSize), "Vision-t1-Orientation");
                    //network.addModality(new Signal(retinaSize * 4, retinaSize), "Vision-t1-Shape");

                    //Apply a treshold function on the modalities
                    network.onDivergence += network_onDivergence;
                }
            }
            return(models);
        }
Ejemplo n.º 2
0
        public static MMNode obtain(Model mdl, int nbNeurons = 20)
        {
            MMNode node = null;

            switch (mdl)
            {
            case Model.SOM:
                node = new CDZNET.Core.MMNodeSOM(new CDZNET.Point2D(nbNeurons, nbNeurons), false);     //Here you specify which algo to be used
                (node as MMNodeSOM).learningRate            = 0.1;
                (node as MMNodeSOM).elasticity              = 2.0;
                (node as MMNodeSOM).activityRatioToConsider = 1.0;
                break;

            case Model.LUT:
                node = new CDZNET.Core.MMNodeLookupTable(new Point2D(1, 1));     //Here you specify which algo to be used
                (node as MMNodeLookupTable).TRESHOLD_SIMILARITY = 0.01;
                (node as MMNodeLookupTable).learningRate        = 0.5;
                break;

            //case Model.Matlab:
            //    node = new CDZNET.Core.MMNodeMatLab(new CDZNET.Point2D(1, 1),            //This is the size of the output (so far not set in matlab case)
            //        "CA3",                                                              //This is the name of the variable corresponding to this node in Matlab
            //        "D:/robotology/src/Myline/CDZ.NET/CDZ.NET/CDZ.NET/Core/MM/Matlab",  //Path where the script is located
            //        "dummyConvergenceDivergence"                                        //name of the function/script
            //        );
            //    break;

            case Model.MWSOM:
                node = new CDZNET.Core.MMNodeMWSOM(new CDZNET.Point2D(nbNeurons, nbNeurons));
                break;

            case Model.DeepBelief:
                node = new CDZNET.Core.MMNodeDeepBeliefNetwork(new CDZNET.Point2D(1, 1), new int[] { nbNeurons *nbNeurons });
                break;

            case Model.AFSOM:
                node = new CDZNET.Core.MMNodeAFSOM(new CDZNET.Point2D(nbNeurons, nbNeurons));
                break;

            case Model.MLP:
                node = new CDZNET.Core.MMNodeMLP(new CDZNET.Point2D(1, 1), 75, 25, 25, 75);     //Here you specify which algo to be used
                break;

            default:
                throw new Exception("Unknown model type.");
            }
            return(node);
        }
Ejemplo n.º 3
0
        public HippocampusForm()
        {
            InitializeComponent();
            //--------------------------------------------------------------------------
            //-------------MEC
            mec = new List <IONode>();
            int mecAreas = 1;

            for (int i = 1; i <= mecAreas; i++)
            {
                mec.Add(
                    new IONodeGridCells
                    (
                        new Point2D(2, 1),          //Input, we cheat, this is the odometer values X, Y
                        new Point2D(10, 10),        //Number of grid cells This defines the granularity, it should vary among the MEC
                        0.1,                        //rf size
                        0.5 * i                     //spacing
                    ));

                flowLayoutPanelMEC.Controls.Add(mec.Last().GetCtrl());
            }

            //--------------------------------------------------------------------------
            //--------------LEC
            lec = new List <IONode>();
            int lecAreas = 1;

            for (int i = 1; i <= lecAreas; i++)
            {
                lec.Add(
                    new IONodeAdaptiveSOM
                    (
                        new Point2D(foveaSize, foveaSize), //Input, size of the fovea
                        new Point2D(20, 20),               //Size of the map. Defines the number of templates/filter used.
                        true)                              //USe only winner as output
                    );

                flowLayoutPanelLEC.Controls.Add(lec.Last().GetCtrl());
            }


            //--------------------------------------------------------------------------
            //--------------CA3
            CA3Inputs = new List <SignalLink>();
            CA3       = new MMNodeSOM
                        (
                new Point2D(20, 20),        //Size of the map.
                true                        //Use only the winner for prediction
                        );

            //Add all the MEC modalities
            int counter = 0;

            foreach (IONode n in mec)
            {
                SignalLink link = new SignalLink(n.output, new Signal(n.output));
                CA3.addModality(link.to, "MEC_" + counter++); //note: n.output is cloned, not a reference
                CA3Inputs.Add(link);
            }

            //Add the LEC modalities
            counter = 0;
            foreach (IONode n in lec)
            {
                SignalLink link = new SignalLink(n.output, new Signal(n.output));
                CA3.addModality(link.to, "LEC_" + counter++); //note: n.output is cloned, not a reference
                CA3Inputs.Add(link);
            }
            ctrlMMNode1.attach(CA3);

            //--------------------------------------------------------------------------
            //--------------CA1
            CA1 = new MMNodeLookupTable
                  (
                new Point2D(1, 1)        //Size of the map.
                  );

            CA1.addModality(new Signal(foveaSize, foveaSize), "FOVEA");
            CA1.addModality(new Signal(2, 1), "ODOMETRY");
            ctrlMMNode2.attach(CA1);
        }
Ejemplo n.º 4
0
        void evaluateOnSets(Dictionary <string, MMNode> models, Dictionary <string, List <Dictionary <string, double[, ]> > > sets, string logFile, WorldType worldType)
        {
            Console.WriteLine(DateTime.Now + "\t" + "Starting to test log.");

            //Dump it into a file
            StreamWriter file = new StreamWriter(logFile, hasWrittenHeaders);

            //Write some metadata
            if (!hasWrittenHeaders)
            {
                file.WriteLine("worldWidth\t" + worldWidth);
                file.WriteLine("worldHeight\t" + worldHeight);
                file.WriteLine("seedsNumber\t" + seedsNumber);
                file.WriteLine("orientationVariability\t" + orientationVariability);
                file.WriteLine("retinaSize\t" + retinaSize);
                file.WriteLine("saccadeSize\t" + saccadeSize);
                file.WriteLine();

                //write the headers
                file.Write("WorldType,");
                file.Write("Model,");
                file.Write("SetName,");
                file.Write("InvertedBits,");
                file.Write("TrainingSetSize,");
                file.Write("AllModSumError,");
                if (EXTENSIVE_LOG)
                {
                    file.Write(getMatrixHeadingS("XY-t0", 1, 1, 2));
                    file.Write(getMatrixHeadingS("XY-t1", 1, 1, 2));
                    file.Write(getMatrixHeadingS("Vision-t0-Color", retinaSize, retinaSize, 4));
                    file.Write(getMatrixHeadingS("Vision-t0-Orientation", retinaSize, retinaSize, 2));
                    file.Write(getMatrixHeadingS("Vision-t0-Shape", retinaSize, retinaSize, 4));
                    file.Write(getMatrixHeadingS("Saccade", 1, 1, 4));
                    file.Write(getMatrixHeadingS("Vision-t1-Color", retinaSize, retinaSize, 4));
                    file.Write(getMatrixHeadingS("Vision-t1-Orientation", retinaSize, retinaSize, 2));
                    file.Write(getMatrixHeadingS("Vision-t1-Shape", retinaSize, retinaSize, 4));
                }
                foreach (Signal mod in models.First().Value.modalities)
                {
                    string modName = models.First().Value.labelsModalities[mod];
                    file.Write(getMatrixHeadingS("corruption_" + modName, 1, 1, 1));
                    if (EXTENSIVE_LOG)
                    {
                        file.Write(getMatrixHeadingS("reality_" + modName, 1, 1, mod.Width * mod.Height));
                        file.Write(getMatrixHeadingS("prediction_" + modName, 1, 1, mod.Width * mod.Height));
                    }
                    file.Write(getMatrixHeadingS("originalMaxError_" + modName, 1, 1, 1));
                    file.Write(getMatrixHeadingS("corruptedMaxError_" + modName, 1, 1, 1));
                    file.Write(getMatrixHeadingS("originalSumError_" + modName, 1, 1, 1));
                    file.Write(getMatrixHeadingS("corruptedSumError_" + modName, 1, 1, 1));
                    if (modName.Contains("Vision"))
                    {
                        file.Write("wrongPixels_" + modName + ",");
                    }
                }
                file.WriteLine();
                hasWrittenHeaders = true;
            }

            //Start the test
            foreach (string modelName in models.Keys)
            {
                Console.Write("Testing" + modelName + " ...");
                Stopwatch watch = new Stopwatch();
                watch.Start();
                MMNode network = models[modelName];

                network.learningLocked = true;
                foreach (string setName in sets.Keys)
                {
                    //We have "usedForTraining"
                    if (setName == "train")
                    {
                        continue;
                    }

                    List <Dictionary <string, double[, ]> > set = sets[setName];
                    //Test with different level of noise
                    for (double bitShiftProb = 0.0; bitShiftProb <= 1.0; bitShiftProb += 0.5)
                    {
                        foreach (Dictionary <string, double[, ]> sample in set)
                        {
                            //Set the modalities
                            int invertedBits = 0;
                            Dictionary <Signal, double> modalityCorruption = new Dictionary <Signal, double>();
                            foreach (Signal s in network.modalities)
                            {
                                modalityCorruption[s] = 0.0;

                                s.reality = sample[network.labelsModalities[s]].Clone() as double[, ];

                                //Corrupt the signal
                                if (network.labelsModalities[s].Contains("t1"))
                                {
                                    modalityCorruption[s] = bitShiftProb;
                                    //1-----------------Toggle the bit
                                    //ArrayHelper.ForEach(s.reality, false, (x, y) =>
                                    //{
                                    //    if (MathHelpers.Rand.NextDouble() < bitShiftProb)
                                    //    {
                                    //        s.reality[x, y] = Math.Abs(s.reality[x, y] - 1.0);
                                    //        invertedBits++;
                                    //    }
                                    //});

                                    //2-----------------Set bit to 0.5
                                    ArrayHelper.ForEach(s.reality, false, (x, y) =>
                                    {
                                        if (MathHelpers.Rand.NextDouble() < bitShiftProb)
                                        {
                                            s.reality[x, y] = 0.5;
                                            invertedBits++;
                                        }
                                    });
                                }
                            }

                            network.Converge();
                            network.Diverge();

                            double globalError = 0.0;
                            foreach (Signal s in network.modalities)
                            {
                                globalError += s.ComputeSumAbsoluteError();
                            }

                            //Dump the info
                            string line = "";
                            line +=
                                worldType.ToString() + "," +
                                modelName + "," +
                                setName + "," +
                                invertedBits + "," +
                                sets["usedForTraining"].Count + "," +
                                globalError + ",";

                            if (EXTENSIVE_LOG)
                            {
                                line +=
                                    GetString(sample["XY-t0"]) + "," +
                                    GetString(sample["XY-t1"]) + "," +
                                    GetString(sample["Vision-t0-Color"]) + "," +
                                    GetString(sample["Vision-t0-Orientation"]) + "," +
                                    GetString(sample["Vision-t0-Shape"]) + "," +
                                    GetString(sample["Saccade"]) + "," +
                                    GetString(sample["Vision-t1-Color"]) + "," +
                                    GetString(sample["Vision-t1-Orientation"]) + "," +
                                    GetString(sample["Vision-t1-Shape"]) + ",";
                            }


                            foreach (Signal s in network.modalities)
                            {
                                line += modalityCorruption[s] + ",";
                                if (EXTENSIVE_LOG)
                                {
                                    line +=
                                        GetString(s.reality) + "," +
                                        GetString(s.prediction) + ",";
                                }

                                line +=
                                    MathHelpers.maximumAbsoluteDistance(s.prediction, sample[network.labelsModalities[s]]) + "," +
                                    MathHelpers.maximumAbsoluteDistance(s.prediction, s.reality) + "," +
                                    MathHelpers.sumAbsoluteDistance(s.prediction, sample[network.labelsModalities[s]]) + "," +
                                    MathHelpers.sumAbsoluteDistance(s.prediction, s.reality) + ",";

                                string modName = network.labelsModalities[s];
                                if (modName.Contains("Vision"))
                                {
                                    line += countWrongPixels(modName, s.prediction, sample[network.labelsModalities[s]]).ToString() + ",";
                                }
                            }

                            file.WriteLine(line);
                        }
                    }
                }
                file.Flush();
                watch.Stop();
                Console.WriteLine("Done {0}", watch.Elapsed);
            }
            file.Close();
            Console.WriteLine(DateTime.Now + "\t" + "Test log written.");
        }