示例#1
0
        /// <summary>
        /// Generates output predictions for the input samples. Computation is done in batches.
        /// </summary>
        /// <param name="x">The input data frame to run prediction.</param>
        /// <param name="batch_size">Size of the batch.</param>
        /// <returns></returns>
        public DataFrame Predict(DataFrame x, int batch_size)
        {
            DataFrameIter dataFrameIter = new DataFrameIter(x);
            List <float>  predictions   = new List <float>();

            dataFrameIter.SetBatchSize(batch_size);

            while (dataFrameIter.Next())
            {
                var    data   = dataFrameIter.GetBatchX();
                Tensor output = data;
                foreach (var layer in Layers)
                {
                    if (layer.SkipPred)
                    {
                        continue;
                    }

                    layer.Forward(output);
                    output = layer.Output;
                }

                predictions.AddRange(output.ToArray().Cast <float>());
            }

            DataFrame result = new DataFrame();

            result.Load(predictions.ToArray());

            return(result);
        }
示例#2
0
        /// <summary>
        /// Generates output predictions for the input samples.
        /// </summary>
        /// <param name="x">The input data frame to run prediction.</param>
        /// <returns></returns>
        public DataFrame Predict(DataFrame x)
        {
            List <float> predictions = new List <float>();

            Tensor output = x.GetTensor();

            foreach (var layer in Layers)
            {
                if (layer.SkipPred)
                {
                    continue;
                }

                layer.Forward(output);
                output = layer.Output;
            }

            predictions.AddRange(output.ToArray().Cast <float>());
            DataFrame result = new DataFrame();

            result.Load(predictions.ToArray());

            return(result);
        }