コード例 #1
0
ファイル: TypedCursor.cs プロジェクト: zyw400/machinelearning
            private Action <TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinition.Column column, Delegate poke, Delegate peek)
            {
                var colType   = input.Schema.GetColumnType(index);
                var fieldInfo = column.FieldInfo;
                var fieldType = fieldInfo.FieldType;

                Func <IRow, int, Delegate, Delegate, Action <TRow> > del;

                if (fieldType.IsArray)
                {
                    Ch.Assert(colType.IsVector);
                    // VBuffer<DvText> -> String[]
                    if (fieldType.GetElementType() == typeof(string))
                    {
                        Ch.Assert(colType.ItemType.IsText);
                        return(CreateVBufferToStringArraySetter(input, index, poke, peek));
                    }
                    // VBuffer<T> -> T[]
                    Ch.Assert(fieldType.GetElementType() == colType.ItemType.RawType);
                    del = CreateVBufferToArraySetter <int>;
                }
                else if (colType.IsVector)
                {
                    // VBuffer<T> -> VBuffer<T>
                    // REVIEW: Do we care about accomodating VBuffer<string> -> VBuffer<DvText>?
                    Ch.Assert(fieldType.IsGenericType);
                    Ch.Assert(fieldType.GetGenericTypeDefinition() == typeof(VBuffer <>));
                    Ch.Assert(fieldType.GetGenericArguments()[0] == colType.ItemType.RawType);
                    del = CreateVBufferToVBufferSetter <int>;
                }
                else if (colType.IsPrimitive)
                {
                    if (fieldType == typeof(string))
                    {
                        // DvText -> String
                        Ch.Assert(colType.IsText);
                        Ch.Assert(peek == null);
                        return(CreateTextToStringSetter(input, index, poke));
                    }
                    else if (fieldType == typeof(bool))
                    {
                        Ch.Assert(colType.IsBool);
                        Ch.Assert(peek == null);
                        return(CreateDvBoolToBoolSetter(input, index, poke));
                    }
                    else
                    {
                        // T -> T
                        Ch.Assert(colType.RawType == fieldType);
                        del = CreateDirectSetter <int>;
                    }
                }
                else
                {
                    // REVIEW: Is this even possible?
                    throw Ch.ExceptNotImpl("Type '{0}' is not yet supported.", fieldInfo.FieldType.FullName);
                }
                MethodInfo meth = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(colType.ItemType.RawType);

                return((Action <TRow>)meth.Invoke(this, new object[] { input, index, poke, peek }));
            }
コード例 #2
0
        /// <summary>
        /// Utility method to create the file-based <see cref="TermMap"/> if the <see cref="ArgumentsBase.DataFile"/>
        /// argument of <paramref name="args"/> was present.
        /// </summary>
        private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, ArgumentsBase args, Builder bldr)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(env);
            ch.AssertValue(args);
            ch.Assert(!string.IsNullOrWhiteSpace(args.DataFile));
            ch.AssertValue(bldr);

            string file = args.DataFile;
            // First column using the file.
            string             src        = args.TermsColumn;
            IMultiStreamSource fileSource = new MultiFileSource(file);

            var loaderFactory = args.Loader;
            // If the user manually specifies a loader, or this is already a pre-processed binary
            // file, then we assume the user knows what they're doing and do not attempt to convert
            // to the desired type ourselves.
            bool        autoConvert = false;
            IDataLoader loader;

            if (loaderFactory != null)
            {
                loader = loaderFactory.CreateComponent(env, fileSource);
            }
            else
            {
                // Determine the default loader from the extension.
                var  ext         = Path.GetExtension(file);
                bool isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
                if (isBinary || isTranspose)
                {
                    ch.Assert(isBinary != isTranspose);
                    ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn),
                                    "Must be specified");
                    if (isBinary)
                    {
                        loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
                    }
                    else
                    {
                        ch.Assert(isTranspose);
                        loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
                    }
                }
                else
                {
                    if (!string.IsNullOrWhiteSpace(src))
                    {
                        ch.Warning(
                            "{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}",
                            nameof(Arguments.TermsColumn), src);
                    }
                    loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                    {
                        Separator = "tab",
                        Column    = new[]
                        {
                            new TextLoader.Column()
                            {
                                Name   = "Term",
                                Type   = DataKind.TX,
                                Source = new[] { new TextLoader.Range()
                                                 {
                                                     Min = 0
                                                 } }
                            }
                        }
                    },
                                            fileSource);
                    src         = "Term";
                    autoConvert = true;
                }
            }
            ch.AssertNonEmpty(src);

            int colSrc;

            if (!loader.Schema.TryGetColumnIndex(src, out colSrc))
            {
                throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src);
            }
            var typeSrc = loader.Schema.GetColumnType(colSrc);

            if (!autoConvert && !typeSrc.Equals(bldr.ItemType))
            {
                throw ch.ExceptUserArg(nameof(args.TermsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc);
            }

            using (var cursor = loader.GetRowCursor(col => col == colSrc))
                using (var pch = env.StartProgressChannel("Building term dictionary from file"))
                {
                    var    header   = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" });
                    var    trainer  = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr);
                    double rowCount = loader.GetRowCount(true) ?? double.NaN;
                    long   rowCur   = 0;
                    pch.SetHeader(header,
                                  e =>
                    {
                        e.SetProgress(0, rowCur, rowCount);
                        // Purely feedback for the user. That the other thread might be
                        // working in the background is not a problem.
                        e.SetMetric(0, trainer.Count);
                    });
                    while (cursor.MoveNext() && trainer.ProcessRow())
                    {
                        rowCur++;
                    }
                    if (trainer.Count == 0)
                    {
                        ch.Warning("Term map loaded from file resulted in an empty map.");
                    }
                    pch.Checkpoint(trainer.Count, rowCur);
                    return(trainer.Finish());
                }
        }
