Пример #1
0
        public void StartTrainingAsync(Network network, TrainingCallback callback)
        {
            if (alreadyDisposed)
            {
                throw new ObjectDisposedException("Object already disposed");
            }

            Action <Network, TrainingCallback> action = Train;

            action.BeginInvoke(network, callback, action.EndInvoke, action);
        }
Пример #2
0
        public void Train(Network network, TrainingCallback callback)
        {
            IActivationFunction activationFunctionInput = network.GetActivation(0);
            int outputNeurons = network.GetLayerNeuronCount(network.LayerCount - 1);
            double error = 0;
            callback.Invoke(TrainingStatus.FillingStandardInputs, 0, 0, 0); /*First operation is filling standard input/outputs*/
            Dictionary<int, List<BasicMLData>> trackIdFingerprints = GetNormalizedTrackFingerprints(activationFunctionInput, trainingSongSnippets, outputNeurons);
            workingThread = Thread.CurrentThread;
            IActivationFunction activationFunctionOutput = network.GetActivation(network.LayerCount - 1);
            double[][] normalizedBinaryCodes = GetNormalizedBinaryCodes(activationFunctionOutput, outputNeurons);
            Tuple<double[][], double[][]> tuple = FillStandardInputsOutputs(trackIdFingerprints, normalizedBinaryCodes); /*Fill standard input output*/
            double[][] inputs = tuple.Item1;
            double[][] outputs = tuple.Item2;

            if (inputs == null || outputs == null)
            {
                callback.Invoke(TrainingStatus.Exception, 0, 0, 0);
                return;
            }

            int currentIterration = 0;
            double correctOutputs = 0.0;
            BasicNeuralDataSet dataset = new BasicNeuralDataSet(inputs, outputs);
            ITrain learner = new ResilientPropagation(network, dataset);
            try
            {
                // Dynamic output reordering cycle
                /*Idyn = 50*/
                for (int i = 0; i < Idyn; i++)
                {
                    if (paused)
                    {
                        pauseSem.WaitOne();
                    }

                    correctOutputs = NetworkPerformanceMeter.MeasurePerformance(network, dataset);
                    callback.Invoke(TrainingStatus.OutputReordering, correctOutputs, error, currentIterration);
                    ReorderOutput(network, dataset, trackIdFingerprints, normalizedBinaryCodes);
                    /*Edyn = 10*/
                    for (int j = 0; j < Edyn; j++)
                    {
                        if (paused)
                        {
                            pauseSem.WaitOne();
                        }

                        correctOutputs = NetworkPerformanceMeter.MeasurePerformance(network, dataset);
                        callback.Invoke(TrainingStatus.RunningDynamicEpoch, correctOutputs, error, currentIterration);
                        learner.Iteration();
                        error = learner.Error;
                        currentIterration++;
                    }
                }

                for (int i = 0; i < Efixed; i++)
                {
                    if (paused)
                    {
                        pauseSem.WaitOne();
                    }

                    correctOutputs = NetworkPerformanceMeter.MeasurePerformance(network, dataset);
                    callback.Invoke(TrainingStatus.FixedTraining, correctOutputs, error, currentIterration);
                    learner.Iteration();
                    error = learner.Error;
                    currentIterration++;
                }

                network.ComputeMedianResponses(inputs, trainingSongSnippets);
                callback.Invoke(TrainingStatus.Finished, correctOutputs, error, currentIterration);
            }
            catch (ThreadAbortException)
            {
                callback.Invoke(TrainingStatus.Aborted, correctOutputs, error, currentIterration);
                paused = false;
            }
        }
Пример #3
0
        public void StartTrainingAsync(Network network, TrainingCallback callback)
        {
            if (alreadyDisposed)
            {
                throw new ObjectDisposedException("Object already disposed");
            }

            Action<Network, TrainingCallback> action = Train;
            action.BeginInvoke(network, callback, action.EndInvoke, action);
        }
Пример #4
0
        public Task <Network> Train(NetworkConfiguration networkConfiguration, int[] spectralImagesIndexesToConsider, TrainingCallback callback)
        {
            var network = networkFactory.Create(
                networkConfiguration.ActivationFunction,
                DefaultFingerprintSize,
                networkConfiguration.HiddenLayerCount,
                networkConfiguration.OutputCount);

            var spectralImagesToTrain = trainingDataProvider.GetSpectralImagesToTrain(
                spectralImagesIndexesToConsider, (int)System.Math.Pow(2, networkConfiguration.OutputCount));
            var trainingSet = trainingDataProvider.MapSpectralImagesToBinaryOutputs(
                spectralImagesToTrain, networkConfiguration.OutputCount);

            normalizeStrategy.NormalizeInputInPlace(networkConfiguration.ActivationFunction, trainingSet.Inputs);
            normalizeStrategy.NormalizeOutputInPlace(networkConfiguration.ActivationFunction, trainingSet.Outputs);
            return(Task.Factory.StartNew(
                       () =>
            {
                var dataset = new BasicNeuralDataSet(trainingSet.Inputs, trainingSet.Outputs);
                var learner = new Backpropagation(network, dataset);
                double correctOutputs = 0.0;
                for (int idynIndex = 0; idynIndex < Idyn; idynIndex++)
                {
                    correctOutputs = networkPerformanceMeter.MeasurePerformance(
                        network, dataset, networkConfiguration.ActivationFunction);
                    callback(TrainingStatus.OutputReordering, correctOutputs, learner.Error, idynIndex * Edyn);
                    var bestPairs = GetBestPairsForReordering(
                        (int)System.Math.Pow(2, networkConfiguration.OutputCount), network, spectralImagesToTrain, trainingSet);
                    ReorderOutputsAccordingToBestPairs(bestPairs, trainingSet, dataset);

                    for (int edynIndex = 0; edynIndex < Edyn; edynIndex++)
                    {
                        correctOutputs = networkPerformanceMeter.MeasurePerformance(
                            network, dataset, networkConfiguration.ActivationFunction);
                        callback(
                            TrainingStatus.RunningDynamicEpoch,
                            correctOutputs,
                            learner.Error,
                            (idynIndex * Edyn) + edynIndex);
                        learner.Iteration();
                    }
                }

                for (int efixedIndex = 0; efixedIndex < Efixed; efixedIndex++)
                {
                    correctOutputs = networkPerformanceMeter.MeasurePerformance(
                        network, dataset, networkConfiguration.ActivationFunction);
                    callback(
                        TrainingStatus.FixedTraining, correctOutputs, learner.Error, (Idyn * Edyn) + efixedIndex);
                    learner.Iteration();
                }

                network.ComputeMedianResponses(trainingSet.Inputs, TrainingSongSnippets);
                callback(TrainingStatus.Finished, correctOutputs, learner.Error, (Idyn * Edyn) + Efixed);
                return network;
            }));
        }
