private IDictionary<int, string> m_usernames; // Always update this when updating m_trainingData. Is null when rec sources are finalized.

        #endregion Fields

        #region Constructors

        /// <summary>
        /// 
        /// </summary>
        /// <param name="trainingDataLoaderFactory">Must be thread-safe.</param>
        public RecServiceState(IMalTrainingDataLoaderFactory trainingDataLoaderFactory)
        {
            m_trainingDataLoaderFactory = trainingDataLoaderFactory;
            JsonRecSourceTypes = GetJsonRecSourceTypes();
            using (IMalTrainingDataLoader trainingDataLoader = trainingDataLoaderFactory.GetTrainingDataLoader())
            {
                m_trainingData = LoadTrainingDataOnInit(trainingDataLoader);
                m_usernames = GetUsernamesFromTrainingData(m_trainingData);
                m_animes = m_trainingData.Animes;
                m_prereqs = LoadPrereqsOnInit(trainingDataLoader);
            }
            m_trainingDataLock = new ReaderWriterLockSlim();
            m_recSourcesLock = new ReaderWriterLockSlim();
        }
Exemple #2
0
        static Tuple<MalTrainingData, ICollection<MalUserListEntries>> GetDataForTrainingAndEvaluation(MalTrainingData rawData)
        {
            List<int> userIds = rawData.Users.Keys.ToList();
            userIds.Shuffle();

            Dictionary<int, MalUserListEntries> trainingUsers = new Dictionary<int, MalUserListEntries>();
            List<MalUserListEntries> evaluationUsers = new List<MalUserListEntries>();

            int numUsersForTraining = userIds.Count / 2;
            for (int i = 0; i < numUsersForTraining; i++)
            {
                trainingUsers[userIds[i]] = rawData.Users[userIds[i]];
            }
            for (int i = numUsersForTraining; i < userIds.Count; i++)
            {
                evaluationUsers.Add(rawData.Users[userIds[i]]);
            }

            MalTrainingData trainingData = new MalTrainingData(trainingUsers, rawData.Animes);

            return new Tuple<MalTrainingData, ICollection<MalUserListEntries>>(trainingData, evaluationUsers);
        }
        private void ReloadTrainingDataLowMemory(bool finalize)
        {
            using (var trainingDataWriteLock = m_trainingDataLock.ScopedWriteLock())
            using (var recSourcesWriteLock = m_recSourcesLock.ScopedWriteLock())
            {
                Logging.Log.Info("Reloading training data and prerequisites and retraining rec sources. Rec sources will not be available until retraining all rec sources is complete.");
                Stopwatch totalTimer = Stopwatch.StartNew();

                m_recSources.Clear();
                m_trainingData = null;
                m_usernames = null;
                m_animes = null;
                m_prereqs = null;
                m_finalized = false;

                GC.Collect();
                Logging.Log.Info("Rec sources cleared.");
                Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));

                Stopwatch timer = Stopwatch.StartNew();

                // Load new training data
                // If this throws an error, m_trainingData is left null. Methods that use m_trainingData should check it for null.
                using (IMalTrainingDataLoader malTrainingDataLoader = m_trainingDataLoaderFactory.GetTrainingDataLoader())
                {
                    Logging.Log.Debug("Created training data loader.");

                    m_trainingData = malTrainingDataLoader.LoadMalTrainingData();
                    m_usernames = GetUsernamesFromTrainingData(m_trainingData);
                    m_animes = m_trainingData.Animes;
                    timer.Stop();

                    Logging.Log.InfoFormat("Training data loaded. {0} users, {1} animes, {2} entries. Took {3}.",
                        m_trainingData.Users.Count, m_trainingData.Animes.Count,
                        m_trainingData.Users.Keys.Sum(userId => m_trainingData.Users[userId].Entries.Count),
                        timer.Elapsed);

                    timer.Restart();
                    m_prereqs = malTrainingDataLoader.LoadPrerequisites();
                    timer.Stop();

                    int numPrereqs = m_prereqs.Values.Sum(prereqList => prereqList.Count);
                    Logging.Log.InfoFormat("Prerequisites loaded. {0} prerequisites for {1} animes. Took {2}.",
                        numPrereqs, m_prereqs.Count, timer.Elapsed);
                }

                GC.Collect();
                Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));

                // Then retrain all loaded rec sources.
                // ToList() so we can unload a rec source as we iterate if it errors while training.
                foreach (string recSourceName in m_recSourceFactories.Keys.ToList())
                {
                    ITrainableJsonRecSource recSource = m_recSourceFactories[recSourceName]();
                    try
                    {
                        Logging.Log.InfoFormat("Retraining rec source {0} ({1}).", recSourceName, recSource);
                        timer.Restart();

                        recSource.Train(m_trainingData, m_usernames);
                        m_recSources[recSourceName] = recSource;

                        timer.Stop();
                        Logging.Log.InfoFormat("Trained rec source {0} ({1}). Took {2}.", recSourceName, recSource, timer.Elapsed);
                        GC.Collect();
                        Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));
                    }
                    catch (Exception ex)
                    {
                        Logging.Log.ErrorFormat("Error retraining rec source {0} ({1}): {2} Unloading it.",
                            ex, recSourceName, recSource, ex.Message);
                        m_recSourceFactories.Remove(recSourceName);
                    }
                }

                if (finalize)
                {
                    m_trainingData = null;
                    m_usernames = null;
                    m_finalized = true;
                    Logging.Log.Info("Finalized rec sources.");
                }

                totalTimer.Stop();
                Logging.Log.InfoFormat("All rec sources retrained with the latest data. Total time: {0}", totalTimer.Elapsed);
            }

            GC.Collect();
            Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));
        }
        private void ReloadTrainingDataHighAvailability(bool finalize)
        {
            Logging.Log.Info("Reloading training data and retraining rec sources. Rec sources will remain available.");
            Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));

            Stopwatch timer = Stopwatch.StartNew();
            Stopwatch totalTimer = Stopwatch.StartNew();

            // Load new training data
            MalTrainingData newData;
            IDictionary<int, string> newUsernames;
            IDictionary<int, IList<int>> newPrereqs;
            using (IMalTrainingDataLoader malTrainingDataLoader = m_trainingDataLoaderFactory.GetTrainingDataLoader())
            {
                Logging.Log.Debug("Created training data loader.");

                newData = malTrainingDataLoader.LoadMalTrainingData();
                newUsernames = GetUsernamesFromTrainingData(newData);

                timer.Stop();
                Logging.Log.InfoFormat("Training data loaded. {0} users, {1} animes, {2} entries. Took {3}.",
                    newData.Users.Count, newData.Animes.Count,
                    newData.Users.Keys.Sum(userId => newData.Users[userId].Entries.Count),
                    timer.Elapsed);

                timer.Restart();
                newPrereqs = malTrainingDataLoader.LoadPrerequisites();
                timer.Stop();

                int numPrereqs = newPrereqs.Values.Sum(prereqList => prereqList.Count);
                Logging.Log.InfoFormat("Prerequisites loaded. {0} prerequisites for {1} animes. Took {2}.",
                    numPrereqs, newPrereqs.Count, timer.Elapsed);
            }

            GC.Collect();
            Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));

            using (var trainingDataUpgradeableLock = m_trainingDataLock.ScopedUpgradeableReadLock())
            using (var recSourcesUpgradeableLock = m_recSourcesLock.ScopedUpgradeableReadLock())
            {
                // clone the json rec sources without the training state and train each one with the new data.
                Dictionary<string, ITrainableJsonRecSource> newRecSources = new Dictionary<string, ITrainableJsonRecSource>(StringComparer.OrdinalIgnoreCase);
                Dictionary<string, Func<ITrainableJsonRecSource>> newRecSourceFactories = new Dictionary<string, Func<ITrainableJsonRecSource>>(m_recSourceFactories, StringComparer.OrdinalIgnoreCase);
                foreach (string recSourceName in newRecSourceFactories.Keys)
                {
                    ITrainableJsonRecSource recSource = newRecSourceFactories[recSourceName]();
                    Logging.Log.InfoFormat("Retraining rec source {0} ({1}).", recSourceName, recSource);
                    timer.Restart();
                    try
                    {
                        recSource.Train(newData, newUsernames);
                        timer.Stop();
                        Logging.Log.InfoFormat("Trained rec source {0} ({1}). Took {2}.", recSourceName, recSource, timer.Elapsed);
                        newRecSources[recSourceName] = recSource;

                        GC.Collect();
                        Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));
                    }
                    catch (Exception ex)
                    {
                        Logging.Log.ErrorFormat("Error retraining rec source {0} ({1}): {2} Unloading it.",
                            ex, recSourceName, recSource, ex.Message);
                        newRecSourceFactories.Remove(recSourceName);
                    }
                }

                // Swap in the newly trained rec sources.
                using (var trainingDatariteLock = m_trainingDataLock.ScopedWriteLock())
                using (var recSourcesWriteLock = m_recSourcesLock.ScopedWriteLock())
                {
                    m_recSources = newRecSources;
                    m_recSourceFactories = newRecSourceFactories;

                    m_animes = newData.Animes;
                    m_prereqs = newPrereqs;

                    if (finalize)
                    {
                        m_trainingData = null;
                        m_usernames = null;
                        m_finalized = true;
                        Logging.Log.Info("Finalized rec sources.");
                    }
                    else
                    {
                        m_trainingData = newData;
                        m_usernames = newUsernames;
                        m_finalized = false;
                    }
                }
            }

            totalTimer.Stop();
            Logging.Log.InfoFormat("All rec sources retrained with the latest data. Total time: {0}", totalTimer.Elapsed);

            GC.Collect();
            Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));
        }
 private static IDictionary<int, string> GetUsernamesFromTrainingData(MalTrainingData trainingData)
 {
     Dictionary<int, string> usernames = new Dictionary<int, string>(trainingData.Users.Count);
     foreach (int userId in trainingData.Users.Keys)
     {
         usernames[userId] = trainingData.Users[userId].MalUsername;
     }
     return usernames;
 }
 public void FinalizeRecSources()
 {
     using (var trainingDataWriteLock = m_trainingDataLock.ScopedWriteLock())
     {
         m_trainingData = null;
         m_usernames = null;
         m_finalized = true;
     }
     GC.Collect();
     Logging.Log.Info("Finalized rec sources.");
     Logging.Log.InfoFormat("Memory use: {0} bytes", GC.GetTotalMemory(forceFullCollection: false));
 }