コード例 #3
0
        /// <summary>
        /// This builds the <see cref="TermMap"/> instances per column.
        /// </summary>
        private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] infos,
                                       ArgumentsBase args, ColumnBase[] column, IDataView trainingData)
        {
            Contracts.AssertValue(env);
            env.AssertValue(ch);
            ch.AssertValue(infos);
            ch.AssertValue(args);
            ch.AssertValue(column);
            ch.AssertValue(trainingData);

            if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) &&
                (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader != null ||
                 !string.IsNullOrWhiteSpace(args.TermsColumn)))
            {
                ch.Warning("Explicit term list specified. Data file arguments will be ignored");
            }

            if (!Enum.IsDefined(typeof(SortOrder), args.Sort))
            {
                throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected", args.Sort);
            }

            TermMap termsFromFile = null;
            var     termMap       = new TermMap[infos.Length];

            int[]         lims         = new int[infos.Length];
            int           trainsNeeded = 0;
            HashSet <int> toTrain      = null;

            for (int iinfo = 0; iinfo < infos.Length; iinfo++)
            {
                // First check whether we have a terms argument, and handle it appropriately.
                var terms      = new DvText(column[iinfo].Terms);
                var termsArray = column[iinfo].Term;
                if (!terms.HasChars && termsArray == null)
                {
                    terms      = new DvText(args.Terms);
                    termsArray = args.Term;
                }

                terms = terms.Trim();
                if (terms.HasChars || (termsArray != null && termsArray.Length > 0))
                {
                    // We have terms! Pass it in.
                    var sortOrder = column[iinfo].Sort ?? args.Sort;
                    if (!Enum.IsDefined(typeof(SortOrder), sortOrder))
                    {
                        throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, infos[iinfo].Name);
                    }

                    var bldr = Builder.Create(infos[iinfo].TypeSrc, sortOrder);
                    if (terms.HasChars)
                    {
                        bldr.ParseAddTermArg(ref terms, ch);
                    }
                    else
                    {
                        bldr.ParseAddTermArg(termsArray, ch);
                    }
                    termMap[iinfo] = bldr.Finish();
                }
                else if (!string.IsNullOrWhiteSpace(args.DataFile))
                {
                    // First column using this file.
                    if (termsFromFile == null)
                    {
                        var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort);
                        termsFromFile = CreateFileTermMap(env, ch, args, bldr);
                    }
                    if (!termsFromFile.ItemType.Equals(infos[iinfo].TypeSrc.ItemType))
                    {
                        // We have no current plans to support re-interpretation based on different column
                        // type, not only because it's unclear what realistic customer use-cases for such
                        // a complicated feature would be, and also because it's difficult to see how we
                        // can logically reconcile "reinterpretation" for different types with the resulting
                        // data view having an actual type.
                        throw ch.ExceptUserArg(nameof(args.DataFile), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'",
                                               termsFromFile.ItemType, infos[iinfo].Name, infos[iinfo].TypeSrc.ItemType);
                    }
                    termMap[iinfo] = termsFromFile;
                }
                else
                {
                    // Auto train this column. Leave the term map null for now, but set the lim appropriately.
                    lims[iinfo] = column[iinfo].MaxNumTerms ?? args.MaxNumTerms;
                    ch.CheckUserArg(lims[iinfo] > 0, nameof(Column.MaxNumTerms), "Must be positive");
                    Utils.Add(ref toTrain, infos[iinfo].Source);
                    ++trainsNeeded;
                }
            }

            ch.Assert((Utils.Size(toTrain) == 0) == (trainsNeeded == 0));
            ch.Assert(Utils.Size(toTrain) <= trainsNeeded);
            if (trainsNeeded > 0)
            {
                Trainer[] trainer     = new Trainer[trainsNeeded];
                int[]     trainerInfo = new int[trainsNeeded];
                // Open the cursor, then instantiate the trainers.
                int itrainer;
                using (var cursor = trainingData.GetRowCursor(toTrain.Contains))
                    using (var pch = env.StartProgressChannel("Building term dictionary"))
                    {
                        long   rowCur   = 0;
                        double rowCount = trainingData.GetRowCount(true) ?? double.NaN;
                        var    header   = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" });

                        itrainer = 0;
                        for (int iinfo = 0; iinfo < infos.Length; ++iinfo)
                        {
                            if (termMap[iinfo] != null)
                            {
                                continue;
                            }
                            var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort);
                            trainerInfo[itrainer] = iinfo;
                            trainer[itrainer++]   = Trainer.Create(cursor, infos[iinfo].Source, false, lims[iinfo], bldr);
                        }
                        ch.Assert(itrainer == trainer.Length);
                        pch.SetHeader(header,
                                      e =>
                        {
                            e.SetProgress(0, rowCur, rowCount);
                            // Purely feedback for the user. That the other thread might be
                            // working in the background is not a problem.
                            e.SetMetric(0, trainer.Sum(t => t.Count));
                        });

                        // The [0,tmin) trainers are finished.
                        int tmin = 0;
                        // We might exit early if all trainers reach their maximum.
                        while (tmin < trainer.Length && cursor.MoveNext())
                        {
                            rowCur++;
                            for (int t = tmin; t < trainer.Length; ++t)
                            {
                                if (!trainer[t].ProcessRow())
                                {
                                    Utils.Swap(ref trainerInfo[t], ref trainerInfo[tmin]);
                                    Utils.Swap(ref trainer[t], ref trainer[tmin++]);
                                }
                            }
                        }

                        pch.Checkpoint(trainer.Sum(t => t.Count), rowCur);
                    }
                for (itrainer = 0; itrainer < trainer.Length; ++itrainer)
                {
                    int iinfo = trainerInfo[itrainer];
                    ch.Assert(termMap[iinfo] == null);
                    if (trainer[itrainer].Count == 0)
                    {
                        ch.Warning("Term map for output column '{0}' contains no entries.", infos[iinfo].Name);
                    }
                    termMap[iinfo] = trainer[itrainer].Finish();
                    // Allow the intermediate structures in the trainer and builder to be released as we iterate
                    // over the columns, as the Finish operation can potentially result in the allocation of
                    // additional structures.
                    trainer[itrainer] = null;
                }
                ch.Assert(termMap.All(tm => tm != null));
                ch.Assert(termMap.Zip(infos, (tm, info) => tm.ItemType.Equals(info.TypeSrc.ItemType)).All(x => x));
            }

            return(termMap);
        }
