Inheritance: IComponent
示例#1
0
        public override TrainInfo TrainBatch(IDataset ds, IDataset ts, int epochs)
        {
            // лучший резальтат на тестовой выборке
            float bestError = 1e30f;
            double bestEnergy = 1e30;
            // вначале надо бы запустить тест и уточнить его
            TrainInfo bestinfo = TestDense(ts);
            bestError = bestinfo.terror / (float)bestinfo.tsize;
            bestEnergy = bestinfo.tenergy;
            Global.Debugf("info", "     BEST errors={0:0.00#%} energy={1:0.#####}", bestError, bestEnergy);

            // сначала сохраним состояние нейросети
            SaveNetworkToBuffer(out netsize, out netstate);

            // запуск эпох тренинга
            for (int epoch = 0; epoch < epochs; epoch++)
            {
                Stopwatch swRound = Stopwatch.StartNew();
                // запустим тренинг
                TrainInfo trinfo = base.TrainBatch(ds, ts, 1);
                float err = trinfo.terror / (float)trinfo.tsize;
                if (err < bestError || (bestError == 0 && err == 0 && trinfo.tenergy < bestEnergy))
                {
                    bestError = err;
                    bestEnergy = trinfo.tenergy;
                    bestinfo = trinfo;
                    // пересохраним состояние улучшенной нейросети
                    SaveNetworkToBuffer(out netsize, out netstate);
                    Global.Debugf("info", "     ==>  best selected");
                }

                swRound.Stop();
                OnTrainRound(this, new TrainEventArgs(
                    epoch, trinfo.tenergy, trinfo.tcorrect, trinfo.tsize, bestEnergy, swRound.Elapsed, TimeSpan.Zero
                    ));
            }

            // восстановим состояние наилучшей сети
            LoadNetworkFromBuffer(netstate, netsize);
            return bestinfo;
        }
示例#2
0
文件: IModel.cs 项目: nickun/OCRonet
 public void XTrain(IDataset ds)
 {
     if(_extractor.IsEmpty) {
         Train(ds);
     } else {
         ExtractedDataset eds = new ExtractedDataset(ds, _extractor.Object);
         Train(eds);
     }
 }
示例#3
0
        private int[] CreateClassesFromDataset(IDataset ds)
        {
            // create class list from dataset
            SortedDictionary<int, int> keymap = new SortedDictionary<int, int>();
            for (int i = 0; i < ds.nSamples(); i++)
            {
                if (!keymap.ContainsKey(ds.Cls(i)))
                    keymap.Add(ds.Cls(i), 1);
                else
                    keymap[ds.Cls(i)]++;
            }
            int[] classes = new int[keymap.Count];
            keymap.Keys.CopyTo(classes, 0);

            // show class counts
            Console.WriteLine("Classes counts:");
            string showline = "";
            int ishow = 0;
            for (int i = 0; i < keymap.Count; i++)
            {
                ishow++;
                if (classes[i] >= 32)
                    showline += String.Format("{0,-9}", String.Format("{0}[{1}]", (char)classes[i], keymap[classes[i]]));
                else
                    showline += String.Format("{0,-9}", String.Format("{0}[{1}]", classes[i], keymap[classes[i]]));
                if (ishow % 8 == 0)
                {
                    Console.WriteLine(showline);
                    showline = "";
                }
            }
            Console.WriteLine(showline);
            return classes;
        }
示例#4
0
        protected override void Train(IDataset ds)
        {
            bool use_junk = PGetb("junk") && !DisableJunk;
            int nsamples = ds.nSamples();
            if (PExists("%nsamples"))
                nsamples += PGeti("%nsamples");

            Global.Debugf("info", "Training content classifier");

            if (CharClass.IsEmpty)
            {
                Initialize(CreateClassesFromDataset(ds));
            }
            if (use_junk/*&& !JunkClass.IsEmpty*/)
            {
                Intarray nonjunk = new Intarray();
                for (int i = 0; i < ds.nSamples(); i++)
                    if (ds.Cls(i) != jc())
                        nonjunk.Push(i);
                Datasubset nonjunkds = new Datasubset(ds, nonjunk);
                CharClass.TrainDense(nonjunkds, PGeti("epochs"));
            }
            else
            {
                CharClass.TrainDense(ds, PGeti("epochs"));
            }

            if (use_junk /*&& !JunkClass.IsEmpty*/)
            {
                Global.Debugf("info", "Training junk classifier");
                Intarray isjunk = new Intarray();
                int njunk = 0;
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    bool j = (ds.Cls(i) == jc());
                    isjunk.Push(JunkClass.Classes[Convert.ToInt32(j)]);
                    if (j) njunk++;
                }
                if (njunk > 0)
                {
                    MappedDataset junkds = new MappedDataset(ds, isjunk);
                    JunkClass.TrainDense(junkds, PGeti("epochs"));
                }
                else
                {
                    Global.Debugf("warn", "you are training a junk class but there are no samples to train on");
                    JunkClass.DeleteLenet();
                }
            }
            PSet("%nsamples", nsamples);
        }
