/// <summary>
        /// Splits the observations and targets into a training and a test set.
        /// </summary>
        /// <param name="splitter">The type of splitter used for determining the distribution of observations</param>
        /// <param name="observations">The observations for the problem</param>
        /// <param name="targets">The targets for the problem</param>
        /// <returns></returns>
        public static TrainingTestSetSplit SplitSet(this ITrainingTestIndexSplitter <double> splitter,
                                                    F64Matrix observations, double[] targets)
        {
            if (observations.RowCount != targets.Length)
            {
                throw new ArgumentException("Observations and targets has different number of rows");
            }

            var indexSplit  = splitter.Split(targets);
            var trainingSet = new ObservationTargetSet((F64Matrix)observations.Rows(indexSplit.TrainingIndices),
                                                       targets.GetIndices(indexSplit.TrainingIndices));

            var testSet = new ObservationTargetSet((F64Matrix)observations.Rows(indexSplit.TestIndices),
                                                   targets.GetIndices(indexSplit.TestIndices));

            return(new TrainingTestSetSplit(trainingSet, testSet));
        }
コード例 #2
0
        public void TrainingTestIndexSplitterExtensions_SplitSet()
        {
            var observations = new F64Matrix(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 10, 1);
            var targets      = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

            var splitter = new NoShuffleTrainingTestIndexSplitter <double>(0.6);

            var actual = splitter.SplitSet(observations, targets);

            var trainingIndices = Enumerable.Range(0, 6).ToArray();
            var testIndices     = Enumerable.Range(6, 4).ToArray();

            var expected = new TrainingTestSetSplit(
                new ObservationTargetSet((F64Matrix)observations.Rows(trainingIndices),
                                         targets.GetIndices(trainingIndices)),
                new ObservationTargetSet((F64Matrix)observations.Rows(testIndices),
                                         targets.GetIndices(testIndices)));

            Assert.AreEqual(expected, actual);
        }
コード例 #3
0
        public void F64Matrix_GetRows()
        {
            // matrix is created row wise
            var matrix = new F64Matrix(new double[] { 1, 2, 3,
                                                      4, 5, 6,
                                                      7, 8, 9 }, 3, 3);

            // returns selected rows a a new matrix
            var rows = matrix.Rows(new int[] { 0, 2 }); // [1, 2, 3,
                                                        //  7, 8, 9]
        }
コード例 #4
0
    private void UpdateClassificationResults()
    {
        float[,] meanSubVals = new float[numGroups, subVals.ColumnCount];

        for (int group = 0; group < numGroups; group++)
        {
            for (int sub = 0; sub < subVals.ColumnCount; sub++)
            {
                int[] correctRows = Global.IntArray(group * numPerGroup, (group + 1) * numPerGroup);

                IMatrix <double> membersOfGroup = subVals.Rows(correctRows);

                double[] substrateFromGroup = membersOfGroup.Column(sub);

                float meanVal = (float)Global.Sum(substrateFromGroup) / substrateFromGroup.Length;

                meanSubVals[group, sub] = meanVal;
            }
        }

        Global.LastClassificationResults.meanSubValues.Add(meanSubVals);
    }
コード例 #5
0
        public void Predict(int iterations = DefaultNNIterations, int targetOffset = 1, string targetName = DefaultTargetName, bool pauseAtEnd = false)
        {
            _iterations   = iterations;
            _targetName   = targetName;
            _targetOffset = targetOffset;

            Program.StatusLogger.Info($"Iterations: {_iterations}");
            Program.StatusLogger.Info($"Target: {_targetName}");
            Program.StatusLogger.Info($"Offset: {_targetOffset}");

            var data = new ConcurrentDictionary <int, ModelData>();

            if (File.Exists(Path()))
            {
                data = JsonConvert.DeserializeObject <ConcurrentDictionary <int, ModelData> >(File.ReadAllText(Path()));
                //data = TypeSerializer.DeserializeFromReader<ConcurrentDictionary<int, ModelData>>(new StreamReader(Path()));

                Program.StatusLogger.Info("Cached data was loaded.");
            }
            else
            {
                //http://publicdata.landregistry.gov.uk/market-trend-data/house-price-index-data/UK-HPI-full-file-2019-07.csv
                var header      = File.ReadLines("UK-HPI-full-file-2019-07.csv").First();
                var columnNames = header.Split(",");

                var parser = new CsvParser(() => new StringReader(File.ReadAllText("UK-HPI-full-file-2019-07.csv")), ',', false, true);

                var creditData          = _creditDataExtractor.Extract();
                var populationData      = _populationDataExtractor.Extract();
                var otherPopulationData = _otherPopulationDataExtractor.Extract();
                var densityData         = _londonDensityDataExtractor.Extract();
                var gvaData             = _gvaDataExtractor.Extract();

                var featureRows = parser.EnumerateRows().ToArray();
                var targets     = parser.EnumerateRows(_targetName).ToArray();

                string previousKey = null;

                for (int i = 0; i < featureRows.Length; i++)
                {
                    var item = featureRows[i];
                    var key  = item.GetValue("RegionName");
                    var date = DateTime.ParseExact(item.GetValue("Date"), "dd/MM/yyyy", new CultureInfo("en-GB"), DateTimeStyles.AssumeLocal);

                    if (key != previousKey)
                    {
                        Program.StatusLogger.Info($"Processing {key}");
                    }
                    previousKey = key;

                    var regionFeatures = item.GetValues(columnNames.Except(excludeColumns).ToArray()).Select(s => ParseRowValue(s));

                    var creditDataKey = _creditDataExtractor.GetKey(date, creditData.Keys.ToArray());
                    if (!creditData.ContainsKey(creditDataKey))
                    {
                        regionFeatures = regionFeatures.Concat(Enumerable.Repeat(-1d, creditData.Values.First().Length));
                        Trace.WriteLine($"Credit data not found: {creditDataKey}");
                    }
                    else
                    {
                        regionFeatures = regionFeatures.Concat(creditData[creditDataKey]);
                    }

                    var modelData = new ModelData
                    {
                        Name           = key,
                        Code           = item.GetValue("AreaCode"),
                        Date           = date,
                        Observations   = regionFeatures.ToArray(),
                        OriginalTarget = ParseTarget(item.GetValue(_targetName))
                    };

                    modelData.Observations = modelData.Observations
                                             .Concat(_populationDataExtractor.Get(populationData, modelData))
                                             .Concat(_londonDensityDataExtractor.Get(densityData, modelData))
                                             .Concat(_otherPopulationDataExtractor.Get(otherPopulationData, modelData))
                                             .Concat(_gvaDataExtractor.Get(gvaData, modelData))
                                             .ToArray();

                    data.TryAdd(i, modelData);
                }

                _targetCalculator.Calculate(data, _targetOffset);


                //TypeSerializer.SerializeToWriter<ConcurrentDictionary<int, ModelData>>(data, new StreamWriter(Path()));
                var json = JsonConvert.SerializeObject(data, Formatting.Indented);
                File.WriteAllText(Path(), json);
            }

            var itemCount = 0;

            Parallel.ForEach(data.OrderBy(o => o.Value.Date).GroupBy(g => g.Value.Name).AsParallel(), new ParallelOptions {
                MaxDegreeOfParallelism = -1
            }, (grouping) =>
            {
                var lastDate       = grouping.Last().Value.Date;
                var dataWithTarget = grouping.Where(s => s.Value.OriginalTarget.HasValue && s.Value.Target != -1);

                if (dataWithTarget.Any())
                {
                    var allObservations = dataWithTarget.Select(s => s.Value.Observations).ToArray();
                    var allTargets      = dataWithTarget.Select(s => s.Value.Target).ToArray();

                    //var validation = new TimeSeriesCrossValidation<double>((int)(allObservationsExceptLast.RowCount * 0.8), 0, 1);
                    //var validationPredictions = validation.Validate((IIndexedLearner<double>)learner, allObservationsExceptLast, allTargetsExceptLast);
                    //var crossMetric = new MeanSquaredErrorRegressionMetric();
                    //var crossError = crossMetric.Error(validation.GetValidationTargets(allTargetsExceptLast), validationPredictions);
                    //_totalCrossError += crossError;
                    var meanZeroTransformer  = new MeanZeroFeatureTransformer();
                    var minMaxTransformer    = new MinMaxTransformer(0d, 1d);
                    var lastObservations     = grouping.Last().Value.Observations;
                    F64Matrix allTransformed = minMaxTransformer.Transform(meanZeroTransformer.Transform(allObservations.Append(lastObservations).ToArray()));
                    var transformed          = new F64Matrix(allTransformed.Rows(Enumerable.Range(0, allTransformed.RowCount - 1).ToArray()).Data(), allTransformed.RowCount - 1, allTransformed.ColumnCount);

                    var splitter = new RandomTrainingTestIndexSplitter <double>(trainingPercentage: 0.7, seed: 24);

                    var trainingTestSplit = splitter.SplitSet(transformed, allTargets);
                    transformed           = trainingTestSplit.TrainingSet.Observations;
                    var testSet           = trainingTestSplit.TestSet;

                    //var learner = GetRandomForest();
                    //var learner = GetAda();
                    //var learner = GetNeuralNet(grouping.First().Value.Observations.Length, transformed.RowCount);
                    var learner = GetEnsemble(grouping.First().Value.Observations.Length, transformed.RowCount);

                    Program.StatusLogger.Info("Learning commenced " + grouping.First().Value.Name);

                    var model = learner.Learn(transformed, trainingTestSplit.TrainingSet.Targets);

                    Program.StatusLogger.Info("Learning completed " + grouping.First().Value.Name);

                    if (model.GetRawVariableImportance().Any(a => a > 0))
                    {
                        var importanceSummary = string.Join(",\r\n", model.GetRawVariableImportance().Select((d, i) => i.ToString() + ":" + d.ToString()));
                        Program.StatusLogger.Info("Raw variable importance:\r\n" + importanceSummary);
                    }

                    var lastTransformed = allTransformed.Row(transformed.RowCount);
                    var prediction      = model.Predict(lastTransformed);

                    //var before = item.Value.Item2[transformed.RowCount - _targetOffset - 1];
                    var change = -1; //Math.Round(prediction / before, 2);

                    var testPrediction = model.Predict(testSet.Observations);

                    var metric       = new MeanSquaredErrorRegressionMetric();
                    var error        = metric.Error(testSet.Targets, testPrediction);
                    var averageError = 0d;
                    lock (Locker)
                    {
                        _totalError += error;
                        itemCount++;
                        averageError = Math.Round(_totalError / itemCount, 3);
                    }
                    var isLondon = London.Contains(grouping.First().Value.Name);

                    var message = $"TotalError: {Math.Round(_totalError, 3)}, AverageError: {averageError}, Target: {_targetName}, Offset: {_targetOffset}, Region: {grouping.First().Value.Name}, London: {isLondon}, Error: {Math.Round(error, 3)}, Next: {Math.Round(prediction, 3)}, Change: {change}";

                    Program.Logger.Info(message);
                }
            });

            if (pauseAtEnd)
            {
                Console.WriteLine("Press any key to continue");
                Console.ReadKey();
            }
        }