コード例 #1
0
        public override void TrainDense(IDataset ds)
        {
            //PSet("%nsamples", ds.nSamples());
            float split      = PGetf("cv_split");
            int   mlp_cv_max = PGeti("cv_max");

            if (crossvalidate)
            {
                // perform a split for cross-validation, making sure
                // that we don't have the same sample in both the
                // test and the training set (even if the data set
                // is the result of resampling)
                Intarray test_ids = new Intarray();
                Intarray ids      = new Intarray();
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    ids.Push(ds.Id(i));
                }
                NarrayUtil.Uniq(ids);
                Global.Debugf("cvdetail", "reduced {0} ids to {1} ids", ds.nSamples(), ids.Length());
                NarrayUtil.Shuffle(ids);
                int nids = (int)((1.0 - split) * ids.Length());
                nids = Math.Min(nids, mlp_cv_max);
                for (int i = 0; i < nids; i++)
                {
                    test_ids.Push(ids[i]);
                }
                NarrayUtil.Quicksort(test_ids);
                Intarray training = new Intarray();
                Intarray testing  = new Intarray();
                for (int i = 0; i < ds.nSamples(); i++)
                {
                    int id = ds.Id(i);
                    if (ClassifierUtil.Bincontains(test_ids, id))
                    {
                        testing.Push(i);
                    }
                    else
                    {
                        training.Push(i);
                    }
                }
                Global.Debugf("cvdetail", "#training {0} #testing {1}",
                              training.Length(), testing.Length());
                PSet("%ntraining", training.Length());
                PSet("%ntesting", testing.Length());
                Datasubset trs = new Datasubset(ds, training);
                Datasubset tss = new Datasubset(ds, testing);
                TrainBatch(trs, tss);
            }
            else
            {
                TrainBatch(ds, ds);
            }
        }
コード例 #2
0
ファイル: SegmRoutine.cs プロジェクト: liaoheping/OCRonet
        public static void line_segmentation_sort_x(Intarray segmentation)
        {
            if (NarrayUtil.Max(segmentation) > 100000)
            {
                throw new Exception("line_segmentation_merge_small_components: to many segments");
            }
            Narray <Rect> bboxes = new Narray <Rect>();

            ImgLabels.bounding_boxes(ref bboxes, segmentation);
            Floatarray x0s = new Floatarray();

            unchecked
            {
                x0s.Push((float)-999999);
            }
            for (int i = 1; i < bboxes.Length(); i++)
            {
                if (bboxes[i].Empty())
                {
                    x0s.Push(999999);
                }
                else
                {
                    x0s.Push(bboxes[i].x0);
                }
            }
            // dprint(x0s,1000); printf("\n");
            Narray <int> permutation  = new Intarray();
            Narray <int> rpermutation = new Intarray();

            NarrayUtil.Quicksort(permutation, x0s);
            rpermutation.Resize(permutation.Length());
            for (int i = 0; i < permutation.Length(); i++)
            {
                rpermutation[permutation[i]] = i;
            }
            // dprint(rpermutation,1000); printf("\n");
            for (int i = 0; i < segmentation.Length1d(); i++)
            {
                if (segmentation.At1d(i) == 0)
                {
                    continue;
                }
                segmentation.Put1d(i, rpermutation[segmentation.At1d(i)]);
            }
        }
コード例 #3
0
        private void achieve(int flag)
        {
            if (!(flag == SORTED_BY_INPUT ||
                  flag == SORTED_BY_OUTPUT ||
                  flag == HAS_HEURISTICS))
            {
                throw new Exception("CHECK_ARG: flag == SORTED_BY_INPUT || flag == SORTED_BY_OUTPUT || flag == HAS_HEURISTICS");
            }

            if (flags > 0 & flag > 0)
            {
                return;
            }

            if (flag == HAS_HEURISTICS)
            {
                AStarUtil.a_star_backwards(m_heuristics, this);
                return;
            }

            for (int node = 0; node < nStates(); node++)
            {
                Intarray permutation = new Intarray();
                if (flag == OcroFST.SORTED_BY_INPUT)
                {
                    NarrayUtil.Quicksort(permutation, m_inputs[node]);
                }
                else
                {
                    NarrayUtil.Quicksort(permutation, m_outputs[node]);
                }
                NarrayUtil.Permute(m_inputs[node], permutation);
                NarrayUtil.Permute(m_outputs[node], permutation);
                NarrayUtil.Permute(m_targets[node], permutation);
                NarrayUtil.Permute(m_costs[node], permutation);
            }
            flags |= flag;
        }