示例#5
0
 public Datasubset(IDataset ds, Intarray samples)
 {
     this._ds = ds;
     this._samples = samples;
 }
示例#6
0
 protected override void Train(IDataset ds)
 {
     if (!(ds.nSamples() > 0))
         throw new Exception("nSamples of IDataset must be > 0");
     if (!(ds.nFeatures() > 0))
         throw new Exception("nFeatures of IDataset must be > 0");
     if (c2i.Length() < 1)
     {
         Intarray raw_classes = new Intarray();
         raw_classes.ReserveTo(ds.nSamples());
         for (int i = 0; i < ds.nSamples(); i++)
             raw_classes.Push(ds.Cls(i));
         ClassMap(c2i, i2c, raw_classes);
         /*Intarray classes = new Intarray();
         ctranslate(classes, raw_classes, c2i);*/
         //debugf("info","[mapped %d to %d classes]\n",c2i.length(),i2c.length());
     }
     TranslatedDataset mds = new TranslatedDataset(ds, c2i);
     TrainDense(mds);
 }
示例#7
0
文件: IBatch.cs 项目: nickun/OCRonet
 protected override void Train(IDataset dataset)
 {
     throw new NotImplementedException();
 }
示例#8
0
 public TranslatedDataset(IDataset ds, Intarray c2i)
 {
     _ds = ds;
     _c2i = c2i;
     _nc = NarrayUtil.Max(c2i) + 1;
 }
示例#9
0
 public void InitData(IDataset ds, int nhidden, Intarray newc2i = null, Intarray newi2c = null)
 {
     CHECK_ARG(nhidden > 1 && nhidden < 1000000, "nhidden > 1 && nhidden < 1000000");
     int ninput = ds.nFeatures();
     int noutput = ds.nClasses();
     w1.Resize(nhidden, ninput);
     b1.Resize(nhidden);
     w2.Resize(noutput, nhidden);
     b2.Resize(noutput);
     Intarray indexes = new Intarray();
     NarrayUtil.RPermutation(indexes, ds.nSamples());
     Floatarray v = new Floatarray();
     for (int i = 0; i < w1.Dim(0); i++)
     {
         int row = indexes[i];
         ds.Input1d(v, row);
         float normv = (float)NarrayUtil.Norm2(v);
         v /= normv * normv;
         NarrayRowUtil.RowPut(w1, i, v);
     }
     ClassifierUtil.fill_random(b1, -1e-6f, 1e-6f);
     ClassifierUtil.fill_random(w2, -1.0f / nhidden, 1.0f / nhidden);
     ClassifierUtil.fill_random(b2, -1e-6f, 1e-6f);
     if (newc2i != null)
         c2i.Copy(newc2i);
     if (newi2c != null)
         i2c.Copy(newi2c);
 }
