Ejemplo n.º 1
0
        /// <summary>
        /// Trains a graph for a fixed number of iterations
        /// </summary>
        /// <param name="engine">The graph training engine</param>
        /// <param name="numIterations">The number of iterations to train for</param>
        /// <param name="testData">The test data source to use</param>
        /// <param name="errorMetric">The error metric to evaluate the test data against</param>
        /// <param name="onImprovement">Optional callback for when the test data score has improved against the error metric</param>
        /// <param name="testCadence">Determines how many epochs elapse before the test data is evaluated</param>
        public static void Train(this IGraphTrainingEngine engine, int numIterations,
                                 IDataSource testData, IErrorMetric errorMetric, Action <GraphModel> onImprovement = null,
                                 int testCadence = 1)
        {
            var executionContext = new ExecutionContext(engine.LinearAlgebraProvider);

            engine.Test(testData, errorMetric, 128,
                        percentage => Console.Write("\rTesting... ({0:P})    ", percentage));
            int count = 0;

            for (var i = 0; i < numIterations; i++)
            {
                engine.Train(executionContext,
                             percentage => Console.Write("\rTraining... ({0:P})    ", percentage));
                if (++count == testCadence)
                {
                    if (engine.Test(testData, errorMetric, 128,
                                    percentage => Console.Write("\rTesting... ({0:P})    ", percentage)) &&
                        onImprovement != null)
                    {
                        var bestModel = new GraphModel {
                            Graph = engine.Graph
                        };
                        if (engine.DataSource is IAdaptiveDataSource adaptiveDataSource)
                        {
                            bestModel.DataSource = adaptiveDataSource.GetModel();
                        }
                        onImprovement(bestModel);
                    }

                    count = 0;
                }
            }
        }
 public NonNegativeMatrixFactorisation(ILinearAlgebraProvider lap, int numClusters,
                                       IErrorMetric costFunction = null)
 {
     _lap          = lap;
     _numClusters  = numClusters;
     _costFunction = costFunction ?? new Quadratic();
 }
Ejemplo n.º 3
0
        public NNMF(ILinearAlgebraProvider lap, IReadOnlyList <IIndexableVector> data, int numClusters, IErrorMetric costFunction = null)
        {
            _lap          = lap;
            _data         = data;
            _numClusters  = numClusters;
            _costFunction = costFunction ?? ErrorMetricType.RMSE.Create();

            // create the main matrix
            var rand = new Random();

            _dataMatrix = _lap.Create(data.Count, data.First().Count, (x, y) => data[x][y]);

            // create the weights and features
            _weights  = _lap.Create(_dataMatrix.RowCount, _numClusters, (x, y) => Convert.ToSingle(rand.NextDouble()));
            _features = _lap.Create(_numClusters, _dataMatrix.ColumnCount, (x, y) => Convert.ToSingle(rand.NextDouble()));
        }
Ejemplo n.º 4
0
 /// <summary>
 /// Adds backpropagation through time
 /// </summary>
 /// <param name="errorMetric">Error metric to calculate the error signal</param>
 /// <param name="name">Optional name to give the node</param>
 /// <returns></returns>
 public WireBuilder AddBackpropagationThroughTime(IErrorMetric errorMetric, string name = null)
 {
     AddForwardAction(new BackpropagateThroughTime(errorMetric), name);
     return(this);
 }
Ejemplo n.º 5
0
 /// <summary>
 /// Calculates the error of the output against the target
 /// </summary>
 /// <param name="errorMetric">The error metric to calculate with</param>
 /// <returns></returns>
 public float CalculateError(IErrorMetric errorMetric) => Output.Zip(Target, (o, t) => errorMetric.Compute(o, t)).Average();
Ejemplo n.º 6
0
 public LowTransformFinder()
 {
     this.errorMetric = ErrorMetric.Low();
 }
Ejemplo n.º 7
0
		public void Initialise(string data)
		{
			_errorMetric = (IErrorMetric)Activator.CreateInstance(TypeLoader.LoadType(data));
		}
Ejemplo n.º 8
0
		public Backpropagate(IErrorMetric errorMetric)
		{
			_errorMetric = errorMetric;
		}
Ejemplo n.º 9
0
 public ITrainingContext CreateContext(float trainingRate, int batchSize, IErrorMetric errorMetric)
 {
     return(new TrainingContext(trainingRate, batchSize, errorMetric));
 }
Ejemplo n.º 10
0
 public BackpropagateThroughTime(IErrorMetric errorMetric)
 {
     _errorMetric = errorMetric;
 }
Ejemplo n.º 11
0
 public ITrainingContext CreateTrainingContext(IErrorMetric errorMetric, float learningRate, int batchSize)
 {
     return(new TrainingContext(learningRate, batchSize, errorMetric));
 }
Ejemplo n.º 12
0
 public HornTransformFinder()
 {
     this.errorMetric = ErrorMetric.Horn();
 }
Ejemplo n.º 13
0
 public TrainingContext(float trainingRate, int miniBatchSize, IErrorMetric errorMetric)
 {
     _miniBatchSize = miniBatchSize;
     TrainingRate   = trainingRate;
     _errorMetric   = errorMetric;
 }