예제 #1
0
        public TrainInfo TestDense(IDataset ts)
        {
            TrainInfo  trInfo = new TrainInfo();
            bool       first  = true;
            Floatarray v      = new Floatarray();

            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                {
                    v.Reshape(csize, csize, 0, 0);
                }
                StdInput linput = new StdInput(v);
                if (first)
                {
                    first = false;
                    BeginTestEpoch(HrLenet, linput.Height, linput.Width, ts.nSamples());   // init test
                }
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                    {
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    }
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // do test one epoch
            try
            {
                EndAndRunTestEpoch(HrLenet, ref trInfo);
            }
            catch (Exception e)
            {
                GetStdout(sbout);   // get test messages
                if (sbout.Length > 0)
                {
                    Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTestEpoch");
                }
                throw new Exception("Exception in EndAndRunTestEpoch\r\n" + e.Message);
            }

            return(trInfo);
        }
예제 #2
0
        protected override float Outputs(OutputVector result, Floatarray v)
        {
            result.Clear();
            if (v.Rank() == 1)
            {
                v.Reshape(csize, csize, 0, 0);
            }
            // byte array input
            StdInput vinput = new StdInput(v);

            byte[] buffer = vinput.GetDataBuffer();

            // char classifier compute outputs
            if (CharClass.AsciiTarget)
            {
                // net output 0..~ (lower - winner)
                CharClass.ComputeOutputs(buffer, vinput.Length, vinput.Height, vinput.Width, result);
            }
            else
            {
                // net output 0..1; (higher - winner)
                CharClass.ComputeOutputsRaw(buffer, vinput.Length, vinput.Height, vinput.Width, result);
            }

            // junk classifier
            if (PGetb("junk") && !DisableJunk && !JunkClass.IsEmpty)
            {
                OutputVector jv = new OutputVector();
                if (JunkClass.AsciiTarget)
                {
                    JunkClass.ComputeOutputs(buffer, vinput.Length, vinput.Height, vinput.Width, jv);
                    result[jc()] = jv.Value(1);
                }
                else
                {
                    //result.Normalize();
                    result.Normalize2();
                    JunkClass.ComputeOutputsRaw(buffer, vinput.Length, vinput.Height, vinput.Width, jv);
                    jv.Normalize2();
                    for (int i = 0; i < result.nKeys(); i++)
                    {
                        result.Values[i] *= jv.Value(0);
                    }
                    result[jc()] = jv.Value(1);
                }
            }

            return(0.0f);
        }