コード例 #4
0
        private TPredictor TrainCore(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);

            // 1. Subset Selection
            var stackingTrainer = Combiner as IStackingTrainer <TOutput>;

            //REVIEW: Implement stacking for Batch mode.
            ch.CheckUserArg(stackingTrainer == null || Args.BatchSize <= 0, nameof(Args.BatchSize), "Stacking works only with Non-batch mode");

            var validationDataSetProportion = SubModelSelector.ValidationDatasetProportion;

            if (stackingTrainer != null)
            {
                validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion);
            }

            var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager;
            var models      = new List <FeatureSubsetModel <TOutput> >();

            _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion);
            int batchNumber = 1;

            foreach (var batch in _subsetSelector.GetBatches(Host.Rand))
            {
                // 2. Core train
                ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++);
                var batchModels = new FeatureSubsetModel <TOutput> [Trainers.Length];

                Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand),
                                 new ParallelOptions()
                {
                    MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1
                },
                                 (subset, state, index) =>
                {
                    ch.Info("Beginning training model {0} of {1}", index + 1, Trainers.Length);
                    Stopwatch sw = Stopwatch.StartNew();
                    try
                    {
                        if (EnsureMinimumFeaturesSelected(subset))
                        {
                            var model = new FeatureSubsetModel <TOutput>(
                                Trainers[(int)index].Train(subset.Data),
                                subset.SelectedFeatures,
                                null);
                            SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics);
                            batchModels[(int)index] = model;
                        }
                    }
                    catch (Exception ex)
                    {
                        ch.Assert(batchModels[(int)index] == null);
                        ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.",
                                   index + 1, Trainers.Length, ex.Message);
                    }
                    ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed);
                });

                var modelsList = batchModels.Where(m => m != null).ToList();
                if (Args.ShowMetrics)
                {
                    PrintMetrics(ch, modelsList);
                }

                modelsList = SubModelSelector.Prune(modelsList).ToList();

                if (stackingTrainer != null)
                {
                    stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host);
                }

                models.AddRange(modelsList);
                int modelSize = Utils.Size(models);
                if (modelSize < Utils.Size(Trainers))
                {
                    ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers));
                }
                ch.Check(modelSize > 0, "Ensemble training resulted in no valid models.");
            }
            return(CreatePredictor(models));
        }
