Esempio n. 1
0
        /// <summary>
        /// Train one fold.
        /// </summary>
        /// <param name="k">The fold id.</param>
        /// <param name="fold">The fold.</param>
        public void TrainFold(int k, CrossValidateFold fold)
        {
            int noImprove = 0;
            double localBest = 0;

            // Get the training and cross validation sets.
            IList<BasicData> training = fold.TrainingSet;
            IList<BasicData> validation = fold.ValidationSet;

            // Create random particles for the RBF.
            IGenerateRandom rnd = new MersenneTwisterGenerateRandom();
            var particles = new RBFNetwork[TitanicConfig.ParticleCount];
            for (int i = 0; i < particles.Length; i++)
            {
                particles[i] = new RBFNetwork(TitanicConfig.InputFeatureCount, TitanicConfig.RbfCount, 1);
                particles[i].Reset(rnd);
            }

            /**
             * Construct a network to hold the best network.
             */
            if (_bestNetwork == null)
            {
                _bestNetwork = new RBFNetwork(TitanicConfig.InputFeatureCount, TitanicConfig.RbfCount, 1);
            }

            /**
             * Setup the scoring function.
             */
            IScoreFunction score = new ScoreTitanic(training);
            IScoreFunction scoreValidate = new ScoreTitanic(validation);

            /**
             * Setup particle swarm.
             */
            bool done = false;
            var train = new TrainPSO(particles, score);
            int iterationNumber = 0;
            var line = new StringBuilder();

            do
            {
                iterationNumber++;

                train.Iteration();

                var best = (RBFNetwork) train.BestParticle;

                double trainingScore = train.LastError;
                double validationScore = scoreValidate.CalculateScore(best);

                if (validationScore > _bestScore)
                {
                    Array.Copy(best.LongTermMemory, 0, _bestNetwork.LongTermMemory, 0, best.LongTermMemory.Length);
                    _bestScore = validationScore;
                }

                if (validationScore > localBest)
                {
                    noImprove = 0;
                    localBest = validationScore;
                }
                else
                {
                    noImprove++;
                }

                line.Length = 0;
                line.Append("Fold #");
                line.Append(k + 1);
                line.Append(", Iteration #");
                line.Append(iterationNumber);
                line.Append(": training correct: ");
                line.Append(trainingScore);
                line.Append(", validation correct: ");
                line.Append(validationScore);
                line.Append(", no improvement: ");
                line.Append(noImprove);

                if (noImprove > TitanicConfig.AllowNoImprovement)
                {
                    done = true;
                }

                Console.WriteLine(line.ToString());
            } while (!done);

            fold.Score = localBest;
        }
Esempio n. 2
0
        /// <summary>
        ///     Run the example.
        /// </summary>
        public void Process()
        {
            // read the iris data from the resources
            Assembly assembly = Assembly.GetExecutingAssembly();
            Stream res = assembly.GetManifestResourceStream("AIFH_Vol2.Resources.iris.csv");

            // did we fail to read the resouce
            if (res == null)
            {
                Console.WriteLine("Can't read iris data from embedded resources.");
                return;
            }

            // load the data
            var istream = new StreamReader(res);
            DataSet ds = DataSet.Load(istream);
            istream.Close();

            IGenerateRandom rnd = new MersenneTwisterGenerateRandom();

            // The following ranges are setup for the Iris data set.  If you wish to normalize other files you will
            // need to modify the below function calls other files.
            ds.NormalizeRange(0, -1, 1);
            ds.NormalizeRange(1, -1, 1);
            ds.NormalizeRange(2, -1, 1);
            ds.NormalizeRange(3, -1, 1);
            IDictionary<string, int> species = ds.EncodeOneOfN(4);

            var particles = new RBFNetwork[ParticleCount];
            for (int i = 0; i < particles.Length; i++)
            {
                particles[i] = new RBFNetwork(4, 4, 3);
                particles[i].Reset(rnd);
            }

            IList<BasicData> trainingData = ds.ExtractSupervised(0, 4, 4, 3);

            IScoreFunction score = new ScoreRegressionData(trainingData);

            var train = new TrainPSO(particles, score);

            PerformIterations(train, 100000, 0.05, true);

            var winner = (RBFNetwork) train.BestParticle;

            QueryOneOfN(winner, trainingData, species);
        }