public void persistAndLoad()
        {
            FastByIDMap<int?> userIDMapping = new FastByIDMap<int?>();
            FastByIDMap<int?> itemIDMapping = new FastByIDMap<int?>();

            userIDMapping.Put(123, 0);
            userIDMapping.Put(456, 1);

            itemIDMapping.Put(12, 0);
            itemIDMapping.Put(34, 1);

            double[][] userFeatures = new double[][] { new double[] { 0.1, 0.2, 0.3 }, new double[] { 0.4, 0.5, 0.6 } };
            double[][] itemFeatures = new double[][] { new double[] { 0.7, 0.8, 0.9 }, new double[] { 1.0, 1.1, 1.2 } };

            Factorization original = new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
            var storage = Path.Combine( Path.GetTempPath(), "storage.bin");
            try {
            IPersistenceStrategy persistenceStrategy = new FilePersistenceStrategy(storage);

            Assert.IsNull(persistenceStrategy.Load());

            persistenceStrategy.MaybePersist(original);
            Factorization clone = persistenceStrategy.Load();

            Assert.True(original.Equals( clone ) );
            } finally {
            if (File.Exists(storage))
            try { File.Delete(storage); } catch { }
            }
        }
        public void persistAndLoad()
        {
            FastByIDMap <int?> userIDMapping = new FastByIDMap <int?>();
            FastByIDMap <int?> itemIDMapping = new FastByIDMap <int?>();

            userIDMapping.Put(123, 0);
            userIDMapping.Put(456, 1);

            itemIDMapping.Put(12, 0);
            itemIDMapping.Put(34, 1);

            double[][] userFeatures = new double[][] { new double[] { 0.1, 0.2, 0.3 }, new double[] { 0.4, 0.5, 0.6 } };
            double[][] itemFeatures = new double[][] { new double[] { 0.7, 0.8, 0.9 }, new double[] { 1.0, 1.1, 1.2 } };

            Factorization original = new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
            var           storage  = Path.Combine(Path.GetTempPath(), "storage.bin");

            try {
                IPersistenceStrategy persistenceStrategy = new FilePersistenceStrategy(storage);

                Assert.IsNull(persistenceStrategy.Load());

                persistenceStrategy.MaybePersist(original);
                Factorization clone = persistenceStrategy.Load();

                Assert.True(original.Equals(clone));
            } finally {
                if (File.Exists(storage))
                {
                    try { File.Delete(storage); } catch { }
                }
            }
        }
Exemplo n.º 3
0
 private void train()
 {
     factorization = factorizer.Factorize();
     try {
         persistenceStrategy.MaybePersist(factorization);
     } catch (IOException e) {
         throw new TasteException("Error persisting factorization", e);
     }
 }
 public void MaybePersist(Factorization factorization) {
   Stream outFile = null;
   try {
     log.Info("Writing factorization to {0}...", file);
     outFile = new FileStream(file, FileMode.OpenOrCreate, FileAccess.Write);
     writeBinary(factorization, outFile);
   } finally {
     outFile.Close();
   }
 }
Exemplo n.º 5
0
 public override bool Equals(object o)
 {
     if (o is Factorization)
     {
         Factorization other = (Factorization)o;
         return(userIDMapping.Equals(other.userIDMapping) && itemIDMapping.Equals(other.itemIDMapping) &&
                Utils.ArrayDeepEquals(userFeatures, other.userFeatures) && Utils.ArrayDeepEquals(itemFeatures, other.itemFeatures));
     }
     return(false);
 }
        public void MaybePersist(Factorization factorization)
        {
            Stream outFile = null;

            try {
                log.Info("Writing factorization to {0}...", file);
                outFile = new FileStream(file, FileMode.OpenOrCreate, FileAccess.Write);
                writeBinary(factorization, outFile);
            } finally {
                outFile.Close();
            }
        }
        protected static void writeBinary(Factorization factorization, Stream outFile)
        {
            var binWr = new BinaryWriter(outFile);

            binWr.Write(factorization.numFeatures());
            binWr.Write(factorization.numUsers());
            binWr.Write(factorization.numItems());

            foreach (var mappingEntry in factorization.getUserIDMappings())
            {
                if (!mappingEntry.Value.HasValue)
                {
                    continue; //?correct?
                }
                long userID = mappingEntry.Key;
                binWr.Write(mappingEntry.Value.Value);
                binWr.Write(userID);
                try {
                    double[] userFeatures = factorization.getUserFeatures(userID);
                    for (int feature = 0; feature < factorization.numFeatures(); feature++)
                    {
                        binWr.Write(userFeatures[feature]);
                    }
                } catch (NoSuchUserException e) {
                    throw new IOException("Unable to persist factorization", e);
                }
            }

            foreach (var entry in factorization.getItemIDMappings())
            {
                if (!entry.Value.HasValue)
                {
                    continue; //?correct?
                }
                long itemID = entry.Key;
                binWr.Write(entry.Value.Value);
                binWr.Write(itemID);
                try {
                    double[] itemFeatures = factorization.getItemFeatures(itemID);
                    for (int feature = 0; feature < factorization.numFeatures(); feature++)
                    {
                        binWr.Write(itemFeatures[feature]);
                    }
                } catch (NoSuchItemException e) {
                    throw new IOException("Unable to persist factorization", e);
                }
            }
        }
  protected static void writeBinary(Factorization factorization, Stream outFile) {
	var binWr = new BinaryWriter(outFile);
    binWr.Write( factorization.numFeatures() );
    binWr.Write( factorization.numUsers() );
    binWr.Write( factorization.numItems() );

    foreach (var mappingEntry in factorization.getUserIDMappings()) {
	  if (!mappingEntry.Value.HasValue)
		  continue; //?correct?

	  long userID = mappingEntry.Key;
      binWr.Write(mappingEntry.Value.Value);
      binWr.Write( userID );
      try {
        double[] userFeatures = factorization.getUserFeatures(userID);
        for (int feature = 0; feature < factorization.numFeatures(); feature++) {
          binWr.Write(userFeatures[feature]);
        }
      } catch (NoSuchUserException e) {
        throw new IOException("Unable to persist factorization", e);
      }
    }

    foreach (var entry in factorization.getItemIDMappings()) {
	  if (!entry.Value.HasValue)
		  continue; //?correct?

      long itemID = entry.Key;
      binWr.Write(entry.Value.Value);
      binWr.Write(itemID);
      try {
        double[] itemFeatures = factorization.getItemFeatures(itemID);
        for (int feature = 0; feature < factorization.numFeatures(); feature++) {
          binWr.Write(itemFeatures[feature]);
        }
      } catch (NoSuchItemException e) {
        throw new IOException("Unable to persist factorization", e);
      }
    }
  }
