Пример #1
0
        public void Train(ILabeledExampleCollection <LblT, SparseVector <double> > dataset)
        {
            Utils.ThrowException(dataset == null ? new ArgumentNullException("dataset") : null);
            Utils.ThrowException(dataset.Count == 0 ? new ArgumentValueException("dataset") : null);
            Dictionary <LblT, CentroidData> centroids = new Dictionary <LblT, CentroidData>(mLblCmp);

            foreach (LabeledExample <LblT, SparseVector <double> > labeledExample in dataset)
            {
                if (!centroids.ContainsKey(labeledExample.Label))
                {
                    CentroidData centroidData = new CentroidData();
                    centroidData.AddToSum(labeledExample.Example);
                    centroids.Add(labeledExample.Label, centroidData);
                }
                else
                {
                    CentroidData centroidData = centroids[labeledExample.Label];
                    centroidData.AddToSum(labeledExample.Example);
                }
            }
            foreach (CentroidData cenData in centroids.Values)
            {
                cenData.UpdateCentroidLen();
            }
            double learnRate = 1;

            double[][]            dotProd = null;
            SparseMatrix <double> dsMtx   = null;

            if (mIterations > 0)
            {
                dotProd = new double[centroids.Count][];
                dsMtx   = ModelUtils.GetTransposedMatrix(ModelUtils.ConvertToUnlabeledDataset(dataset));
            }
            for (int iter = 1; iter <= mIterations; iter++)
            {
                mLogger.Info("Train", "Iteration {0} / {1} ...", iter, mIterations);
                // compute dot products
                mLogger.Info("Train", "Computing dot products ...");
                int j = 0;
                foreach (KeyValuePair <LblT, CentroidData> labeledCentroid in centroids)
                {
                    mLogger.ProgressNormal(Logger.Level.Info, /*sender=*/ this, "Train", "Centroid {0} / {1} ...", j + 1, centroids.Count);
                    SparseVector <double> cenVec = labeledCentroid.Value.GetSparseVector();
                    dotProd[j] = ModelUtils.GetDotProductSimilarity(dsMtx, dataset.Count, cenVec);
                    j++;
                }
                // classify training examples
                mLogger.Info("Train", "Classifying training examples ...");
                int errCount = 0;
                for (int instIdx = 0; instIdx < dataset.Count; instIdx++)
                {
                    mLogger.ProgressFast(Logger.Level.Info, /*sender=*/ this, "Train", "Example {0} / {1} ...", instIdx + 1, dataset.Count);
                    double       maxSim           = double.MinValue;
                    CentroidData assignedCentroid = null;
                    CentroidData actualCentroid   = null;
                    LabeledExample <LblT, SparseVector <double> > labeledExample = dataset[instIdx];
                    SparseVector <double> vec = labeledExample.Example;
                    int cenIdx = 0;
                    foreach (KeyValuePair <LblT, CentroidData> labeledCentroid in centroids)
                    {
                        double sim = dotProd[cenIdx][instIdx];
                        if (sim > maxSim)
                        {
                            maxSim = sim; assignedCentroid = labeledCentroid.Value;
                        }
                        if (labeledCentroid.Key.Equals(labeledExample.Label))
                        {
                            actualCentroid = labeledCentroid.Value;
                        }
                        cenIdx++;
                    }
                    if (assignedCentroid != actualCentroid)
                    {
                        assignedCentroid.AddToDiff(-learnRate, vec);
                        actualCentroid.AddToDiff(learnRate, vec);
                        errCount++;
                    }
                }
                mLogger.Info("Train", "Training set error rate: {0:0.00}%", (double)errCount / (double)dataset.Count * 100.0);
                // update centroids
                int k = 0;
                foreach (CentroidData centroidData in centroids.Values)
                {
                    mLogger.ProgressNormal(Logger.Level.Info, /*sender=*/ this, "Train", "Centroid {0} / {1} ...", ++k, centroids.Count);
                    centroidData.Update(mPositiveValuesOnly);
                    centroidData.UpdateCentroidLen();
                }
                learnRate *= mDamping;
            }
            mCentroidMtxTr = new SparseMatrix <double>();
            mLabels        = new ArrayList <LblT>();
            int rowIdx = 0;

            foreach (KeyValuePair <LblT, CentroidData> labeledCentroid in centroids)
            {
                mCentroidMtxTr[rowIdx++] = labeledCentroid.Value.GetSparseVector();
                mLabels.Add(labeledCentroid.Key);
            }
            mCentroidMtxTr = mCentroidMtxTr.GetTransposedCopy();
        }
Пример #2
0
 public void Add(LabeledExample <LblT, ExT> labeledExample)
 {
     Utils.ThrowException(labeledExample == null ? new ArgumentNullException("labeledExample") : null);
     mItems.Add(labeledExample);
 }
Пример #3
0
 public IDataset <LblT> ConvertDataset(Type new_ex_type, bool move)
 {
     Utils.ThrowException(new_ex_type == null ? new ArgumentNullException("new_ex_type") : null);
     if (new_ex_type == typeof(SparseVector <double>))
     {
         Dataset <LblT, SparseVector <double> > new_dataset = new Dataset <LblT, SparseVector <double> >();
         for (int i = 0; i < m_items.Count; i++)
         {
             LabeledExample <LblT, ExT> example = m_items[i];
             new_dataset.Add(example.Label, ModelUtils.ConvertExample <SparseVector <double> >(example.Example));
             if (move)
             {
                 m_items[i] = new LabeledExample <LblT, ExT>();
             }
         }
         if (move)
         {
             m_items.Clear();
         }
         return(new_dataset);
     }
     else if (new_ex_type == typeof(SparseVector <double> .ReadOnly))
     {
         Dataset <LblT, SparseVector <double> .ReadOnly> new_dataset = new Dataset <LblT, SparseVector <double> .ReadOnly>();
         for (int i = 0; i < m_items.Count; i++)
         {
             LabeledExample <LblT, ExT> example = m_items[i];
             new_dataset.Add(example.Label, ModelUtils.ConvertExample <SparseVector <double> .ReadOnly>(example.Example));
             if (move)
             {
                 m_items[i] = new LabeledExample <LblT, ExT>();
             }
         }
         if (move)
         {
             m_items.Clear();
         }
         return(new_dataset);
     }
     else if (new_ex_type == typeof(BinaryVector <int>))
     {
         Dataset <LblT, BinaryVector <int> > new_dataset = new Dataset <LblT, BinaryVector <int> >();
         for (int i = 0; i < m_items.Count; i++)
         {
             LabeledExample <LblT, ExT> example = m_items[i];
             new_dataset.Add(example.Label, ModelUtils.ConvertExample <BinaryVector <int> >(example.Example));
             if (move)
             {
                 m_items[i] = new LabeledExample <LblT, ExT>();
             }
         }
         if (move)
         {
             m_items.Clear();
         }
         return(new_dataset);
     }
     else if (new_ex_type == typeof(BinaryVector <int> .ReadOnly))
     {
         Dataset <LblT, BinaryVector <int> .ReadOnly> new_dataset = new Dataset <LblT, BinaryVector <int> .ReadOnly>();
         for (int i = 0; i < m_items.Count; i++)
         {
             LabeledExample <LblT, ExT> example = m_items[i];
             new_dataset.Add(example.Label, ModelUtils.ConvertExample <BinaryVector <int> .ReadOnly>(example.Example));
             if (move)
             {
                 m_items[i] = new LabeledExample <LblT, ExT>();
             }
         }
         if (move)
         {
             m_items.Clear();
         }
         return(new_dataset);
     }
     //else if (new_ex_type == typeof(SvmFeatureVector))
     //{
     //    Dataset<LblT, SvmFeatureVector> new_dataset = new Dataset<LblT, SvmFeatureVector>();
     //    for (int i = 0; i < m_items.Count; i++)
     //    {
     //        LabeledExample<LblT, ExT> example = m_items[i];
     //        new_dataset.Add(example.Label, ModelUtils.ConvertVector<SvmFeatureVector>(example.Example));
     //        if (move) { m_items[i] = new LabeledExample<LblT, ExT>(); }
     //    }
     //    if (move) { m_items.Clear(); }
     //    return new_dataset;
     //}
     else
     {
         throw new ArgumentNotSupportedException("new_ex_type");
     }
 }