コード例 #5
0
        private void SaveTransposedData(IChannel ch, Stream stream, ITransposeDataView data, int[] cols)
        {
            _host.AssertValue(ch);
            ch.AssertValue(stream);
            ch.AssertValue(data);
            ch.AssertNonEmpty(cols);
            ch.Assert(stream.CanSeek);

            // Initialize what we can in the header, though we will not be writing out things in the
            // header until we have confidence that things were written out correctly.
            TransposeLoader.Header header = default(TransposeLoader.Header);
            header.Signature         = TransposeLoader.Header.SignatureValue;
            header.Version           = TransposeLoader.Header.WriterVersion;
            header.CompatibleVersion = TransposeLoader.Header.WriterVersion;
            var slotType = data.GetSlotType(cols[0]);

            ch.AssertValue(slotType);
            header.RowCount    = slotType.Size;
            header.ColumnCount = cols.Length;

            // We keep track of the offsets of the start of each sub-IDV, for use in writing out the
            // offsets/length table later.
            List <long> offsets = new List <long>();

            // First write a bunch of zeros at the head, as a placeholder for the header that
            // will go there assuming we can successfully load it. We'll keep this array around
            // for the real marshalling and writing of the header bytes structure.
            byte[] headerBytes = new byte[TransposeLoader.Header.HeaderSize];
            stream.Write(headerBytes, 0, headerBytes.Length);
            offsets.Add(stream.Position);

            // This is a convenient delegate to write out an IDV substream, then save the offsets
            // where writing stopped to the offsets list.
            Action <string, IDataView> viewAction =
                (name, view) =>
            {
                using (var substream = new SubsetStream(stream))
                {
                    _internalSaver.SaveData(substream, view, Utils.GetIdentityPermutation(view.Schema.Count));
                    substream.Seek(0, SeekOrigin.End);
                    ch.Info("Wrote {0} data view in {1} bytes", name, substream.Length);
                }
                offsets.Add(stream.Position);
            };

            // First write out the no-row data, limited to these columns.
            IDataView subdata = new ChooseColumnsByIndexTransform(_host,
                                                                  new ChooseColumnsByIndexTransform.Arguments()
            {
                Index = cols
            }, data);

            // If we want the "dual mode" row-wise and slot-wise file, don't filter out anything.
            if (!_writeRowData)
            {
                subdata = SkipTakeFilter.Create(_host, new SkipTakeFilter.TakeArguments()
                {
                    Count = 0
                }, subdata);
            }

            string msg = _writeRowData ? "row-wise data, schema, and metadata" : "schema and metadata";

            viewAction(msg, subdata);
            foreach (var col in cols)
            {
                viewAction(data.Schema[col].Name, new TransposerUtils.SlotDataView(_host, data, col));
            }

            // Wrote out the dataview. Write out the table offset.
            using (var writer = new BinaryWriter(stream, Encoding.UTF8, leaveOpen: true))
            {
                // Format of the table is offset, length, both as 8-byte integers.
                // As it happens we wrote things out as adjacent sub-IDVs, so the
                // length can be derived from the offsets. The first will be the
                // start of the first sub-IDV, and all subsequent entries will be
                // the start/end of the current/next sub-IDV, respectively, so a total
                // of cols.Length + 2 entries.
                ch.Assert(offsets.Count == cols.Length + 2);
                ch.Assert(offsets[offsets.Count - 1] == stream.Position);
                header.SubIdvTableOffset = stream.Position;
                for (int c = 1; c < offsets.Count; ++c)
                {
                    // 8-byte int for offsets, 8-byte int for length.
                    writer.Write(offsets[c - 1]);
                    writer.Write(offsets[c] - offsets[c - 1]);
                }
                header.TailOffset = stream.Position;
                writer.Write(TransposeLoader.Header.TailSignatureValue);

                // Now we are confident that things will work, so write it out.
                unsafe
                {
                    Marshal.Copy(new IntPtr(&header), headerBytes, 0, Marshal.SizeOf(typeof(Header)));
                }
                writer.Seek(0, SeekOrigin.Begin);
                writer.Write(headerBytes);
            }
        }
