示例#1
0
 public void LoadDataBatch(IDataBatch dataBatch)
 {
     ExecutorManager.LoadData(dataBatch, this._dataArrays);
     ExecutorManager._load_label(dataBatch, this._labelArrays);
 }
        public List <SingleNArray> Predict(IDataIter inputX, int?numBatch = null, bool returnData = false, bool reset = true)
        {
            if (reset)
            {
                inputX.Reset();
            }

            var dataShapes = inputX.ProvideData;
            var dataNames  = dataShapes.Select(s => s.Key).ToList();

            InitPredictor(dataShapes);

            var batchSize  = inputX.BatchSize;
            var dataArrays = dataNames.Select(name => this._predExec.ArgDict[name]).ToList();
            var outputList = this._predExec.Outputs.Select(s => new List <SingleNArray>()).ToList();

            List <List <SingleNArray> > dataList  = null;
            List <List <SingleNArray> > labelList = null;

            if (returnData)
            {
                dataList  = inputX.ProvideData.Select(s => new List <SingleNArray>()).ToList();
                labelList = inputX.ProvideLabel.Select(s => new List <SingleNArray>()).ToList();
            }

            int i = 0;

            foreach (var batch in inputX)
            {
                ExecutorManager.LoadData(batch, dataArrays);
                this._predExec.Forward(isTrain: false);
                var padded   = batch.Pad;
                var realSize = batchSize - padded;

                foreach (var vitem in outputList.Zip(this._predExec.Outputs, Tuple.Create))
                {
                    vitem.Item1.Add(vitem.Item2.Slice(0, (uint)realSize).AsNumerics());
                }

                if (returnData)
                {
                    for (int j = 0; j < batch.Data.Count; j++)
                    {
                        var x = batch.Data[j];
                        dataList[j].Add(x.Slice(0, (uint)realSize).AsNumerics());
                    }

                    for (int j = 0; j < batch.Data.Count; j++)
                    {
                        var x = batch.Label[j];
                        labelList[j].Add(x.Slice(0, (uint)realSize).AsNumerics());
                    }
                }

                i += 1;
                if (numBatch != null && i == numBatch.Value)
                {
                    break;
                }
            }


            var outputs = outputList.Select(s => SingleNArray.Concatenate(0, s.ToArray())).ToList();

            if (returnData)
            {
                var data  = dataList.Select(s => SingleNArray.Concatenate(0, s.ToArray()));
                var label = labelList.Select(s => SingleNArray.Concatenate(0, s.ToArray()));
            }


            return(outputs);
        }