コード例 #1
0
        public void SsaForecast()
        {
            var       env = new MLContext(1);
            const int changeHistorySize         = 10;
            const int seasonalitySize           = 10;
            const int numberOfSeasonsInTraining = 5;

            List <Data> data     = new List <Data>();
            var         dataView = env.Data.LoadFromEnumerable(data);

            var args = new SsaForecastingTransformer.Options()
            {
                ConfidenceLevel            = 0.95f,
                Source                     = "Value",
                Name                       = "Forecast",
                ConfidenceLowerBoundColumn = "MinCnf",
                ConfidenceUpperBoundColumn = "MaxCnf",
                WindowSize                 = 10,
                SeriesLength               = 11,
                TrainSize                  = 22,
                Horizon                    = 4,
                IsAdaptive                 = true
            };

            for (int j = 0; j < numberOfSeasonsInTraining; j++)
            {
                for (int i = 0; i < seasonalitySize; i++)
                {
                    data.Add(new Data(i));
                }
            }

            for (int i = 0; i < changeHistorySize; i++)
            {
                data.Add(new Data(i * 100));
            }

            // Train
            var detector = new SsaForecastingEstimator(env, args).Fit(dataView);
            // Transform
            var output = detector.Transform(dataView);
            // Get predictions
            var enumerator         = env.Data.CreateEnumerable <ForecastPrediction>(output, true).GetEnumerator();
            ForecastPrediction row = null;
            List <float>       expectedForecast = new List <float>()
            {
                0.191491723f, 2.53994083f, 5.26454258f, 7.37313938f
            };
            List <float> minCnf = new List <float>()
            {
                -3.9741993f, -2.36872721f, 0.09407653f, 2.18899345f
            };
            List <float> maxCnf = new List <float>()
            {
                4.3571825f, 7.448609f, 10.435009f, 12.5572853f
            };

            enumerator.MoveNext();
            row = enumerator.Current;

            for (int localIndex = 0; localIndex < 4; localIndex++)
            {
                Assert.Equal(expectedForecast[localIndex], row.Forecast[localIndex], precision: 7);
                Assert.Equal(minCnf[localIndex], row.MinCnf[localIndex], precision: 7);
                Assert.Equal(maxCnf[localIndex], row.MaxCnf[localIndex], precision: 7);
            }
        }