Exemple #7
0
        static void Main(string[] args)
        {
            TopNEvaluator evaluator = new TopNEvaluator();

            var recommendersUnderTest = new List <ITrainableRecSource <MalTrainingData, MalUserListEntries, IEnumerable <IRecommendation>, IRecommendation> >();
            List <List <EvaluationResults> > resultsForEachRecommender = new List <List <EvaluationResults> >();

            const int    minEpisodesToCountIncomplete = 26;
            const double targetPercentile             = 0.25;

            //var averageScoreRecSourceWithoutDropped = new MalAverageScoreRecSource(minEpisodesToCountIncomplete, useDropped: false, minUsersToCountAnime: 50);
            //var mostPopularRecSourceWithoutDropped = new MalMostPopularRecSource(minEpisodesToCountIncomplete, useDropped: false);
            //var defaultBiasedMatrixFactorizationRecSource = new MalMyMediaLiteRatingPredictionRecSource<BiasedMatrixFactorization>
            //    (new BiasedMatrixFactorization(), minEpisodesToCountIncomplete, useDropped: true, minUsersToCountAnime: 50);
            //var biasedMatrixFactorizationRecSourceWithBoldDriver = new MalMyMediaLiteRatingPredictionRecSource<BiasedMatrixFactorization>
            //    (new BiasedMatrixFactorization() { BoldDriver = true }, minEpisodesToCountIncomplete, useDropped: true, minUsersToCountAnime: 50);
            //var biasedMatrixFactorizationRecSourceWithFactors = new MalMyMediaLiteRatingPredictionRecSource<BiasedMatrixFactorization>
            //    (new BiasedMatrixFactorization() { BoldDriver = true, FrequencyRegularization = true, NumFactors = 50 },
            //    minEpisodesToCountIncomplete, useDropped: true, minUsersToCountAnime: 50);
            //var biasedMatrixFactorizationRecSourceWithFactorsAndIters = new MalMyMediaLiteRatingPredictionRecSource<BiasedMatrixFactorization>
            //    (new BiasedMatrixFactorization() { BoldDriver = true, FrequencyRegularization = true, NumFactors = 50, NumIter = 50 },
            //    minEpisodesToCountIncomplete, useDropped: true, minUsersToCountAnime: 50);
            //var defaultMatrixFactorizationRecSource = new MalMyMediaLiteRatingPredictionRecSource<MatrixFactorization>
            //    (new MatrixFactorization(), minEpisodesToCountIncomplete, useDropped: true, minUsersToCountAnime: 50);
            var animeRecsRecSource35 = new MalAnimeRecsRecSourceWithConstantPercentTarget(
                numRecommendersToUse: 100,
                fractionConsideredRecommended: 0.35,
                targetFraction: 0.35,
                minEpisodesToClassifyIncomplete: minEpisodesToCountIncomplete
                );

            //var animeRecsRecSource25 = new MalAnimeRecsRecSourceWithConstantPercentTarget(
            //    numRecommendersToUse: 100,
            //    fractionConsideredRecommended: 0.25,
            //    targetFraction: 0.25,
            //    minEpisodesToClassifyIncomplete: minEpisodesToCountIncomplete
            //);

#if MYMEDIALITE
            var bprmfRecSource = new MalMyMediaLiteItemRecommenderRecSourceWithConstantPercentTarget <BPRMF>(
                new BPRMF()
            {
                BiasReg = .01f
            },
                fractionConsideredRecommended: 0.25,
                minEpisodesToClassifyIncomplete: minEpisodesToCountIncomplete,
                minUsersToCountAnime: 30,
                targetFraction: 0.25
                );
#endif

            //recommendersUnderTest.Add(averageScoreRecSourceWithoutDropped);
            //recommendersUnderTest.Add(mostPopularRecSourceWithoutDropped);
            //recommendersUnderTest.Add(defaultBiasedMatrixFactorizationRecSource);
            //recommendersUnderTest.Add(biasedMatrixFactorizationRecSourceWithBoldDriver);
            //recommendersUnderTest.Add(biasedMatrixFactorizationRecSourceWithFactors);
            //recommendersUnderTest.Add(biasedMatrixFactorizationRecSourceWithFactorsAndIters);
            //recommendersUnderTest.Add(defaultMatrixFactorizationRecSource);
            recommendersUnderTest.Add(animeRecsRecSource35);
            //recommendersUnderTest.Add(animeRecsRecSource25);
#if MYMEDIALITE
            recommendersUnderTest.Add(bprmfRecSource);
#endif

            for (int i = 0; i < recommendersUnderTest.Count; i++)
            {
                resultsForEachRecommender.Add(new List <EvaluationResults>());
            }

            IUserInputClassifier <MalUserListEntries> targetClassifier = new MalPercentageRatingClassifier(targetPercentile, minEpisodesToCountIncomplete);

            MalTrainingData rawData;

            IConfigurationBuilder configBuilder = new ConfigurationBuilder()
                                                  .AddXmlFile("config_base.xml")
                                                  .AddXmlFile("config_overrides.xml", optional: true);

            IConfigurationRoot rawConfig = configBuilder.Build();
            Config             config    = rawConfig.Get <Config>();

            string postgresConnectionString = config.ConnectionStrings.AnimeRecs;
            using (PgMalDataLoader loader = new PgMalDataLoader(postgresConnectionString))
                using (CancellationTokenSource timeout = new CancellationTokenSource(TimeSpan.FromSeconds(60)))
                {
                    rawData = loader.LoadMalTrainingDataAsync(timeout.Token).ConfigureAwait(false).GetAwaiter().GetResult();
                }

            const int numEvaluations = 5;
            const int numRecsToGet   = 25;

            for (int pass = 0; pass < numEvaluations; pass++)
            {
                for (int recSourceIndex = 0; recSourceIndex < recommendersUnderTest.Count; recSourceIndex++)
                {
                    ITrainableRecSource <MalTrainingData, MalUserListEntries, IEnumerable <IRecommendation>, IRecommendation> recSource = recommendersUnderTest[recSourceIndex];

                    Tuple <MalTrainingData, ICollection <MalUserListEntries> > dataForTrainingAndEvaluation = GetDataForTrainingAndEvaluation(rawData);
                    MalTrainingData trainingData = dataForTrainingAndEvaluation.Item1;
                    ICollection <MalUserListEntries> dataForEvaluation = dataForTrainingAndEvaluation.Item2;

                    recSource.Train(trainingData);

                    EvaluationResults results = evaluator.Evaluate(
                        recSource: recSource,
                        users: dataForEvaluation,
                        goodBadClassifier: targetClassifier,
                        inputDivisionFunc: MalUserListEntries.DivideClassifiedForInputAndEvaluation,
                        numRecsToTryToGet: numRecsToGet
                        );

                    resultsForEachRecommender[recSourceIndex].Add(results);
                }
            }

            for (int recSourceIndex = 0; recSourceIndex < recommendersUnderTest.Count; recSourceIndex++)
            {
                ITrainableRecSource <MalTrainingData, MalUserListEntries, IEnumerable <IRecommendation>, IRecommendation> recSource = recommendersUnderTest[recSourceIndex];
                Console.WriteLine(recSource);
                foreach (EvaluationResults resultsForPass in resultsForEachRecommender[recSourceIndex])
                {
                    Console.WriteLine("Precision: {0:P2}\tRecall: {1:P2}",
                                      resultsForPass.Precision, resultsForPass.Recall);
                }
                Console.WriteLine();
                Console.WriteLine();
                Console.WriteLine();
                Console.WriteLine();
            }
        }
Exemple #8
0
        static Tuple <MalTrainingData, ICollection <MalUserListEntries> > GetDataForTrainingAndEvaluation(MalTrainingData rawData)
        {
            List <int> userIds = rawData.Users.Keys.ToList();

            userIds.Shuffle();

            Dictionary <int, MalUserListEntries> trainingUsers   = new Dictionary <int, MalUserListEntries>();
            List <MalUserListEntries>            evaluationUsers = new List <MalUserListEntries>();

            int numUsersForTraining = userIds.Count / 2;

            for (int i = 0; i < numUsersForTraining; i++)
            {
                trainingUsers[userIds[i]] = rawData.Users[userIds[i]];
            }
            for (int i = numUsersForTraining; i < userIds.Count; i++)
            {
                evaluationUsers.Add(rawData.Users[userIds[i]]);
            }

            MalTrainingData trainingData = new MalTrainingData(trainingUsers, rawData.Animes);

            return(new Tuple <MalTrainingData, ICollection <MalUserListEntries> >(trainingData, evaluationUsers));
        }