コード例 #4
0
        public override void Arcs(Intarray ids, Intarray targets, Intarray outputs, Floatarray costs, int node)
        {
            int        n1   = node / l2.nStates();
            int        n2   = node % l2.nStates();
            Intarray   ids1 = new Intarray();
            Intarray   ids2 = new Intarray();
            Intarray   t1   = new Intarray();
            Intarray   t2   = new Intarray();
            Intarray   o1   = new Intarray();
            Intarray   o2   = new Intarray();
            Floatarray c1   = new Floatarray();
            Floatarray c2   = new Floatarray();

            l1.Arcs(ids1, t1, o1, c1, n1);
            l2.Arcs(ids2, t2, o2, c2, n2);

            // sort & permute
            Intarray p1 = new Intarray();
            Intarray p2 = new Intarray();

            NarrayUtil.Quicksort(p1, o1);
            NarrayUtil.Permute(ids1, p1);
            NarrayUtil.Permute(t1, p1);
            NarrayUtil.Permute(o1, p1);
            NarrayUtil.Permute(c1, p1);

            NarrayUtil.Quicksort(p2, ids2);
            NarrayUtil.Permute(ids2, p2);
            NarrayUtil.Permute(t2, p2);
            NarrayUtil.Permute(o2, p2);
            NarrayUtil.Permute(c2, p2);

            int k1, k2;

            // l1 epsilon moves
            for (k1 = 0; k1 < o1.Length() && o1.At1d(k1) == 0; k1++)
            {
                ids.Push(ids1.At1d(k1));
                targets.Push(Combine(t1.At1d(k1), n2));
                outputs.Push(0);
                costs.Push(c1.At1d(k1));
            }
            // l2 epsilon moves
            for (k2 = 0; k2 < o2.Length() && ids2.At1d(k2) == 0; k2++)
            {
                ids.Push(0);
                targets.Push(Combine(n1, t2.At1d(k2)));
                outputs.Push(o2.At1d(k2));
                costs.Push(c2.At1d(k2));
            }
            // non-epsilon moves
            while (k1 < o1.Length() && k2 < ids2.Length())
            {
                while (k1 < o1.Length() && o1.At1d(k1) < ids2.At1d(k2))
                {
                    k1++;
                }
                if (k1 >= o1.Length())
                {
                    break;
                }
                while (k2 < ids2.Length() && o1.At1d(k1) > ids2.At1d(k2))
                {
                    k2++;
                }
                while (k1 < o1.Length() && k2 < ids2.Length() && o1.At1d(k1) == ids2.At1d(k2))
                {
                    for (int j = k2; j < ids2.Length() && o1.At1d(k1) == ids2.At1d(j); j++)
                    {
                        ids.Push(ids1.At1d(k1));
                        targets.Push(Combine(t1.At1d(k1), t2.At1d(j)));
                        outputs.Push(o2.At1d(j));
                        costs.Push(c1.At1d(k1) + c2.At1d(j));
                    }
                    k1++;
                }
            }
        }