コード例 #6
0
        protected virtual void TrainCore(IChannel ch, RoleMappedData data)
        {
            Host.AssertValue(ch);
            ch.AssertValue(data);

            // Compute the number of threads to use. The ctor should have verified that this will
            // produce a positive value.
            int numThreads = !UseThreads ? 1 : (NumThreads ?? Environment.ProcessorCount);

            if (Host.ConcurrencyFactor > 0 && numThreads > Host.ConcurrencyFactor)
            {
                numThreads = Host.ConcurrencyFactor;
                ch.Warning("The number of threads specified in trainer arguments is larger than the concurrency factor "
                           + "setting of the environment. Using {0} training threads instead.", numThreads);
            }

            ch.Assert(numThreads > 0);

            NumGoodRows = 0;
            WeightSum   = 0;

            _features = null;
            _labels   = null;
            _weights  = null;
            if (numThreads > 1)
            {
                ch.Info("LBFGS multi-threading will attempt to load dataset into memory. In case of out-of-memory " +
                        "issues, add 'numThreads=1' to the trainer arguments and 'cache=-' to the command line " +
                        "arguments to turn off multi-threading.");
                _features = new VBuffer <float> [1000];
                _labels   = new float[1000];
                if (data.Schema.Weight != null)
                {
                    _weights = new float[1000];
                }
            }

            var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Features | CursOpt.Label | CursOpt.Weight);

            long numBad;

            // REVIEW: This pass seems overly expensive for the benefit when multi-threading is off....
            using (var cursor = cursorFactory.Create())
                using (var pch = Host.StartProgressChannel("LBFGS data prep"))
                {
                    // REVIEW: maybe it makes sense for the factory to capture the good row count after
                    // the first successful cursoring?
                    Double totalCount = data.Data.GetRowCount(true) ?? Double.NaN;

                    long exCount = 0;
                    pch.SetHeader(new ProgressHeader(null, new[] { "examples" }),
                                  e => e.SetProgress(0, exCount, totalCount));
                    while (cursor.MoveNext())
                    {
                        WeightSum += cursor.Weight;
                        if (ShowTrainingStats)
                        {
                            ProcessPriorDistribution(cursor.Label, cursor.Weight);
                        }

                        PreTrainingProcessInstance(cursor.Label, ref cursor.Features, cursor.Weight);
                        exCount++;
                        if (_features != null)
                        {
                            ch.Assert(cursor.KeptRowCount <= int.MaxValue);
                            int index = (int)cursor.KeptRowCount - 1;
                            Utils.EnsureSize(ref _features, index + 1);
                            Utils.EnsureSize(ref _labels, index + 1);
                            if (_weights != null)
                            {
                                Utils.EnsureSize(ref _weights, index + 1);
                                _weights[index] = cursor.Weight;
                            }
                            Utils.Swap(ref _features[index], ref cursor.Features);
                            _labels[index] = cursor.Label;

                            if (cursor.KeptRowCount >= int.MaxValue)
                            {
                                ch.Warning("Limiting data size for multi-threading");
                                break;
                            }
                        }
                    }
                    NumGoodRows = cursor.KeptRowCount;
                    numBad      = cursor.SkippedRowCount;
                }
            ch.Check(NumGoodRows > 0, NoTrainingInstancesMessage);
            if (numBad > 0)
            {
                ch.Warning("Skipped {0} instances with missing features/label/weight during training", numBad);
            }

            if (_features != null)
            {
                ch.Assert(numThreads > 1);

                // If there are so many threads that each only gets a small number (less than 10) of instances, trim
                // the number of threads so each gets a more reasonable number (100 or so). These numbers are pretty arbitrary,
                // but avoid the possibility of having no instances on some threads.
                if (numThreads > 1 && NumGoodRows / numThreads < 10)
                {
                    int numNew = Math.Max(1, (int)NumGoodRows / 100);
                    ch.Warning("Too few instances to use {0} threads, decreasing to {1} thread(s)", numThreads, numNew);
                    numThreads = numNew;
                }
                ch.Assert(numThreads > 0);

                // Divide up the instances among the threads.
                _numChunks = numThreads;
                _ranges    = new int[_numChunks + 1];
                int cinstTot = (int)NumGoodRows;
                for (int ichk = 0, iinstMin = 0; ichk < numThreads; ichk++)
                {
                    int cchkLeft = numThreads - ichk;                                // Number of chunks left to fill.
                    ch.Assert(0 < cchkLeft && cchkLeft <= numThreads);
                    int cinstThis = (cinstTot - iinstMin + cchkLeft - 1) / cchkLeft; // Size of this chunk.
                    ch.Assert(0 < cinstThis && cinstThis <= cinstTot - iinstMin);
                    iinstMin         += cinstThis;
                    _ranges[ichk + 1] = iinstMin;
                }

                _localLosses    = new float[numThreads];
                _localGradients = new VBuffer <float> [numThreads - 1];
                int size = BiasCount + WeightCount;
                for (int i = 0; i < _localGradients.Length; i++)
                {
                    _localGradients[i] = VBufferUtils.CreateEmpty <float>(size);
                }

                ch.Assert(_numChunks > 0 && _data == null);
            }
            else
            {
                // Streaming, single-threaded case.
                _data          = data;
                _cursorFactory = cursorFactory;
                ch.Assert(_numChunks == 0 && _data != null);
            }

            VBuffer <float>       initWeights;
            ITerminationCriterion terminationCriterion;
            Optimizer             opt = InitializeOptimizer(ch, cursorFactory, out initWeights, out terminationCriterion);

            opt.Quiet = Quiet;

            float loss;

            try
            {
                opt.Minimize(DifferentiableFunction, ref initWeights, terminationCriterion, ref CurrentWeights, out loss);
            }
            catch (Optimizer.PrematureConvergenceException e)
            {
                if (!Quiet)
                {
                    ch.Warning("Premature convergence occurred. The OptimizationTolerance may be set too small. {0}", e.Message);
                }
                CurrentWeights = e.State.X;
                loss           = e.State.Value;
            }

            ch.Assert(CurrentWeights.Length == BiasCount + WeightCount);

            int numParams = BiasCount;

            if ((L1Weight > 0 && !Quiet) || ShowTrainingStats)
            {
                VBufferUtils.ForEachDefined(ref CurrentWeights, (index, value) => { if (index >= BiasCount && value != 0)
                                                                                    {
                                                                                        numParams++;
                                                                                    }
                                            });
                if (L1Weight > 0 && !Quiet)
                {
                    ch.Info("L1 regularization selected {0} of {1} weights.", numParams, BiasCount + WeightCount);
                }
            }

            if (ShowTrainingStats)
            {
                ComputeTrainingStatistics(ch, cursorFactory, loss, numParams);
            }
        }