示例#10
0
        public virtual void TrainBatch(IDataset ds, IDataset ts)
        {
            Stopwatch sw = Stopwatch.StartNew();
            bool parallel = PGetb("parallel");
            float eta_init = PGetf("eta_init"); // 0.5
            float eta_varlog = PGetf("eta_varlog"); // 1.5
            float hidden_varlog = PGetf("hidden_varlog"); // 1.2
            int hidden_lo = PGeti("hidden_lo");
            int hidden_hi = PGeti("hidden_hi");
            int rounds = PGeti("rounds");
            int mlp_noopt = PGeti("noopt");
            int hidden_min = PGeti("hidden_min");
            int hidden_max = PGeti("hidden_max");
            CHECK_ARG(hidden_min > 1 && hidden_max < 1000000, "hidden_min > 1 && hidden_max < 1000000");
            CHECK_ARG(hidden_hi >= hidden_lo, "hidden_hi >= hidden_lo");
            CHECK_ARG(hidden_max >= hidden_min, "hidden_max >= hidden_min");
            CHECK_ARG(hidden_lo >= hidden_min && hidden_hi <= hidden_max, "hidden_lo >= hidden_min && hidden_hi <= hidden_max");
            int nn = PGeti("nensemble");
            ObjList<MlpClassifier> nets = new ObjList<MlpClassifier>();
            nets.Resize(nn);
            for (int i = 0; i < nn; i++)
                nets[i] = new MlpClassifier(i);
            Floatarray errs = new Floatarray(nn);
            Floatarray etas = new Floatarray(nn);
            Intarray index = new Intarray();
            float best = 1e30f;
            if (PExists("%error"))
                best = PGetf("%error");
            int nclasses = ds.nClasses();

            /*Floatarray v = new Floatarray();
            for (int i = 0; i < ds.nSamples(); i++)
            {
                ds.Input1d(v, i);
                CHECK_ARG(NarrayUtil.Min(v) > -100 && NarrayUtil.Max(v) < 100, "min(v)>-100 && max(v)<100");
            }*/
            CHECK_ARG(ds.nSamples() >= 10 && ds.nSamples() < 100000000, "ds.nSamples() >= 10 && ds.nSamples() < 100000000");

            for (int i = 0; i < nn; i++)
            {
                // nets(i).init(data.dim(1),logspace(i,nn,hidden_lo,hidden_hi),nclasses);
                if (w1.Length() > 0)
                {
                    nets[i].Copy(this);
                    etas[i] = ClassifierUtil.rLogNormal(eta_init, eta_varlog);
                }
                else
                {
                    nets[i].InitData(ds, (int)(logspace(i, nn, hidden_lo, hidden_hi)), c2i, i2c);
                    etas[i] = PGetf("eta");
                }
            }
            etas[0] = PGetf("eta");     // zero position is identical to itself

            Global.Debugf("info", "mlp training n {0} nc {1}", ds.nSamples(), nclasses);
            for (int round = 0; round < rounds; round++)
            {
                Stopwatch swRound = Stopwatch.StartNew();
                errs.Fill(-1);
                if (parallel)
                {
                    // For each network i
                    Parallel.For(0, nn, i =>
                    {
                        nets[i].PSet("eta", etas[i]);
                        nets[i].TrainDense(ds);     // было XTrain
                        errs[i] = ClassifierUtil.estimate_errors(nets[i], ts);
                    });
                }
                else
                {
                    for (int i = 0; i < nn; i++)
                    {
                        nets[i].PSet("eta", etas[i]);
                        nets[i].TrainDense(ds);     // было XTrain
                        errs[i] = ClassifierUtil.estimate_errors(nets[i], ts);
                        //Global.Debugf("detail", "net({0}) {1} {2} {3}", i,
                        //       errs[i], nets[i].Complexity(), etas[i]);
                    }
                }
                NarrayUtil.Quicksort(index, errs);
                if (errs[index[0]] < best)
                {
                    best = errs[index[0]];
                    cv_error = best;
                    this.Copy(nets[index[0]]);
                    this.PSet("eta", etas[index[0]]);
                    Global.Debugf("info", "  best mlp[{0}] update errors={1} {2}", index[0], best, crossvalidate ? "cv" : "");
                }
                if (mlp_noopt == 0)
                {
                    for (int i = 0; i < nn / 2; i++)
                    {
                        int j = i + nn / 2;
                        nets[index[j]].Copy(nets[index[i]]);
                        int n = nets[index[j]].nHidden();
                        int nm = Math.Min(Math.Max(hidden_min, (int)(ClassifierUtil.rLogNormal(n, hidden_varlog))), hidden_max);
                        nets[index[j]].ChangeHidden(nm);
                        etas[index[j]] = ClassifierUtil.rLogNormal(etas[index[i]], eta_varlog);
                    }
                }
                Global.Debugf("info", " end mlp round {0} err {1} nHidden {2}", round, best, nHidden());
                swRound.Stop();
                int totalTest= ts.nSamples();
                int errCnt = Convert.ToInt32(best * totalTest);
                OnTrainRound(this, new TrainEventArgs(
                    round, best, totalTest - errCnt, totalTest, best, swRound.Elapsed, TimeSpan.Zero
                    ));
            }

            sw.Stop();
            Global.Debugf("info", String.Format("          training time: {0} minutes, {1} seconds",
                (int)sw.Elapsed.TotalMinutes, sw.Elapsed.Seconds));
            PSet("%error", best);
            int nsamples = ds.nSamples() * rounds;
            if (PExists("%nsamples"))
                nsamples += PGeti("%nsamples");
            PSet("%nsamples", nsamples);
        }