コード例 #5
0
        public virtual void TrainBatch(IDataset ds, IDataset ts)
        {
            Stopwatch sw            = Stopwatch.StartNew();
            bool      parallel      = PGetb("parallel");
            float     eta_init      = PGetf("eta_init");      // 0.5
            float     eta_varlog    = PGetf("eta_varlog");    // 1.5
            float     hidden_varlog = PGetf("hidden_varlog"); // 1.2
            int       hidden_lo     = PGeti("hidden_lo");
            int       hidden_hi     = PGeti("hidden_hi");
            int       rounds        = PGeti("rounds");
            int       mlp_noopt     = PGeti("noopt");
            int       hidden_min    = PGeti("hidden_min");
            int       hidden_max    = PGeti("hidden_max");

            CHECK_ARG(hidden_min > 1 && hidden_max < 1000000, "hidden_min > 1 && hidden_max < 1000000");
            CHECK_ARG(hidden_hi >= hidden_lo, "hidden_hi >= hidden_lo");
            CHECK_ARG(hidden_max >= hidden_min, "hidden_max >= hidden_min");
            CHECK_ARG(hidden_lo >= hidden_min && hidden_hi <= hidden_max, "hidden_lo >= hidden_min && hidden_hi <= hidden_max");
            int nn = PGeti("nensemble");
            ObjList <MlpClassifier> nets = new ObjList <MlpClassifier>();

            nets.Resize(nn);
            for (int i = 0; i < nn; i++)
            {
                nets[i] = new MlpClassifier(i);
            }
            Floatarray errs  = new Floatarray(nn);
            Floatarray etas  = new Floatarray(nn);
            Intarray   index = new Intarray();
            float      best  = 1e30f;

            if (PExists("%error"))
            {
                best = PGetf("%error");
            }
            int nclasses = ds.nClasses();

            /*Floatarray v = new Floatarray();
             * for (int i = 0; i < ds.nSamples(); i++)
             * {
             *  ds.Input1d(v, i);
             *  CHECK_ARG(NarrayUtil.Min(v) > -100 && NarrayUtil.Max(v) < 100, "min(v)>-100 && max(v)<100");
             * }*/
            CHECK_ARG(ds.nSamples() >= 10 && ds.nSamples() < 100000000, "ds.nSamples() >= 10 && ds.nSamples() < 100000000");

            for (int i = 0; i < nn; i++)
            {
                // nets(i).init(data.dim(1),logspace(i,nn,hidden_lo,hidden_hi),nclasses);
                if (w1.Length() > 0)
                {
                    nets[i].Copy(this);
                    etas[i] = ClassifierUtil.rLogNormal(eta_init, eta_varlog);
                }
                else
                {
                    nets[i].InitData(ds, (int)(logspace(i, nn, hidden_lo, hidden_hi)), c2i, i2c);
                    etas[i] = PGetf("eta");
                }
            }
            etas[0] = PGetf("eta");     // zero position is identical to itself

            Global.Debugf("info", "mlp training n {0} nc {1}", ds.nSamples(), nclasses);
            for (int round = 0; round < rounds; round++)
            {
                Stopwatch swRound = Stopwatch.StartNew();
                errs.Fill(-1);
                if (parallel)
                {
                    // For each network i
                    Parallel.For(0, nn, i =>
                    {
                        nets[i].PSet("eta", etas[i]);
                        nets[i].TrainDense(ds);     // было XTrain
                        errs[i] = ClassifierUtil.estimate_errors(nets[i], ts);
                    });
                }
                else
                {
                    for (int i = 0; i < nn; i++)
                    {
                        nets[i].PSet("eta", etas[i]);
                        nets[i].TrainDense(ds);     // было XTrain
                        errs[i] = ClassifierUtil.estimate_errors(nets[i], ts);
                        //Global.Debugf("detail", "net({0}) {1} {2} {3}", i,
                        //       errs[i], nets[i].Complexity(), etas[i]);
                    }
                }
                NarrayUtil.Quicksort(index, errs);
                if (errs[index[0]] < best)
                {
                    best     = errs[index[0]];
                    cv_error = best;
                    this.Copy(nets[index[0]]);
                    this.PSet("eta", etas[index[0]]);
                    Global.Debugf("info", "  best mlp[{0}] update errors={1} {2}", index[0], best, crossvalidate ? "cv" : "");
                }
                if (mlp_noopt == 0)
                {
                    for (int i = 0; i < nn / 2; i++)
                    {
                        int j = i + nn / 2;
                        nets[index[j]].Copy(nets[index[i]]);
                        int n  = nets[index[j]].nHidden();
                        int nm = Math.Min(Math.Max(hidden_min, (int)(ClassifierUtil.rLogNormal(n, hidden_varlog))), hidden_max);
                        nets[index[j]].ChangeHidden(nm);
                        etas[index[j]] = ClassifierUtil.rLogNormal(etas[index[i]], eta_varlog);
                    }
                }
                Global.Debugf("info", " end mlp round {0} err {1} nHidden {2}", round, best, nHidden());
                swRound.Stop();
                int totalTest = ts.nSamples();
                int errCnt    = Convert.ToInt32(best * totalTest);
                OnTrainRound(this, new TrainEventArgs(
                                 round, best, totalTest - errCnt, totalTest, best, swRound.Elapsed, TimeSpan.Zero
                                 ));
            }

            sw.Stop();
            Global.Debugf("info", String.Format("          training time: {0} minutes, {1} seconds",
                                                (int)sw.Elapsed.TotalMinutes, sw.Elapsed.Seconds));
            PSet("%error", best);
            int nsamples = ds.nSamples() * rounds;

            if (PExists("%nsamples"))
            {
                nsamples += PGeti("%nsamples");
            }
            PSet("%nsamples", nsamples);
        }