コード例 #7
0
        private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiStreamSource source, bool skipStrictValidation, out ColumnSplitResult result)
        {
            result = default(ColumnSplitResult);
            try
            {
                // No need to provide information from unsuccessful loader, so we create temporary environment and get information from it in case of success
                using (var loaderEnv = new ConsoleEnvironment(0, true))
                {
                    var messages = new ConcurrentBag <ChannelMessage>();
                    loaderEnv.AddListener <ChannelMessage>(
                        (src, msg) =>
                    {
                        messages.Add(msg);
                    });
                    var  idv          = TextLoader.ReadFile(loaderEnv, args, source).Take(1000);
                    var  columnCounts = new List <int>();
                    int  columnIndex;
                    bool found = idv.Schema.TryGetColumnIndex("C", out columnIndex);
                    ch.Assert(found);

                    using (var cursor = idv.GetRowCursor(x => x == columnIndex))
                    {
                        var getter = cursor.GetGetter <VBuffer <ReadOnlyMemory <char> > >(columnIndex);

                        VBuffer <ReadOnlyMemory <char> > line = default;
                        while (cursor.MoveNext())
                        {
                            getter(ref line);
                            columnCounts.Add(line.Length);
                        }
                    }

                    Contracts.Check(columnCounts.Count > 0);
                    var mostCommon = columnCounts.GroupBy(x => x).OrderByDescending(x => x.Count()).First();
                    if (!skipStrictValidation && mostCommon.Count() < UniformColumnCountThreshold * columnCounts.Count)
                    {
                        return(false);
                    }

                    // If user explicitly specified separator we're allowing "single" column case;
                    // Otherwise user will see message informing that we were not able to detect any columns.
                    if (!skipStrictValidation && mostCommon.Key <= 1)
                    {
                        return(false);
                    }

                    result = new ColumnSplitResult(true, args.Separator, args.AllowQuoting, args.AllowSparse, mostCommon.Key);
                    ch.Trace("Discovered {0} columns using separator '{1}'", mostCommon.Key, args.Separator);
                    foreach (var msg in messages)
                    {
                        ch.Send(msg);
                    }
                    return(true);
                }
            }
            catch (Exception ex)
            {
                if (!ex.IsMarked())
                {
                    throw;
                }
                // For known exceptions, we just continue to the next separator candidate.
            }
            return(false);
        }