示例#11
0
 public override void TrainDense(IDataset ds)
 {
     //PSet("%nsamples", ds.nSamples());
     float split = PGetf("cv_split");
     int mlp_cv_max = PGeti("cv_max");
     if (crossvalidate)
     {
         // perform a split for cross-validation, making sure
         // that we don't have the same sample in both the
         // test and the training set (even if the data set
         // is the result of resampling)
         Intarray test_ids = new Intarray();
         Intarray ids = new Intarray();
         for (int i = 0; i < ds.nSamples(); i++)
             ids.Push(ds.Id(i));
         NarrayUtil.Uniq(ids);
         Global.Debugf("cvdetail", "reduced {0} ids to {1} ids", ds.nSamples(), ids.Length());
         NarrayUtil.Shuffle(ids);
         int nids = (int)((1.0 - split) * ids.Length());
         nids = Math.Min(nids, mlp_cv_max);
         for (int i = 0; i < nids; i++)
             test_ids.Push(ids[i]);
         NarrayUtil.Quicksort(test_ids);
         Intarray training = new Intarray();
         Intarray testing = new Intarray();
         for (int i = 0; i < ds.nSamples(); i++)
         {
             int id = ds.Id(i);
             if (ClassifierUtil.Bincontains(test_ids, id))
                 testing.Push(i);
             else
                 training.Push(i);
         }
         Global.Debugf("cvdetail", "#training {0} #testing {1}",
                training.Length(), testing.Length());
         PSet("%ntraining", training.Length());
         PSet("%ntesting", testing.Length());
         Datasubset trs = new Datasubset(ds, training);
         Datasubset tss = new Datasubset(ds, testing);
         TrainBatch(trs, tss);
     }
     else
     {
         TrainBatch(ds, ds);
     }
 }
示例#12
0
        public virtual TrainInfo TrainBatch(IDataset ds, IDataset ts, int epochs)
        {
            TrainInfo trInfo = new TrainInfo();
            bool first = true;
            Floatarray v = new Floatarray();

            // Send Train Dataset to Lenet
            for (int i = 0; i < ds.nSamples(); i++)
            {
                ds.Input(v, i);
                if (v.Rank() == 1)
                    v.Reshape(csize, csize, 0, 0);
                StdInput linput = new StdInput(v);
                try
                {
                    if (first)
                    {
                        first = false;
                        //StartRedirectStdout();  // start redirect cout to string buffer
                        BeginTrainEpoch(HrLenet, linput.Height, linput.Width, ds.nSamples(), ts.nSamples());   // init train
                    }
                    if (C2i.ContainsKey(ds.Cls(i)))
                        AddSampleToTrainOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ds.Cls(i)]);
                    else
                        Global.Debugf("error", "class '{0}' is not contained in the char list", (char)ds.Cls(i));
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTrainOfEpoch");
                    throw new Exception("Exception in AddSampleToTrainOfEpoch\r\n" + e.Message);
                }
            }

            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                    v.Reshape(csize, csize, 0, 0);
                StdInput linput = new StdInput(v);
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // debug save mnist
            //SaveTrainMnist(HrLenet, "debug-images-idx3-ubyte", "debug-labels-idx1-ubyte");

            // do train epochs
            for (int epoch = 0; epoch < epochs; epoch++)
            {
                trInfo = new TrainInfo();
                try
                {
                    DateTime startDate = DateTime.Now;
                    EndAndRunTrainEpoch(HrLenet, ref trInfo);   // do train one epoch
                    // show train info
                    Global.Debugf("info",
                        String.Format("|{0,7}| Energy:{1:0.#####} Correct:{2:0.00#%} Errors:{3:0.00#%} Count:{4} ",
                            trInfo.age, trInfo.energy, (trInfo.correct / (float)trInfo.size),
                            (trInfo.error / (float)trInfo.size), trInfo.size) );
                    Global.Debugf("info",
                        String.Format("     TEST Energy={0:0.#####} Correct={1:0.00#%} Errors={2:0.00#%} Count={3} ",
                            trInfo.tenergy, (trInfo.tcorrect / (float)trInfo.tsize),
                            (trInfo.terror / (float)trInfo.tsize), trInfo.tsize));
                    TimeSpan spanTrain = DateTime.Now - startDate;
                    Global.Debugf("info", String.Format("          training time: {0} minutes, {1} seconds",
                        (int)spanTrain.TotalMinutes, spanTrain.Seconds));

                    // get dll stdout messages
                    GetStdout(sbout);
                    if (sbout.Length > 0)
                        Console.Write(sbout.ToString());
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTrainEpoch");
                    throw new Exception("Exception in EndAndRunTrainEpoch\r\n" + e.Message);
                }
            }
            return trInfo;
        }
