예제 #1
0
        protected override void Train(IDataset ds)
        {
            if (!(ds.nSamples() > 0))
            {
                throw new Exception("nSamples of IDataset must be > 0");
            }
            if (!(ds.nFeatures() > 0))
            {
                throw new Exception("nFeatures of IDataset must be > 0");
            }
            if (c2i.Length() < 1)
            {
                Intarray raw_classes = new Intarray();
                raw_classes.ReserveTo(ds.nSamples());
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    raw_classes.Push(ds.Cls(i));
                }
                ClassMap(c2i, i2c, raw_classes);

                /*Intarray classes = new Intarray();
                 * ctranslate(classes, raw_classes, c2i);*/
                //debugf("info","[mapped %d to %d classes]\n",c2i.length(),i2c.length());
            }
            TranslatedDataset mds = new TranslatedDataset(ds, c2i);

            TrainDense(mds);
        }
예제 #2
0
        private int[] CreateClassesFromDataset(IDataset ds)
        {
            // create class list from dataset
            SortedDictionary <int, int> keymap = new SortedDictionary <int, int>();

            for (int i = 0; i < ds.nSamples(); i++)
            {
                if (!keymap.ContainsKey(ds.Cls(i)))
                {
                    keymap.Add(ds.Cls(i), 1);
                }
                else
                {
                    keymap[ds.Cls(i)]++;
                }
            }
            int[] classes = new int[keymap.Count];
            keymap.Keys.CopyTo(classes, 0);

            // show class counts
            Console.WriteLine("Classes counts:");
            string showline = "";
            int    ishow    = 0;

            for (int i = 0; i < keymap.Count; i++)
            {
                ishow++;
                if (classes[i] >= 32)
                {
                    showline += String.Format("{0,-9}", String.Format("{0}[{1}]", (char)classes[i], keymap[classes[i]]));
                }
                else
                {
                    showline += String.Format("{0,-9}", String.Format("{0}[{1}]", classes[i], keymap[classes[i]]));
                }
                if (ishow % 8 == 0)
                {
                    Console.WriteLine(showline);
                    showline = "";
                }
            }
            Console.WriteLine(showline);
            return(classes);
        }
예제 #3
0
        protected virtual void Train(IDataset ds)
        {
            Floatarray v = new Floatarray();

            for (int i = 0; i < ds.nSamples(); i++)
            {
                ds.Input(v, i);
                Add(v, ds.Cls(i));
            }
        }
예제 #4
0
        public TrainInfo TestDense(IDataset ts)
        {
            TrainInfo  trInfo = new TrainInfo();
            bool       first  = true;
            Floatarray v      = new Floatarray();

            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                {
                    v.Reshape(csize, csize, 0, 0);
                }
                StdInput linput = new StdInput(v);
                if (first)
                {
                    first = false;
                    BeginTestEpoch(HrLenet, linput.Height, linput.Width, ts.nSamples());   // init test
                }
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                    {
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    }
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // do test one epoch
            try
            {
                EndAndRunTestEpoch(HrLenet, ref trInfo);
            }
            catch (Exception e)
            {
                GetStdout(sbout);   // get test messages
                if (sbout.Length > 0)
                {
                    Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTestEpoch");
                }
                throw new Exception("Exception in EndAndRunTestEpoch\r\n" + e.Message);
            }

            return(trInfo);
        }
예제 #5
0
        public static float estimate_errors(IModel classifier, IDataset ds, int n = 1000000)
        {
            Floatarray v      = new Floatarray();
            int        errors = 0;
            int        count  = 0;

            for (int i = 0; i < ds.nSamples(); i++)
            {
                int cls = ds.Cls(i);
                if (cls == -1)
                {
                    continue;
                }
                ds.Input1d(v, i);
                int pred = classifier.Classify(v);
                count++;
                if (pred != cls)
                {
                    errors++;
                }
            }
            return(errors / (float)count);
        }
예제 #6
0
 public override int Cls(int i)
 {
     return(_ds.Cls(i));
 }