コード例 #8
0
        private void Train(IChannel ch, IDataView trainingData, LdaState[] states)
        {
            Host.AssertValue(ch);
            ch.AssertValue(trainingData);
            ch.AssertValue(states);
            ch.Assert(states.Length == Infos.Length);

            bool[] activeColumns = new bool[trainingData.Schema.ColumnCount];
            int[]  numVocabs     = new int[Infos.Length];

            for (int i = 0; i < Infos.Length; i++)
            {
                activeColumns[Infos[i].Source] = true;
                numVocabs[i] = 0;
            }

            //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data,
            //one for the pre-calc memory, one for feedin data really
            //another solution can be prepare these two value externally and put them in the beginning of the input file.
            long[] corpusSize  = new long[Infos.Length];
            int[]  numDocArray = new int[Infos.Length];

            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                var getters = new ValueGetter <VBuffer <Double> > [Utils.Size(Infos)];
                for (int i = 0; i < Infos.Length; i++)
                {
                    corpusSize[i]  = 0;
                    numDocArray[i] = 0;
                    getters[i]     = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, cursor, Infos[i].Source);
                }
                VBuffer <Double> src      = default(VBuffer <Double>);
                long             rowCount = 0;

                while (cursor.MoveNext())
                {
                    ++rowCount;
                    for (int i = 0; i < Infos.Length; i++)
                    {
                        int docSize = 0;
                        getters[i](ref src);

                        // compute term, doc instance#.
                        for (int termID = 0; termID < src.Count; termID++)
                        {
                            int termFreq = GetFrequency(src.Values[termID]);
                            if (termFreq < 0)
                            {
                                // Ignore this row.
                                docSize = 0;
                                break;
                            }

                            if (docSize >= _exes[i].NumMaxDocToken - termFreq)
                            {
                                break; //control the document length
                            }
                            //if legal then add the term
                            docSize += termFreq;
                        }

                        // Ignore empty doc
                        if (docSize == 0)
                        {
                            continue;
                        }

                        numDocArray[i]++;
                        corpusSize[i] += docSize * 2 + 1;   // in the beggining of each doc, there is a cursor variable

                        // increase numVocab if needed.
                        if (numVocabs[i] < src.Length)
                        {
                            numVocabs[i] = src.Length;
                        }
                    }
                }

                for (int i = 0; i < Infos.Length; ++i)
                {
                    if (numDocArray[i] != rowCount)
                    {
                        ch.Assert(numDocArray[i] < rowCount);
                        ch.Warning($"Column '{Infos[i].Name}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values.");
                    }
                }
            }

            // Initialize all LDA states
            for (int i = 0; i < Infos.Length; i++)
            {
                var state = new LdaState(Host, _exes[i], numVocabs[i]);
                if (numDocArray[i] == 0 || corpusSize[i] == 0)
                {
                    throw ch.Except("The specified documents are all empty in column '{0}'.", Infos[i].Name);
                }

                state.AllocateDataMemory(numDocArray[i], corpusSize[i]);
                states[i] = state;
            }

            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                int[] docSizeCheck = new int[Infos.Length];
                // This could be optimized so that if multiple trainers consume the same column, it is
                // fed into the train method once.
                var getters = new ValueGetter <VBuffer <Double> > [Utils.Size(Infos)];
                for (int i = 0; i < Infos.Length; i++)
                {
                    docSizeCheck[i] = 0;
                    getters[i]      = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, cursor, Infos[i].Source);
                }

                VBuffer <Double> src = default(VBuffer <Double>);

                while (cursor.MoveNext())
                {
                    for (int i = 0; i < Infos.Length; i++)
                    {
                        getters[i](ref src);
                        docSizeCheck[i] += states[i].FeedTrain(Host, ref src);
                    }
                }
                for (int i = 0; i < Infos.Length; i++)
                {
                    Host.Assert(corpusSize[i] == docSizeCheck[i]);
                    states[i].CompleteTrain();
                }
            }
        }
コード例 #9
0
        // The multi-output regression evaluator prints only the per-label metrics for each fold.
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            IDataView fold;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
            {
                throw ch.Except("No overall metrics found");
            }

            int  isWeightedCol;
            bool needWeighted = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol);

            int  stratCol;
            bool hasStrats = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol);
            int  stratVal;
            bool hasStratVals = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);

            ch.Assert(hasStrats == hasStratVals);

            var colCount       = fold.Schema.ColumnCount;
            var vBufferGetters = new ValueGetter <VBuffer <double> > [colCount];

            using (var cursor = fold.GetRowCursor(col => true))
            {
                bool isWeighted = false;
                ValueGetter <bool> isWeightedGetter;
                if (needWeighted)
                {
                    isWeightedGetter = cursor.GetGetter <bool>(isWeightedCol);
                }
                else
                {
                    isWeightedGetter = (ref bool dst) => dst = false;
                }

                ValueGetter <uint> stratGetter;
                if (hasStrats)
                {
                    var type = cursor.Schema.GetColumnType(stratCol);
                    stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol);
                }
                else
                {
                    stratGetter = (ref uint dst) => dst = 0;
                }

                int labelCount = 0;
                for (int i = 0; i < fold.Schema.ColumnCount; i++)
                {
                    if (fold.Schema.IsHidden(i) || (needWeighted && i == isWeightedCol) ||
                        (hasStrats && (i == stratCol || i == stratVal)))
                    {
                        continue;
                    }

                    var type = fold.Schema.GetColumnType(i);
                    if (type.IsKnownSizeVector && type.ItemType == NumberType.R8)
                    {
                        vBufferGetters[i] = cursor.GetGetter <VBuffer <double> >(i);
                        if (labelCount == 0)
                        {
                            labelCount = type.VectorSize;
                        }
                        else
                        {
                            ch.Check(labelCount == type.VectorSize, "All vector metrics should contain the same number of slots");
                        }
                    }
                }
                var labelNames = new ReadOnlyMemory <char> [labelCount];
                for (int j = 0; j < labelCount; j++)
                {
                    labelNames[j] = string.Format("Label_{0}", j).AsMemory();
                }

                var sb = new StringBuilder();
                sb.AppendLine("Per-label metrics:");
                sb.AppendFormat("{0,12} ", " ");
                for (int i = 0; i < labelCount; i++)
                {
                    sb.AppendFormat(" {0,20}", labelNames[i]);
                }
                sb.AppendLine();

                VBuffer <Double> metricVals      = default(VBuffer <Double>);
                bool             foundWeighted   = !needWeighted;
                bool             foundUnweighted = false;
                uint             strat           = 0;
                while (cursor.MoveNext())
                {
                    isWeightedGetter(ref isWeighted);
                    if (foundWeighted && isWeighted || foundUnweighted && !isWeighted)
                    {
                        throw ch.Except("Multiple {0} rows found in overall metrics data view",
                                        isWeighted ? "weighted" : "unweighted");
                    }
                    if (isWeighted)
                    {
                        foundWeighted = true;
                    }
                    else
                    {
                        foundUnweighted = true;
                    }

                    stratGetter(ref strat);
                    if (strat > 0)
                    {
                        continue;
                    }

                    for (int i = 0; i < colCount; i++)
                    {
                        if (vBufferGetters[i] != null)
                        {
                            vBufferGetters[i](ref metricVals);
                            ch.Assert(metricVals.Length == labelCount);

                            sb.AppendFormat("{0}{1,12}:", isWeighted ? "Weighted " : "", fold.Schema.GetColumnName(i));
                            foreach (var metric in metricVals.Items(all: true))
                            {
                                sb.AppendFormat(" {0,20:G20}", metric.Value);
                            }
                            sb.AppendLine();
                        }
                    }
                    if (foundUnweighted && foundWeighted)
                    {
                        break;
                    }
                }
                ch.Assert(foundUnweighted && foundWeighted);
                ch.Info(sb.ToString());
            }
        }