Пример #5
0
        /* Method: SetCallback

           Sets the callback function for use during training. The user_data is passed to
           the callback. It can point to arbitrary data that the callback might require and
           can be NULL if it is not used.

           See <FANN::callback_type at http://libfann.github.io/fann/docs/files/fann_data_cpp-h.html#callback_type> for more information about the callback function.

           The default callback function simply prints out some status information.

           This function appears in FANN >= 2.0.0.
         */
        public void SetCallback(TrainingCallback callback, Object userData)
        {
            Callback = callback;
            UserData = userData;
            GCHandle handle = GCHandle.Alloc(userData);
            UnmanagedCallback = new training_callback(InternalCallback);
            fannfloatPINVOKE.neural_net_set_callback(neural_net.getCPtr(this.net), Marshal.GetFunctionPointerForDelegate(UnmanagedCallback), (IntPtr)handle);
        }
Пример #6
0
        public void Train(Network network, TrainingCallback callback)
        {
            IActivationFunction activationFunctionInput = network.GetActivation(0);
            int    outputNeurons = network.GetLayerNeuronCount(network.LayerCount - 1);
            double error         = 0;

            callback.Invoke(TrainingStatus.FillingStandardInputs, 0, 0, 0); /*First operation is filling standard input/outputs*/
            Dictionary <int, List <BasicMLData> > trackIdFingerprints = GetNormalizedTrackFingerprints(activationFunctionInput, trainingSongSnippets, outputNeurons);

            workingThread = Thread.CurrentThread;
            IActivationFunction activationFunctionOutput = network.GetActivation(network.LayerCount - 1);

            double[][] normalizedBinaryCodes     = GetNormalizedBinaryCodes(activationFunctionOutput, outputNeurons);
            Tuple <double[][], double[][]> tuple = FillStandardInputsOutputs(trackIdFingerprints, normalizedBinaryCodes); /*Fill standard input output*/

            double[][] inputs  = tuple.Item1;
            double[][] outputs = tuple.Item2;

            if (inputs == null || outputs == null)
            {
                callback.Invoke(TrainingStatus.Exception, 0, 0, 0);
                return;
            }

            int                currentIterration = 0;
            double             correctOutputs    = 0.0;
            BasicNeuralDataSet dataset           = new BasicNeuralDataSet(inputs, outputs);
            ITrain             learner           = new ResilientPropagation(network, dataset);

            try
            {
                // Dynamic output reordering cycle
                /*Idyn = 50*/
                for (int i = 0; i < Idyn; i++)
                {
                    if (paused)
                    {
                        pauseSem.WaitOne();
                    }

                    correctOutputs = NetworkPerformanceMeter.MeasurePerformance(network, dataset);
                    callback.Invoke(TrainingStatus.OutputReordering, correctOutputs, error, currentIterration);
                    ReorderOutput(network, dataset, trackIdFingerprints, normalizedBinaryCodes);
                    /*Edyn = 10*/
                    for (int j = 0; j < Edyn; j++)
                    {
                        if (paused)
                        {
                            pauseSem.WaitOne();
                        }

                        correctOutputs = NetworkPerformanceMeter.MeasurePerformance(network, dataset);
                        callback.Invoke(TrainingStatus.RunningDynamicEpoch, correctOutputs, error, currentIterration);
                        learner.Iteration();
                        error = learner.Error;
                        currentIterration++;
                    }
                }

                for (int i = 0; i < Efixed; i++)
                {
                    if (paused)
                    {
                        pauseSem.WaitOne();
                    }

                    correctOutputs = NetworkPerformanceMeter.MeasurePerformance(network, dataset);
                    callback.Invoke(TrainingStatus.FixedTraining, correctOutputs, error, currentIterration);
                    learner.Iteration();
                    error = learner.Error;
                    currentIterration++;
                }

                network.ComputeMedianResponses(inputs, trainingSongSnippets);
                callback.Invoke(TrainingStatus.Finished, correctOutputs, error, currentIterration);
            }
            catch (ThreadAbortException)
            {
                callback.Invoke(TrainingStatus.Aborted, correctOutputs, error, currentIterration);
                paused = false;
            }
        }