예제 #7
0
        public virtual TrainInfo TrainBatch(IDataset ds, IDataset ts, int epochs)
        {
            TrainInfo  trInfo = new TrainInfo();
            bool       first  = true;
            Floatarray v      = new Floatarray();

            // Send Train Dataset to Lenet
            for (int i = 0; i < ds.nSamples(); i++)
            {
                ds.Input(v, i);
                if (v.Rank() == 1)
                {
                    v.Reshape(csize, csize, 0, 0);
                }
                StdInput linput = new StdInput(v);
                try
                {
                    if (first)
                    {
                        first = false;
                        //StartRedirectStdout();  // start redirect cout to string buffer
                        BeginTrainEpoch(HrLenet, linput.Height, linput.Width, ds.nSamples(), ts.nSamples());   // init train
                    }
                    if (C2i.ContainsKey(ds.Cls(i)))
                    {
                        AddSampleToTrainOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ds.Cls(i)]);
                    }
                    else
                    {
                        Global.Debugf("error", "class '{0}' is not contained in the char list", (char)ds.Cls(i));
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTrainOfEpoch");
                    }
                    throw new Exception("Exception in AddSampleToTrainOfEpoch\r\n" + e.Message);
                }
            }

            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                {
                    v.Reshape(csize, csize, 0, 0);
                }
                StdInput linput = new StdInput(v);
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                    {
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    }
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // debug save mnist
            //SaveTrainMnist(HrLenet, "debug-images-idx3-ubyte", "debug-labels-idx1-ubyte");

            // do train epochs
            for (int epoch = 0; epoch < epochs; epoch++)
            {
                trInfo = new TrainInfo();
                try
                {
                    DateTime startDate = DateTime.Now;
                    EndAndRunTrainEpoch(HrLenet, ref trInfo);   // do train one epoch
                    // show train info
                    Global.Debugf("info",
                                  String.Format("|{0,7}| Energy:{1:0.#####} Correct:{2:0.00#%} Errors:{3:0.00#%} Count:{4} ",
                                                trInfo.age, trInfo.energy, (trInfo.correct / (float)trInfo.size),
                                                (trInfo.error / (float)trInfo.size), trInfo.size));
                    Global.Debugf("info",
                                  String.Format("     TEST Energy={0:0.#####} Correct={1:0.00#%} Errors={2:0.00#%} Count={3} ",
                                                trInfo.tenergy, (trInfo.tcorrect / (float)trInfo.tsize),
                                                (trInfo.terror / (float)trInfo.tsize), trInfo.tsize));
                    TimeSpan spanTrain = DateTime.Now - startDate;
                    Global.Debugf("info", String.Format("          training time: {0} minutes, {1} seconds",
                                                        (int)spanTrain.TotalMinutes, spanTrain.Seconds));

                    // get dll stdout messages
                    GetStdout(sbout);
                    if (sbout.Length > 0)
                    {
                        Console.Write(sbout.ToString());
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTrainEpoch");
                    }
                    throw new Exception("Exception in EndAndRunTrainEpoch\r\n" + e.Message);
                }
            }
            return(trInfo);
        }
예제 #8
0
        protected override void Train(IDataset ds)
        {
            bool use_junk = PGetb("junk") && !DisableJunk;
            int  nsamples = ds.nSamples();

            if (PExists("%nsamples"))
            {
                nsamples += PGeti("%nsamples");
            }

            Global.Debugf("info", "Training content classifier");

            if (CharClass.IsEmpty)
            {
                Initialize(CreateClassesFromDataset(ds));
            }
            if (use_junk /*&& !JunkClass.IsEmpty*/)
            {
                Intarray nonjunk = new Intarray();
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    if (ds.Cls(i) != jc())
                    {
                        nonjunk.Push(i);
                    }
                }
                Datasubset nonjunkds = new Datasubset(ds, nonjunk);
                CharClass.TrainDense(nonjunkds, PGeti("epochs"));
            }
            else
            {
                CharClass.TrainDense(ds, PGeti("epochs"));
            }

            if (use_junk /*&& !JunkClass.IsEmpty*/)
            {
                Global.Debugf("info", "Training junk classifier");
                Intarray isjunk = new Intarray();
                int      njunk  = 0;
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    bool j = (ds.Cls(i) == jc());
                    isjunk.Push(JunkClass.Classes[Convert.ToInt32(j)]);
                    if (j)
                    {
                        njunk++;
                    }
                }
                if (njunk > 0)
                {
                    MappedDataset junkds = new MappedDataset(ds, isjunk);
                    JunkClass.TrainDense(junkds, PGeti("epochs"));
                }
                else
                {
                    Global.Debugf("warn", "you are training a junk class but there are no samples to train on");
                    JunkClass.DeleteLenet();
                }
            }
            PSet("%nsamples", nsamples);
        }
예제 #9
0
        protected override void Train(IDataset ds)
        {
            bool use_junk = PGetb("junk") && !DisableJunk;

            if (charclass.IsEmpty)
            {
                charclass.SetComponent(ComponentCreator.MakeComponent(PGet("charclass")));
                TryAttachCharClassifierEvent(charclass.Object);
            }
            if (junkclass.IsEmpty)
            {
                junkclass.SetComponent(ComponentCreator.MakeComponent(PGet("junkclass")));
                TryAttachJunkClassifierEvent(junkclass.Object);
            }
            if (ulclass.IsEmpty)
            {
                ulclass.SetComponent(ComponentCreator.MakeComponent(PGet("ulclass")));
            }

            Global.Debugf("info", "Training content classifier");
            if (use_junk && !junkclass.IsEmpty)
            {
                Intarray nonjunk = new Intarray();
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    if (ds.Cls(i) != jc())
                    {
                        nonjunk.Push(i);
                    }
                }
                Datasubset nonjunkds = new Datasubset(ds, nonjunk);
                charclass.Object.XTrain(nonjunkds);
            }
            else
            {
                charclass.Object.XTrain(ds);
            }

            if (use_junk && !junkclass.IsEmpty)
            {
                Global.Debugf("info", "Training junk classifier");
                Intarray isjunk = new Intarray();
                int      njunk  = 0;
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    bool j = (ds.Cls(i) == jc());
                    isjunk.Push(Convert.ToInt32(j));
                    if (j)
                    {
                        njunk++;
                    }
                }
                if (njunk > 0)
                {
                    MappedDataset junkds = new MappedDataset(ds, isjunk);
                    junkclass.Object.XTrain(junkds);
                }
                else
                {
                    Global.Debugf("warn", "you are training a junk class but there are no samples to train on");
                    junkclass.SetComponent(null);
                }

                if (PGeti("ul") > 0 && !ulclass.IsEmpty)
                {
                    throw new Exception("ulclass not implemented");
                }
            }
        }