コード例 #10
0
        private void RunCore(IChannel ch, string cmd)
        {
            Host.AssertValue(ch);
            Host.AssertNonEmpty(cmd);

            ch.Trace("Constructing trainer");
            ITrainer trainer = Args.Trainer.CreateInstance(Host);

            IPredictor inputPredictor = null;

            if (Args.ContinueTrain && !TrainUtils.TryLoadPredictor(ch, Host, Args.InputModelFile, out inputPredictor))
            {
                ch.Warning("No input model file specified or model file did not contain a predictor. The model state cannot be initialized.");
            }

            ch.Trace("Constructing the training pipeline");
            IDataView trainPipe = CreateLoader();

            ISchema schema = trainPipe.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);

            TrainUtils.AddNormalizerIfNeeded(Host, ch, trainer, ref trainPipe, features, Args.NormalizeFeatures);

            ch.Trace("Binding columns");
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);
            var data       = new RoleMappedData(trainPipe, label, features, group, weight, name, customCols);

            RoleMappedData validData = null;

            if (!string.IsNullOrWhiteSpace(Args.ValidationFile))
            {
                if (!trainer.Info.SupportsValidation)
                {
                    ch.Warning("Ignoring validationFile: Trainer does not accept validation dataset.");
                }
                else
                {
                    ch.Trace("Constructing the validation pipeline");
                    IDataView validPipe = CreateRawLoader(dataFile: Args.ValidationFile);
                    validPipe = ApplyTransformUtils.ApplyAllTransformsToData(Host, trainPipe, validPipe);
                    validData = new RoleMappedData(validPipe, data.Schema.GetColumnRoleNames());
                }
            }

            var predictor = TrainUtils.Train(Host, ch, data, trainer, validData,
                                             Args.Calibrator, Args.MaxCalibrationExamples, Args.CacheData, inputPredictor);

            IDataLoader testPipe;

            using (var file = !string.IsNullOrEmpty(Args.OutputModelFile) ?
                              Host.CreateOutputFile(Args.OutputModelFile) : Host.CreateTempFile(".zip"))
            {
                TrainUtils.SaveModel(Host, ch, file, predictor, data, cmd);

                ch.Trace("Constructing the testing pipeline");
                using (var stream = file.OpenReadStream())
                    using (var rep = RepositoryReader.Open(stream, ch))
                        testPipe = LoadLoader(rep, Args.TestFile, true);
            }

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TrainTestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, testPipe, features, group, customCols, Host, data.Schema);

            // Evaluate.
            var evalComp = Args.Evaluator;

            if (!evalComp.IsGood())
            {
                evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
            }
            var evaluator = evalComp.CreateInstance(Host);
            var dataEval  = new RoleMappedData(scorePipe, label, features,
                                               group, weight, name, customCols, opt: true);
            var metrics = evaluator.Evaluate(dataEval);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(dataEval);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }