コード例 #1
0
ファイル: Program.cs プロジェクト: play3577/mxnet.csharp
        private static CustomMetricResult Accuracy(SingleNArray label, SingleNArray pred, int batchSize)
        {
            int hit = 0;

            for (int i = 0; i < batchSize; i++)
            {
                var l = label[(Slice)i].Data;

                IList <int> p = new List <int>();
                for (int k = 0; k < 4; k++)
                {
                    p.Add((int)pred[(Slice)(k * batchSize + i)].Argmax());
                }

                if (l.Length == p.Count)
                {
                    var match = true;
                    for (int k = 0; k < p.Count; k++)
                    {
                        if (p[k] != (int)(l[k]))
                        {
                            match = false;
                            break;
                        }
                    }
                    if (match)
                    {
                        hit += 1;
                    }
                }
            }

            return(new CustomMetricResult {
                SumMetric = hit, NumInst = batchSize
            });
        }
コード例 #2
0
        public List <SingleNArray> Predict(IDataIter inputX, int?numBatch = null, bool returnData = false, bool reset = true)
        {
            if (reset)
            {
                inputX.Reset();
            }

            var dataShapes = inputX.ProvideData;
            var dataNames  = dataShapes.Select(s => s.Key).ToList();

            InitPredictor(dataShapes);

            var batchSize  = inputX.BatchSize;
            var dataArrays = dataNames.Select(name => this._predExec.ArgDict[name]).ToList();
            var outputList = this._predExec.Outputs.Select(s => new List <SingleNArray>()).ToList();

            List <List <SingleNArray> > dataList  = null;
            List <List <SingleNArray> > labelList = null;

            if (returnData)
            {
                dataList  = inputX.ProvideData.Select(s => new List <SingleNArray>()).ToList();
                labelList = inputX.ProvideLabel.Select(s => new List <SingleNArray>()).ToList();
            }

            int i = 0;

            foreach (var batch in inputX)
            {
                ExecutorManager.LoadData(batch, dataArrays);
                this._predExec.Forward(isTrain: false);
                var padded   = batch.Pad;
                var realSize = batchSize - padded;

                foreach (var vitem in outputList.Zip(this._predExec.Outputs, Tuple.Create))
                {
                    vitem.Item1.Add(vitem.Item2.Slice(0, (uint)realSize).AsNumerics());
                }

                if (returnData)
                {
                    for (int j = 0; j < batch.Data.Count; j++)
                    {
                        var x = batch.Data[j];
                        dataList[j].Add(x.Slice(0, (uint)realSize).AsNumerics());
                    }

                    for (int j = 0; j < batch.Data.Count; j++)
                    {
                        var x = batch.Label[j];
                        labelList[j].Add(x.Slice(0, (uint)realSize).AsNumerics());
                    }
                }

                i += 1;
                if (numBatch != null && i == numBatch.Value)
                {
                    break;
                }
            }


            var outputs = outputList.Select(s => SingleNArray.Concatenate(0, s.ToArray())).ToList();

            if (returnData)
            {
                var data  = dataList.Select(s => SingleNArray.Concatenate(0, s.ToArray()));
                var label = labelList.Select(s => SingleNArray.Concatenate(0, s.ToArray()));
            }


            return(outputs);
        }