예제 #3
0
        public virtual TrainInfo TrainBatch(IDataset ds, IDataset ts, int epochs)
        {
            TrainInfo  trInfo = new TrainInfo();
            bool       first  = true;
            Floatarray v      = new Floatarray();

            // Send Train Dataset to Lenet
            for (int i = 0; i < ds.nSamples(); i++)
            {
                ds.Input(v, i);
                if (v.Rank() == 1)
                {
                    v.Reshape(csize, csize, 0, 0);
                }
                StdInput linput = new StdInput(v);
                try
                {
                    if (first)
                    {
                        first = false;
                        //StartRedirectStdout();  // start redirect cout to string buffer
                        BeginTrainEpoch(HrLenet, linput.Height, linput.Width, ds.nSamples(), ts.nSamples());   // init train
                    }
                    if (C2i.ContainsKey(ds.Cls(i)))
                    {
                        AddSampleToTrainOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ds.Cls(i)]);
                    }
                    else
                    {
                        Global.Debugf("error", "class '{0}' is not contained in the char list", (char)ds.Cls(i));
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTrainOfEpoch");
                    }
                    throw new Exception("Exception in AddSampleToTrainOfEpoch\r\n" + e.Message);
                }
            }

            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                {
                    v.Reshape(csize, csize, 0, 0);
                }
                StdInput linput = new StdInput(v);
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                    {
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    }
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // debug save mnist
            //SaveTrainMnist(HrLenet, "debug-images-idx3-ubyte", "debug-labels-idx1-ubyte");

            // do train epochs
            for (int epoch = 0; epoch < epochs; epoch++)
            {
                trInfo = new TrainInfo();
                try
                {
                    DateTime startDate = DateTime.Now;
                    EndAndRunTrainEpoch(HrLenet, ref trInfo);   // do train one epoch
                    // show train info
                    Global.Debugf("info",
                                  String.Format("|{0,7}| Energy:{1:0.#####} Correct:{2:0.00#%} Errors:{3:0.00#%} Count:{4} ",
                                                trInfo.age, trInfo.energy, (trInfo.correct / (float)trInfo.size),
                                                (trInfo.error / (float)trInfo.size), trInfo.size));
                    Global.Debugf("info",
                                  String.Format("     TEST Energy={0:0.#####} Correct={1:0.00#%} Errors={2:0.00#%} Count={3} ",
                                                trInfo.tenergy, (trInfo.tcorrect / (float)trInfo.tsize),
                                                (trInfo.terror / (float)trInfo.tsize), trInfo.tsize));
                    TimeSpan spanTrain = DateTime.Now - startDate;
                    Global.Debugf("info", String.Format("          training time: {0} minutes, {1} seconds",
                                                        (int)spanTrain.TotalMinutes, spanTrain.Seconds));

                    // get dll stdout messages
                    GetStdout(sbout);
                    if (sbout.Length > 0)
                    {
                        Console.Write(sbout.ToString());
                    }
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                    {
                        Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTrainEpoch");
                    }
                    throw new Exception("Exception in EndAndRunTrainEpoch\r\n" + e.Message);
                }
            }
            return(trInfo);
        }
예제 #4
0
        protected override float Outputs(OutputVector result, Floatarray v)
        {
            result.Clear();
            if (v.Rank() == 1)
                v.Reshape(csize, csize, 0, 0);
            // byte array input
            StdInput vinput = new StdInput(v);
            byte[] buffer = vinput.GetDataBuffer();

            // char classifier compute outputs
            if (CharClass.AsciiTarget)
                // net output 0..~ (lower - winner)
                CharClass.ComputeOutputs(buffer, vinput.Length, vinput.Height, vinput.Width, result);
            else
                // net output 0..1; (higher - winner)
                CharClass.ComputeOutputsRaw(buffer, vinput.Length, vinput.Height, vinput.Width, result);

            // junk classifier
            if (PGetb("junk") && !DisableJunk && !JunkClass.IsEmpty)
            {
                OutputVector jv = new OutputVector();
                if (JunkClass.AsciiTarget)
                {
                    JunkClass.ComputeOutputs(buffer, vinput.Length, vinput.Height, vinput.Width, jv);
                    result[jc()] = jv.Value(1);
                }
                else
                {
                    //result.Normalize();
                    result.Normalize2();
                    JunkClass.ComputeOutputsRaw(buffer, vinput.Length, vinput.Height, vinput.Width, jv);
                    jv.Normalize2();
                    for (int i = 0; i < result.nKeys(); i++)
                        result.Values[i] *= jv.Value(0);
                    result[jc()] = jv.Value(1);
                }
            }

            return 0.0f;
        }
예제 #5
0
 public void Input1d(Floatarray v, int i)
 {
     Input(v, i);
     v.Reshape(v.Length());
 }