示例#13
0
        public TrainInfo TestDense(IDataset ts)
        {
            TrainInfo trInfo = new TrainInfo();
            bool first = true;
            Floatarray v = new Floatarray();
            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                    v.Reshape(csize, csize, 0, 0);
                StdInput linput = new StdInput(v);
                if (first)
                {
                    first = false;
                    BeginTestEpoch(HrLenet, linput.Height, linput.Width, ts.nSamples());   // init test
                }
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // do test one epoch
            try
            {
                EndAndRunTestEpoch(HrLenet, ref trInfo);
            }
            catch (Exception e)
            {
                GetStdout(sbout);   // get test messages
                if (sbout.Length > 0)
                    Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTestEpoch");
                throw new Exception("Exception in EndAndRunTestEpoch\r\n" + e.Message);
            }

            return trInfo;
        }
示例#14
0
文件: IModel.cs 项目: nickun/OCRonet
 protected virtual void Train(IDataset ds)
 {
     Floatarray v = new Floatarray();
     for (int i = 0; i < ds.nSamples(); i++)
     {
         ds.Input(v, i);
         Add(v, ds.Cls(i));
     }
 }
示例#15
0
 public override void TrainDense(IDataset ds)
 {
     int nclasses = ds.nClasses();
     float miters = PGetf("miters");
     int niters = (int)(ds.nSamples() * miters);
     niters = Math.Max(1000, Math.Min(10000000,niters));
     double err = 0.0;
     Floatarray x = new Floatarray();
     Floatarray z = new Floatarray();
     Floatarray target = new Floatarray(nclasses);
     int count = 0;
     for (int i = 0; i < niters; i++)
     {
         int row = i % ds.nSamples();
         ds.Output(target, row);
         ds.Input1d(x, row);
         TrainOne(z, target, x, PGetf("eta"));
         err += NarrayUtil.Dist2Squared(z, target);
         count++;
     }
     err /= count;
     Global.Debugf("info", "   {4} n {0} niters={1} eta={2:0.#####} errors={3:0.########}",
            ds.nSamples(), niters, PGetf("eta"), err, FullName);
 }
示例#16
0
 public MappedDataset(IDataset ds, Intarray classes)
 {
     this._ds = ds;
     this._classes = classes;
 }
示例#17
0
 public virtual void TrainDense(IDataset dataset)
 {
     throw new NotImplementedException();
 }
示例#18
0
        protected override void Train(IDataset ds)
        {
            bool use_junk = PGetb("junk") && !DisableJunk;

            if (charclass.IsEmpty)
            {
                charclass.SetComponent(ComponentCreator.MakeComponent(PGet("charclass")));
                TryAttachCharClassifierEvent(charclass.Object);
            }
            if (junkclass.IsEmpty)
            {
                junkclass.SetComponent(ComponentCreator.MakeComponent(PGet("junkclass")));
                TryAttachJunkClassifierEvent(junkclass.Object);
            }
            if (ulclass.IsEmpty)
                ulclass.SetComponent(ComponentCreator.MakeComponent(PGet("ulclass")));

            Global.Debugf("info", "Training content classifier");
            if (use_junk && !junkclass.IsEmpty)
            {
                Intarray nonjunk = new Intarray();
                for (int i = 0; i < ds.nSamples(); i++)
                    if (ds.Cls(i) != jc())
                        nonjunk.Push(i);
                Datasubset nonjunkds = new Datasubset(ds, nonjunk);
                charclass.Object.XTrain(nonjunkds);
            }
            else
            {
                charclass.Object.XTrain(ds);
            }

            if (use_junk && !junkclass.IsEmpty)
            {
                Global.Debugf("info", "Training junk classifier");
                Intarray isjunk = new Intarray();
                int njunk = 0;
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    bool j = (ds.Cls(i) == jc());
                    isjunk.Push(Convert.ToInt32(j));
                    if (j) njunk++;
                }
                if (njunk > 0)
                {
                    MappedDataset junkds = new MappedDataset(ds, isjunk);
                    junkclass.Object.XTrain(junkds);
                }
                else
                {
                    Global.Debugf("warn", "you are training a junk class but there are no samples to train on");
                    junkclass.SetComponent(null);
                }

                if (PGeti("ul") > 0 && !ulclass.IsEmpty)
                {
                    throw new Exception("ulclass not implemented");
                }
            }
        }
示例#19
0
 public ExtractedDataset(IDataset ds, IExtractor ex)
 {
     _ds = ds;
     _ex = ex;
 }