Exemplo n.º 9
0
        /// Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the
        /// store if present, otherwise a new factorization is computed and saved in the store.
        ///
        /// The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store.
        ///
        /// @param dataModel
        /// @param factorizer
        /// @param candidateItemsStrategy
        /// @param persistenceStrategy
        ///
        /// @throws TasteException
        public SVDRecommender(IDataModel dataModel, IFactorizer factorizer, ICandidateItemsStrategy candidateItemsStrategy,
                              IPersistenceStrategy persistenceStrategy) : base(dataModel, candidateItemsStrategy)
        {
            this.factorizer          = factorizer;          //Preconditions.checkNotNull(factorizer);
            this.persistenceStrategy = persistenceStrategy; // Preconditions.checkNotNull(persistenceStrategy);
            try {
                factorization = persistenceStrategy.Load();
            } catch (IOException e) {
                throw new TasteException("Error loading factorization", e);
            }

            if (factorization == null)
            {
                train();
            }

            refreshHelper = new RefreshHelper(() => {
                train();
            });
            refreshHelper.AddDependency(GetDataModel());
            refreshHelper.AddDependency(factorizer);
            refreshHelper.AddDependency(candidateItemsStrategy);
        }
 public void MaybePersist(Factorization factorization) {
   // do nothing.
 }
 public void MaybePersist(Factorization factorization)
 {
     // do nothing.
 }
Exemplo n.º 12
0
        public void testFactorizerWithWithSyntheticData()
        {
            setUpSyntheticData();

            var stopWatch = new System.Diagnostics.Stopwatch();

            stopWatch.Start();

            factorizer = new ParallelSGDFactorizer(dataModel, rank, lambda, numIterations, 0.01, 1, 0, 0);

            Factorization factorization = factorizer.Factorize();

            stopWatch.Stop();
            long duration = stopWatch.ElapsedMilliseconds;

            /// a hold out test would be better, but this is just a toy example so we only check that the
            /// factorization is close to the original matrix
            IRunningAverage    avg     = new FullRunningAverage();
            var                userIDs = dataModel.GetUserIDs();
            IEnumerator <long> itemIDs;

            while (userIDs.MoveNext())
            {
                long userID = userIDs.Current;
                foreach (IPreference pref in dataModel.GetPreferencesFromUser(userID))
                {
                    double rating     = pref.GetValue();
                    var    userVector = factorization.getUserFeatures(userID);
                    var    itemVector = factorization.getItemFeatures(pref.GetItemID());
                    double estimate   = vectorDot(userVector, itemVector);
                    double err        = rating - estimate;

                    avg.AddDatum(err * err);
                }
            }

            double sum = 0.0;

            userIDs = dataModel.GetUserIDs();
            while (userIDs.MoveNext())
            {
                long   userID         = userIDs.Current;
                var    userVector     = factorization.getUserFeatures(userID);
                double regularization = vectorDot(userVector, userVector);
                sum += regularization;
            }

            itemIDs = dataModel.GetItemIDs();
            while (itemIDs.MoveNext())
            {
                long   itemID         = itemIDs.Current;
                var    itemVector     = factorization.getUserFeatures(itemID);
                double regularization = vectorDot(itemVector, itemVector);
                sum += regularization;
            }

            double rmse = Math.Sqrt(avg.GetAverage());
            double loss = avg.GetAverage() / 2 + lambda / 2 * sum;

            logger.Info("RMSE: " + rmse + ";\tLoss: " + loss + ";\tTime Used: " + duration + "ms");
            Assert.True(rmse < 0.2);
        }