示例#1
0
        private static DataSourceSet ConvertFromSerializableObject(Dictionary <string, Tuple <float[], int[]> > obj)
        {
            var dss = new DataSourceSet();

            foreach (var entry in obj)
            {
                dss.Add(entry.Key, DataSourceFactory.Create(entry.Value.Item1, entry.Value.Item2));
            }

            return(dss);
        }
示例#2
0
        public NoiseSampler(string name, int[] shape, int minibatchSize, int iterationsPerEpoch, double min, double max, int?seed = null)
        {
            Name               = name;
            Shape              = shape;
            MinibatchSize      = minibatchSize;
            IterationsPerEpoch = iterationsPerEpoch;
            Min = min;
            Max = max;

            _random   = Random.GetInstance(seed);
            _dataSize = Shape.GetSize(-1);
            _data     = new float[_dataSize * minibatchSize];

            var minibatchShape = new int[Shape.Rank + 2];

            shape.CopyTo(minibatchShape, 0);
            minibatchShape[minibatchShape.Length - 2] = 1;             // sequence
            minibatchShape[minibatchShape.Length - 1] = minibatchSize; // sample
            _samples = DataSourceFactory.Create(_data, minibatchShape);

            Iterations = 0;
        }
示例#3
0
        private void ProcessInternal <T>(Func <object, T> converter)
        {
            if (ParameterSetName == "load")
            {
                Path = IO.GetAbsolutePath(this, Path);
                var result = DataSourceFactory.Load <float>(Path, !NoDecompress);
                WriteObject(result);
            }
            else if (ParameterSetName == "rows")
            {
                var data = new List <T[]>();
                foreach (var row in Rows)
                {
                    var r = row.Select(x => {
                        if (x is PSObject)
                        {
                            x = (x as PSObject).BaseObject;
                        }

                        return(converter.Invoke(x));
                    });
                    data.Add(r.ToArray());
                }

                var result = DataSourceFactory.FromRows <T>(data, Dimensions);
                WriteObject(result);
            }
            else if (ParameterSetName == "columns")
            {
                var data = new List <T[]>();
                foreach (var column in Columns)
                {
                    var c = column.Select(x => {
                        if (x is PSObject)
                        {
                            x = (x as PSObject).BaseObject;
                        }

                        return(converter.Invoke(x));
                    });
                    data.Add(c.ToArray());
                }

                var result = DataSourceFactory.FromColumns(data, Dimensions);
                WriteObject(result);
            }
            else if (ParameterSetName == "psobjects")
            {
                var result = DataSourceFactory.FromPSObjects(PSObjects, converter);

                if (Dimensions != null)
                {
                    result.Reshape(Dimensions);
                }

                WriteObject(result);
            }
            else if (ParameterSetName == "datatable")
            {
                var result = DataSourceFactory.FromDataTable(DataTable, converter);

                if (Dimensions != null)
                {
                    result.Reshape(Dimensions);
                }

                WriteObject(result);
            }
            else
            {
                // new
                var result = DataSourceFactory.Create(Data.Select(x => converter.Invoke(x)).ToArray(), Dimensions);
                WriteObject(result);
            }
        }
示例#4
0
文件: CTFTools.cs 项目: horker/pscntk
        public static IEnumerable <CTFSample> GetSampleReader(TextReader reader)
        {
            int lineCount         = 0;
            int sequenceCount     = 0;
            int seqStartLineCount = 1;

            string line;
            string seqId = null;

            var splitLines = new List <string[]>();

            var comments      = new List <string>();
            var startIndexMap = new Dictionary <string, int>();
            var endIndexMap   = new Dictionary <string, int>();

            while ((line = reader.ReadLine()) != null)
            {
                ++lineCount;

                var splitLine = line.Split(new char[] { '|' });
                var n         = splitLine[0].Trim();

                if (seqId == null || n == seqId)
                {
                    splitLines.Add(splitLine);
                    seqId = n;

                    if (reader.Peek() != -1)
                    {
                        continue;
                    }
                }

                ++sequenceCount;

                var seqDim = splitLines.Count;

                var dss = new DataSourceSet();

                comments.Clear();
                startIndexMap.Clear();
                endIndexMap.Clear();

                for (var i = 0; i < splitLines.Count; ++i)
                {
                    var columns = splitLines[i];
                    for (var j = 1; j < columns.Length; ++j)
                    {
                        var feature = columns[j].Trim();
                        if (feature[0] == '#')
                        {
                            int skip;
                            for (skip = 1; skip < feature.Length && Char.IsWhiteSpace(feature[skip]); ++skip)
                            {
                                ;
                            }

                            comments.Add(feature.Substring(skip));
                            continue;
                        }

                        var items = feature.Split();
                        if (items.Length < 2)
                        {
                            throw new InvalidDataException(string.Format("line {0}: Invalid feature", lineCount));
                        }

                        var featureDim = items.Length - 1;
                        var name       = items[0];

                        DataSourceBase <float, float[]> ds;

                        float[] data;
                        if (dss.Features.ContainsKey(name))
                        {
                            ds   = (DataSourceBase <float, float[]>)dss.Features[name];
                            data = ds.TypedData;
                        }
                        else
                        {
                            data = new float[featureDim * seqDim];
                            ds   = DataSourceFactory.Create(data, new int[] { featureDim, seqDim, 1 });
                            dss.Add(name, ds);
                            startIndexMap[name] = i;
                        }

                        var baseIndex = ds.Shape.GetSequentialIndex(new int[] { 0, i, 0 });
                        for (var k = 0; k < featureDim; ++k)
                        {
                            data[baseIndex + k] = Converter.ToFloat(items[k + 1]);
                        }
                        endIndexMap[name] = i;
                    }
                }

                foreach (var name in dss.Features.Keys.ToArray())
                {
                    var start = startIndexMap[name];
                    var end   = endIndexMap[name];
                    if (start == 0 && end == splitLines.Count - 1)
                    {
                        continue;
                    }

                    var ds = dss[name];
                    dss.Features[name] = ds.Subset(start, end - start + 1, -2);
                }

                yield return(new CTFSample()
                {
                    LineCount = seqStartLineCount,
                    SequenceCount = sequenceCount,
                    SequenceId = seqId,
                    DataSet = dss,
                    Comments = comments
                });

                splitLines.Clear();
                splitLines.Add(splitLine);

                seqStartLineCount = lineCount;
                seqId             = n;
            }
        }