예제 #6
0
        public virtual TrainInfo TrainBatch(IDataset ds, IDataset ts, int epochs)
        {
            TrainInfo trInfo = new TrainInfo();
            bool first = true;
            Floatarray v = new Floatarray();

            // Send Train Dataset to Lenet
            for (int i = 0; i < ds.nSamples(); i++)
            {
                ds.Input(v, i);
                if (v.Rank() == 1)
                    v.Reshape(csize, csize, 0, 0);
                StdInput linput = new StdInput(v);
                try
                {
                    if (first)
                    {
                        first = false;
                        //StartRedirectStdout();  // start redirect cout to string buffer
                        BeginTrainEpoch(HrLenet, linput.Height, linput.Width, ds.nSamples(), ts.nSamples());   // init train
                    }
                    if (C2i.ContainsKey(ds.Cls(i)))
                        AddSampleToTrainOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ds.Cls(i)]);
                    else
                        Global.Debugf("error", "class '{0}' is not contained in the char list", (char)ds.Cls(i));
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTrainOfEpoch");
                    throw new Exception("Exception in AddSampleToTrainOfEpoch\r\n" + e.Message);
                }
            }

            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                    v.Reshape(csize, csize, 0, 0);
                StdInput linput = new StdInput(v);
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // debug save mnist
            //SaveTrainMnist(HrLenet, "debug-images-idx3-ubyte", "debug-labels-idx1-ubyte");

            // do train epochs
            for (int epoch = 0; epoch < epochs; epoch++)
            {
                trInfo = new TrainInfo();
                try
                {
                    DateTime startDate = DateTime.Now;
                    EndAndRunTrainEpoch(HrLenet, ref trInfo);   // do train one epoch
                    // show train info
                    Global.Debugf("info",
                        String.Format("|{0,7}| Energy:{1:0.#####} Correct:{2:0.00#%} Errors:{3:0.00#%} Count:{4} ",
                            trInfo.age, trInfo.energy, (trInfo.correct / (float)trInfo.size),
                            (trInfo.error / (float)trInfo.size), trInfo.size) );
                    Global.Debugf("info",
                        String.Format("     TEST Energy={0:0.#####} Correct={1:0.00#%} Errors={2:0.00#%} Count={3} ",
                            trInfo.tenergy, (trInfo.tcorrect / (float)trInfo.tsize),
                            (trInfo.terror / (float)trInfo.tsize), trInfo.tsize));
                    TimeSpan spanTrain = DateTime.Now - startDate;
                    Global.Debugf("info", String.Format("          training time: {0} minutes, {1} seconds",
                        (int)spanTrain.TotalMinutes, spanTrain.Seconds));

                    // get dll stdout messages
                    GetStdout(sbout);
                    if (sbout.Length > 0)
                        Console.Write(sbout.ToString());
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTrainEpoch");
                    throw new Exception("Exception in EndAndRunTrainEpoch\r\n" + e.Message);
                }
            }
            return trInfo;
        }
예제 #7
0
        public TrainInfo TestDense(IDataset ts)
        {
            TrainInfo trInfo = new TrainInfo();
            bool first = true;
            Floatarray v = new Floatarray();
            // Send Test Dataset to Lenet
            for (int i = 0; i < ts.nSamples(); i++)
            {
                ts.Input(v, i);
                if (v.Rank() == 1)
                    v.Reshape(csize, csize, 0, 0);
                StdInput linput = new StdInput(v);
                if (first)
                {
                    first = false;
                    BeginTestEpoch(HrLenet, linput.Height, linput.Width, ts.nSamples());   // init test
                }
                try
                {
                    if (C2i.ContainsKey(ts.Cls(i)))
                        AddSampleToTestOfEpoch(HrLenet, linput.GetDataBuffer(), linput.Length, C2i[ts.Cls(i)]);
                }
                catch (Exception e)
                {
                    GetStdout(sbout);   // get train messages
                    if (sbout.Length > 0)
                        Global.Debugf("error", sbout.ToString() + "\r\nException in AddSampleToTestOfEpoch");
                    throw new Exception("Exception in AddSampleToTestOfEpoch\r\n" + e.Message);
                }
            }

            // do test one epoch
            try
            {
                EndAndRunTestEpoch(HrLenet, ref trInfo);
            }
            catch (Exception e)
            {
                GetStdout(sbout);   // get test messages
                if (sbout.Length > 0)
                    Global.Debugf("error", sbout.ToString() + "\r\nException in EndAndRunTestEpoch");
                throw new Exception("Exception in EndAndRunTestEpoch\r\n" + e.Message);
            }

            return trInfo;
        }
예제 #8
0
파일: IDataset.cs 프로젝트: nickun/OCRonet
 public void Input1d(Floatarray v, int i)
 {
     Input(v, i);
     v.Reshape(v.Length());
 }