Ejemplo n.º 1
0
        private static InferenceResult InferTextFileColumnTypesCore(IHostEnvironment env, IMultiStreamSource fileSource, Arguments args, IChannel ch)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(env);
            ch.AssertValue(fileSource);
            ch.AssertValue(args);

            if (args.ColumnCount == 0)
            {
                ch.Error("Too many empty columns for automatic inference.");
                return(InferenceResult.Fail());
            }

            if (args.ColumnCount >= SmartColumnsLim)
            {
                ch.Error("Too many columns for automatic inference.");
                return(InferenceResult.Fail());
            }

            // Read the file as the specified number of text columns.
            var textLoaderArgs = new TextLoader.Arguments
            {
                Column       = new[] { TextLoader.Column.Parse(string.Format("C:TX:0-{0}", args.ColumnCount - 1)) },
                Separator    = args.Separator,
                AllowSparse  = args.AllowSparse,
                AllowQuoting = args.AllowQuote,
            };
            var idv = TextLoader.ReadFile(env, textLoaderArgs, fileSource);

            idv = idv.Take(args.MaxRowsToRead);

            // Read all the data into memory.
            // List items are rows of the dataset.
            var data = new List <ReadOnlyMemory <char>[]>();

            using (var cursor = idv.GetRowCursor(col => true))
            {
                int  columnIndex;
                bool found = cursor.Schema.TryGetColumnIndex("C", out columnIndex);
                Contracts.Assert(found);
                var colType = cursor.Schema.GetColumnType(columnIndex);
                Contracts.Assert(colType.ItemType.IsText);
                ValueGetter <VBuffer <ReadOnlyMemory <char> > > vecGetter = null;
                ValueGetter <ReadOnlyMemory <char> >            oneGetter = null;
                bool isVector = colType.IsVector;
                if (isVector)
                {
                    vecGetter = cursor.GetGetter <VBuffer <ReadOnlyMemory <char> > >(columnIndex);
                }
                else
                {
                    Contracts.Assert(args.ColumnCount == 1);
                    oneGetter = cursor.GetGetter <ReadOnlyMemory <char> >(columnIndex);
                }

                VBuffer <ReadOnlyMemory <char> > line    = default;
                ReadOnlyMemory <char>            tsValue = default;
                while (cursor.MoveNext())
                {
                    if (isVector)
                    {
                        vecGetter(ref line);
                        Contracts.Assert(line.Length == args.ColumnCount);
                        var values = new ReadOnlyMemory <char> [args.ColumnCount];
                        line.CopyTo(values);
                        data.Add(values);
                    }
                    else
                    {
                        oneGetter(ref tsValue);
                        var values = new[] { tsValue };
                        data.Add(values);
                    }
                }
            }

            if (data.Count < 2)
            {
                ch.Error("Too few rows ({0}) for automatic inference.", data.Count);
                return(InferenceResult.Fail());
            }

            var cols = new IntermediateColumn[args.ColumnCount];

            for (int i = 0; i < args.ColumnCount; i++)
            {
                cols[i] = new IntermediateColumn(data.Select(x => x[i]).ToArray(), i);
            }

            foreach (var expert in GetExperts())
            {
                expert.Apply(cols);
            }

            Contracts.Check(cols.All(x => x.SuggestedType != null), "Column type inference must be conclusive");

            // Aggregating header signals.
            int suspect   = 0;
            var usedNames = new HashSet <string>();

            for (int i = 0; i < args.ColumnCount; i++)
            {
                if (cols[i].HasHeader == true)
                {
                    if (usedNames.Add(cols[i].RawData[0].ToString()))
                    {
                        suspect++;
                    }
                    else
                    {
                        // duplicate value in the first column is a strong signal that this is not a header
                        suspect -= args.ColumnCount;
                    }
                }
                else if (cols[i].HasHeader == false)
                {
                    suspect--;
                }
            }

            // REVIEW: Why not use this for column names as well?
            TextLoader.Arguments fileArgs;
            bool hasHeader;

            if (TextLoader.FileContainsValidSchema(env, fileSource, out fileArgs))
            {
                hasHeader = fileArgs.HasHeader;
            }
            else
            {
                hasHeader = suspect > 0;
            }

            // suggest names
            var names = new List <string>();

            usedNames.Clear();
            foreach (var col in cols)
            {
                string name0;
                string name;
                name0 = name = SuggestName(col, hasHeader);
                int i = 0;
                while (!usedNames.Add(name))
                {
                    name = string.Format("{0}_{1:00}", name0, i++);
                }
                names.Add(name);
            }
            var outCols =
                cols.Select((x, i) => new Column(x.ColumnId, names[i], x.SuggestedType)).ToArray();

            var numerics = outCols.Count(x => x.ItemType.IsNumber);

            ch.Info("Detected {0} numeric and {1} text columns.", numerics, outCols.Length - numerics);
            if (hasHeader)
            {
                ch.Info("Generated column names from the file header.");
            }

            return(InferenceResult.Success(outCols, hasHeader, cols.Select(col => col.RawData).ToArray()));
        }
        private static InferenceResult InferTextFileColumnTypesCore(MLContext context, IMultiStreamSource fileSource, Arguments args)
        {
            if (args.ColumnCount == 0)
            {
                // too many empty columns for automatic inference
                return(InferenceResult.Fail());
            }

            if (args.ColumnCount >= SmartColumnsLim)
            {
                // too many columns for automatic inference
                return(InferenceResult.Fail());
            }

            // read the file as the specified number of text columns
            var textLoaderOptions = new TextLoader.Options
            {
                Columns      = new[] { new TextLoader.Column("C", DataKind.String, 0, args.ColumnCount - 1) },
                Separators   = new[] { args.Separator },
                AllowSparse  = args.AllowSparse,
                AllowQuoting = args.AllowQuote,
            };
            var textLoader = context.Data.CreateTextLoader(textLoaderOptions);
            var idv        = textLoader.Load(fileSource);

            idv = context.Data.TakeRows(idv, args.MaxRowsToRead);

            // read all the data into memory.
            // list items are rows of the dataset.
            var data = new List <ReadOnlyMemory <char>[]>();

            using (var cursor = idv.GetRowCursor(idv.Schema))
            {
                var column  = cursor.Schema.GetColumnOrNull("C").Value;
                var colType = column.Type;
                ValueGetter <VBuffer <ReadOnlyMemory <char> > > vecGetter = null;
                ValueGetter <ReadOnlyMemory <char> >            oneGetter = null;
                bool isVector = colType.IsVector();
                if (isVector)
                {
                    vecGetter = cursor.GetGetter <VBuffer <ReadOnlyMemory <char> > >(column);
                }
                else
                {
                    oneGetter = cursor.GetGetter <ReadOnlyMemory <char> >(column);
                }

                VBuffer <ReadOnlyMemory <char> > line    = default;
                ReadOnlyMemory <char>            tsValue = default;
                while (cursor.MoveNext())
                {
                    if (isVector)
                    {
                        vecGetter(ref line);
                        var values = new ReadOnlyMemory <char> [args.ColumnCount];
                        line.CopyTo(values);
                        data.Add(values);
                    }
                    else
                    {
                        oneGetter(ref tsValue);
                        var values = new[] { tsValue };
                        data.Add(values);
                    }
                }
            }

            if (data.Count < 2)
            {
                // too few rows for automatic inference
                return(InferenceResult.Fail());
            }

            var cols = new IntermediateColumn[args.ColumnCount];

            for (int i = 0; i < args.ColumnCount; i++)
            {
                cols[i] = new IntermediateColumn(data.Select(x => x[i]).ToArray(), i);
            }

            foreach (var expert in GetExperts())
            {
                expert.Apply(cols);
            }

            // Aggregating header signals.
            int suspect   = 0;
            var usedNames = new HashSet <string>();

            for (int i = 0; i < args.ColumnCount; i++)
            {
                if (cols[i].HasHeader == true)
                {
                    if (usedNames.Add(cols[i].RawData[0].ToString()))
                    {
                        suspect++;
                    }
                    else
                    {
                        // duplicate value in the first column is a strong signal that this is not a header
                        suspect -= args.ColumnCount;
                    }
                }
                else if (cols[i].HasHeader == false)
                {
                    suspect--;
                }
            }

            // suggest names
            usedNames.Clear();
            foreach (var col in cols)
            {
                string name0;
                string name;
                name0 = name = SuggestName(col, args.HasHeader);
                int i = 0;
                while (!usedNames.Add(name))
                {
                    name = string.Format("{0}_{1:00}", name0, i++);
                }
                col.Name = name;
            }

            // validate & retrieve label column
            var labelColumn = GetAndValidateLabelColumn(args, cols);

            // if label column has all Boolean values, set its type as Boolean
            if (labelColumn.HasAllBooleanValues())
            {
                labelColumn.SuggestedType = BooleanDataViewType.Instance;
            }

            var outCols = cols.Select(x => new Column(x.ColumnId, x.Name, x.SuggestedType)).ToArray();

            return(InferenceResult.Success(outCols, args.HasHeader, cols.Select(col => col.RawData).ToArray()));
        }