示例#1
0
        /// <summary>
        /// Utility method to create the file-based <see cref="TermMap"/> if the <see cref="ArgumentsBase.DataFile"/>
        /// argument of <paramref name="args"/> was present.
        /// </summary>
        private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, ArgumentsBase args, Builder bldr)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(env);
            ch.AssertValue(args);
            ch.Assert(!string.IsNullOrWhiteSpace(args.DataFile));
            ch.AssertValue(bldr);

            string file = args.DataFile;
            // First column using the file.
            string             src        = args.TermsColumn;
            IMultiStreamSource fileSource = new MultiFileSource(file);

            var loaderFactory = args.Loader;
            // If the user manually specifies a loader, or this is already a pre-processed binary
            // file, then we assume the user knows what they're doing and do not attempt to convert
            // to the desired type ourselves.
            bool        autoConvert = false;
            IDataLoader loader;

            if (loaderFactory != null)
            {
                loader = loaderFactory.CreateComponent(env, fileSource);
            }
            else
            {
                // Determine the default loader from the extension.
                var  ext         = Path.GetExtension(file);
                bool isBinary    = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase);
                bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase);
                if (isBinary || isTranspose)
                {
                    ch.Assert(isBinary != isTranspose);
                    ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn),
                                    "Must be specified");
                    if (isBinary)
                    {
                        loader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource);
                    }
                    else
                    {
                        ch.Assert(isTranspose);
                        loader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource);
                    }
                }
                else
                {
                    if (!string.IsNullOrWhiteSpace(src))
                    {
                        ch.Warning(
                            "{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}",
                            nameof(Arguments.TermsColumn), src);
                    }
                    loader = new TextLoader(env,
                                            new TextLoader.Arguments()
                    {
                        Separator = "tab",
                        Column    = new[] { new TextLoader.Column("Term", DataKind.TX, 0) }
                    },
                                            fileSource);
                    src         = "Term";
                    autoConvert = true;
                }
            }
            ch.AssertNonEmpty(src);

            int colSrc;

            if (!loader.Schema.TryGetColumnIndex(src, out colSrc))
            {
                throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src);
            }
            var typeSrc = loader.Schema.GetColumnType(colSrc);

            if (!autoConvert && !typeSrc.Equals(bldr.ItemType))
            {
                throw ch.ExceptUserArg(nameof(args.TermsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc);
            }

            using (var cursor = loader.GetRowCursor(col => col == colSrc))
                using (var pch = env.StartProgressChannel("Building term dictionary from file"))
                {
                    var    header   = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" });
                    var    trainer  = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr);
                    double rowCount = loader.GetRowCount(true) ?? double.NaN;
                    long   rowCur   = 0;
                    pch.SetHeader(header,
                                  e =>
                    {
                        e.SetProgress(0, rowCur, rowCount);
                        // Purely feedback for the user. That the other thread might be
                        // working in the background is not a problem.
                        e.SetMetric(0, trainer.Count);
                    });
                    while (cursor.MoveNext() && trainer.ProcessRow())
                    {
                        rowCur++;
                    }
                    if (trainer.Count == 0)
                    {
                        ch.Warning("Term map loaded from file resulted in an empty map.");
                    }
                    pch.Checkpoint(trainer.Count, rowCur);
                    return(trainer.Finish());
                }
        }
示例#2
0
        /// <summary>
        /// This builds the <see cref="TermMap"/> instances per column.
        /// </summary>
        private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] infos,
                                       ArgumentsBase args, ColumnBase[] column, IDataView trainingData)
        {
            Contracts.AssertValue(env);
            env.AssertValue(ch);
            ch.AssertValue(infos);
            ch.AssertValue(args);
            ch.AssertValue(column);
            ch.AssertValue(trainingData);

            if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) &&
                (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader != null ||
                 !string.IsNullOrWhiteSpace(args.TermsColumn)))
            {
                ch.Warning("Explicit term list specified. Data file arguments will be ignored");
            }

            if (!Enum.IsDefined(typeof(SortOrder), args.Sort))
            {
                throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected", args.Sort);
            }

            TermMap termsFromFile = null;
            var     termMap       = new TermMap[infos.Length];

            int[]         lims         = new int[infos.Length];
            int           trainsNeeded = 0;
            HashSet <int> toTrain      = null;

            for (int iinfo = 0; iinfo < infos.Length; iinfo++)
            {
                // First check whether we have a terms argument, and handle it appropriately.
                var terms      = new DvText(column[iinfo].Terms);
                var termsArray = column[iinfo].Term;
                if (!terms.HasChars && termsArray == null)
                {
                    terms      = new DvText(args.Terms);
                    termsArray = args.Term;
                }

                terms = terms.Trim();
                if (terms.HasChars || (termsArray != null && termsArray.Length > 0))
                {
                    // We have terms! Pass it in.
                    var sortOrder = column[iinfo].Sort ?? args.Sort;
                    if (!Enum.IsDefined(typeof(SortOrder), sortOrder))
                    {
                        throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, infos[iinfo].Name);
                    }

                    var bldr = Builder.Create(infos[iinfo].TypeSrc, sortOrder);
                    if (terms.HasChars)
                    {
                        bldr.ParseAddTermArg(ref terms, ch);
                    }
                    else
                    {
                        bldr.ParseAddTermArg(termsArray, ch);
                    }
                    termMap[iinfo] = bldr.Finish();
                }
                else if (!string.IsNullOrWhiteSpace(args.DataFile))
                {
                    // First column using this file.
                    if (termsFromFile == null)
                    {
                        var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort);
                        termsFromFile = CreateFileTermMap(env, ch, args, bldr);
                    }
                    if (!termsFromFile.ItemType.Equals(infos[iinfo].TypeSrc.ItemType))
                    {
                        // We have no current plans to support re-interpretation based on different column
                        // type, not only because it's unclear what realistic customer use-cases for such
                        // a complicated feature would be, and also because it's difficult to see how we
                        // can logically reconcile "reinterpretation" for different types with the resulting
                        // data view having an actual type.
                        throw ch.ExceptUserArg(nameof(args.DataFile), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'",
                                               termsFromFile.ItemType, infos[iinfo].Name, infos[iinfo].TypeSrc.ItemType);
                    }
                    termMap[iinfo] = termsFromFile;
                }
                else
                {
                    // Auto train this column. Leave the term map null for now, but set the lim appropriately.
                    lims[iinfo] = column[iinfo].MaxNumTerms ?? args.MaxNumTerms;
                    ch.CheckUserArg(lims[iinfo] > 0, nameof(Column.MaxNumTerms), "Must be positive");
                    Utils.Add(ref toTrain, infos[iinfo].Source);
                    ++trainsNeeded;
                }
            }

            ch.Assert((Utils.Size(toTrain) == 0) == (trainsNeeded == 0));
            ch.Assert(Utils.Size(toTrain) <= trainsNeeded);
            if (trainsNeeded > 0)
            {
                Trainer[] trainer     = new Trainer[trainsNeeded];
                int[]     trainerInfo = new int[trainsNeeded];
                // Open the cursor, then instantiate the trainers.
                int itrainer;
                using (var cursor = trainingData.GetRowCursor(toTrain.Contains))
                    using (var pch = env.StartProgressChannel("Building term dictionary"))
                    {
                        long   rowCur   = 0;
                        double rowCount = trainingData.GetRowCount(true) ?? double.NaN;
                        var    header   = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" });

                        itrainer = 0;
                        for (int iinfo = 0; iinfo < infos.Length; ++iinfo)
                        {
                            if (termMap[iinfo] != null)
                            {
                                continue;
                            }
                            var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort);
                            trainerInfo[itrainer] = iinfo;
                            trainer[itrainer++]   = Trainer.Create(cursor, infos[iinfo].Source, false, lims[iinfo], bldr);
                        }
                        ch.Assert(itrainer == trainer.Length);
                        pch.SetHeader(header,
                                      e =>
                        {
                            e.SetProgress(0, rowCur, rowCount);
                            // Purely feedback for the user. That the other thread might be
                            // working in the background is not a problem.
                            e.SetMetric(0, trainer.Sum(t => t.Count));
                        });

                        // The [0,tmin) trainers are finished.
                        int tmin = 0;
                        // We might exit early if all trainers reach their maximum.
                        while (tmin < trainer.Length && cursor.MoveNext())
                        {
                            rowCur++;
                            for (int t = tmin; t < trainer.Length; ++t)
                            {
                                if (!trainer[t].ProcessRow())
                                {
                                    Utils.Swap(ref trainerInfo[t], ref trainerInfo[tmin]);
                                    Utils.Swap(ref trainer[t], ref trainer[tmin++]);
                                }
                            }
                        }

                        pch.Checkpoint(trainer.Sum(t => t.Count), rowCur);
                    }
                for (itrainer = 0; itrainer < trainer.Length; ++itrainer)
                {
                    int iinfo = trainerInfo[itrainer];
                    ch.Assert(termMap[iinfo] == null);
                    if (trainer[itrainer].Count == 0)
                    {
                        ch.Warning("Term map for output column '{0}' contains no entries.", infos[iinfo].Name);
                    }
                    termMap[iinfo] = trainer[itrainer].Finish();
                    // Allow the intermediate structures in the trainer and builder to be released as we iterate
                    // over the columns, as the Finish operation can potentially result in the allocation of
                    // additional structures.
                    trainer[itrainer] = null;
                }
                ch.Assert(termMap.All(tm => tm != null));
                ch.Assert(termMap.Zip(infos, (tm, info) => tm.ItemType.Equals(info.TypeSrc.ItemType)).All(x => x));
            }

            return(termMap);
        }
示例#3
0
        /// <summary>
        /// Train and return a booster.
        /// </summary>
        public static Booster Train(IChannel ch, IProgressChannel pch,
                                    Dictionary <string, object> parameters, Dataset dtrain, Dataset dvalid = null, int numIteration = 100,
                                    bool verboseEval = true, int earlyStoppingRound = 0)
        {
            // create Booster.
            Booster bst = new Booster(parameters, dtrain, dvalid);

            // Disable early stopping if we don't have validation data.
            if (dvalid == null && earlyStoppingRound > 0)
            {
                earlyStoppingRound = 0;
                ch.Warning("Validation dataset not present, early stopping will be disabled.");
            }

            int    bestIter              = 0;
            double bestScore             = double.MaxValue;
            double factorToSmallerBetter = 1.0;

            var metric = (string)parameters["metric"];

            if (earlyStoppingRound > 0 && (metric == "auc" || metric == "ndcg" || metric == "map"))
            {
                factorToSmallerBetter = -1.0;
            }

            const int evalFreq = 50;

            var metrics = new List <string>()
            {
                "Iteration"
            };
            var units = new List <string>()
            {
                "iterations"
            };

            if (verboseEval)
            {
                ch.Assert(parameters.ContainsKey("metric"));
                metrics.Add("Training-" + parameters["metric"]);
                if (dvalid != null)
                {
                    metrics.Add("Validation-" + parameters["metric"]);
                }
            }

            var header = new ProgressHeader(metrics.ToArray(), units.ToArray());

            int    iter       = 0;
            double trainError = double.NaN;
            double validError = double.NaN;

            pch.SetHeader(header, e =>
            {
                e.SetProgress(0, iter, numIteration);
                if (verboseEval)
                {
                    e.SetProgress(1, trainError);
                    if (dvalid != null)
                    {
                        e.SetProgress(2, validError);
                    }
                }
            });
            for (iter = 0; iter < numIteration; ++iter)
            {
                if (bst.Update())
                {
                    break;
                }

                if (earlyStoppingRound > 0)
                {
                    validError = bst.EvalValid();
                    if (validError * factorToSmallerBetter < bestScore)
                    {
                        bestScore = validError * factorToSmallerBetter;
                        bestIter  = iter;
                    }
                    if (iter - bestIter >= earlyStoppingRound)
                    {
                        ch.Info($"Met early stopping, best iteration: {bestIter + 1}, best score: {bestScore / factorToSmallerBetter}");
                        break;
                    }
                }
                if ((iter + 1) % evalFreq == 0)
                {
                    if (verboseEval)
                    {
                        trainError = bst.EvalTrain();
                        if (dvalid == null)
                        {
                            pch.Checkpoint(new double?[] { iter + 1, trainError });
                        }
                        else
                        {
                            if (earlyStoppingRound == 0)
                            {
                                validError = bst.EvalValid();
                            }
                            pch.Checkpoint(new double?[] { iter + 1,
                                                           trainError, validError });
                        }
                    }
                    else
                    {
                        pch.Checkpoint(new double?[] { iter + 1 });
                    }
                }
            }
            // Set the BestIteration.
            if (iter != numIteration && earlyStoppingRound > 0)
            {
                bst.BestIteration = bestIter + 1;
            }
            return(bst);
        }
示例#4
0
            public Single[][] GetScores(IDataView input, string labelColumnName, string[] columns, int numBins, int[] colSizes)
            {
                _numBins = numBins;
                var schema = input.Schema;
                var size   = columns.Length;

                if (!schema.TryGetColumnIndex(labelColumnName, out int labelCol))
                {
                    throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectionTransform.Arguments.LabelColumn),
                                              "Label column '{0}' not found", labelColumnName);
                }

                var labelType = schema.GetColumnType(labelCol);

                if (!IsValidColumnType(labelType))
                {
                    throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectionTransform.Arguments.LabelColumn),
                                              "Label column '{0}' does not have compatible type", labelColumnName);
                }

                var colSrcs = new int[size + 1];

                colSrcs[size] = labelCol;
                for (int i = 0; i < size; i++)
                {
                    var colName = columns[i];
                    if (!schema.TryGetColumnIndex(colName, out int colSrc))
                    {
                        throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectionTransform.Arguments.Column),
                                                  "Source column '{0}' not found", colName);
                    }

                    var colType = schema.GetColumnType(colSrc);
                    if (colType.IsVector && !colType.IsKnownSizeVector)
                    {
                        throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectionTransform.Arguments.Column),
                                                  "Variable length column '{0}' is not allowed", colName);
                    }

                    if (!IsValidColumnType(colType.ItemType))
                    {
                        throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectionTransform.Arguments.Column),
                                                  "Column '{0}' of type '{1}' does not have compatible type.", colName, colType);
                    }

                    colSrcs[i]  = colSrc;
                    colSizes[i] = colType.ValueCount;
                }

                var scores = new Single[size][];

                using (var ch = _host.Start("Computing mutual information scores"))
                    using (var pch = _host.StartProgressChannel("Computing mutual information scores"))
                    {
                        using (var trans = Transposer.Create(_host, input, false, colSrcs))
                        {
                            int i      = 0;
                            var header = new ProgressHeader(new[] { "columns" });
                            var b      = trans.Schema.TryGetColumnIndex(labelColumnName, out labelCol);
                            Contracts.Assert(b);

                            GetLabels(trans, labelType, labelCol);
                            _contingencyTable = new int[_numLabels][];
                            _labelSums        = new int[_numLabels];
                            pch.SetHeader(header, e => e.SetProgress(0, i, size));
                            for (i = 0; i < size; i++)
                            {
                                b = trans.Schema.TryGetColumnIndex(columns[i], out int col);
                                Contracts.Assert(b);
                                ch.Trace("Computing scores for column '{0}'", columns[i]);
                                scores[i] = ComputeMutualInformation(trans, col);
#if DEBUG
                                ch.Trace("Scores for column '{0}': {1}", columns[i], string.Join(", ", scores[i]));
#endif
                                pch.Checkpoint(i + 1);
                            }
                        }
                    }

                return(scores);
            }
示例#5
0
        /// <summary>
        /// Returns the feature selection scores for each slot of each column.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="input">The input dataview.</param>
        /// <param name="columns">The columns for which to compute the feature selection scores.</param>
        /// <param name="colSizes">Outputs an array containing the vector sizes of the input columns</param>
        /// <returns>A list of scores.</returns>
        public static long[][] Train(IHostEnvironment env, IDataView input, string[] columns, out int[] colSizes)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            env.CheckParam(Utils.Size(columns) > 0, nameof(columns));

            var schema     = input.Schema;
            var size       = columns.Length;
            var activeCols = new List <DataViewSchema.Column>();
            var colSrcs    = new int[size];
            var colTypes   = new DataViewType[size];

            colSizes = new int[size];
            for (int i = 0; i < size; i++)
            {
                int colSrc;
                var colName = columns[i];
                if (!schema.TryGetColumnIndex(colName, out colSrc))
                {
                    throw env.ExceptUserArg(nameof(CountFeatureSelectingEstimator.Options.Columns), "Source column '{0}' not found", colName);
                }

                var colType = schema[colSrc].Type;
                if (colType is VectorType vectorType && !vectorType.IsKnownSize)
                {
                    throw env.ExceptUserArg(nameof(CountFeatureSelectingEstimator.Options.Columns), "Variable length column '{0}' is not allowed", colName);
                }

                activeCols.Add(schema[colSrc]);
                colSrcs[i]  = colSrc;
                colTypes[i] = colType;
                colSizes[i] = colType.GetValueCount();
            }

            var    aggregators = new CountAggregator[size];
            long   rowCur      = 0;
            double rowCount    = input.GetRowCount() ?? double.NaN;

            using (var pch = env.StartProgressChannel("Aggregating counts"))
                using (var cursor = input.GetRowCursor(activeCols))
                {
                    var header = new ProgressHeader(new[] { "rows" });
                    pch.SetHeader(header, e => { e.SetProgress(0, rowCur, rowCount); });
                    for (int i = 0; i < size; i++)
                    {
                        if (colTypes[i] is VectorType vectorType)
                        {
                            aggregators[i] = GetVecAggregator(cursor, vectorType, colSrcs[i]);
                        }
                        else
                        {
                            aggregators[i] = GetOneAggregator(cursor, colTypes[i], colSrcs[i]);
                        }
                    }

                    while (cursor.MoveNext())
                    {
                        for (int i = 0; i < size; i++)
                        {
                            aggregators[i].ProcessValue();
                        }
                        rowCur++;
                    }
                    pch.Checkpoint(rowCur);
                }
            return(aggregators.Select(a => a.Count).ToArray());
        }
        /// <summary>
        /// Minimize a function using the supplied termination criterion
        /// </summary>
        /// <param name="function">The function to minimize</param>
        /// <param name="initial">The initial point</param>
        /// <param name="term">termination criterion to use</param>
        /// <param name="result">The point at the optimum</param>
        /// <param name="optimum">The optimum function value</param>
        /// <exception cref="PrematureConvergenceException">Thrown if successive points are within numeric precision of each other, but termination condition is still unsatisfied.</exception>
        public void Minimize(DifferentiableFunction function, ref VBuffer <Float> initial, ITerminationCriterion term, ref VBuffer <Float> result, out Float optimum)
        {
            const string computationName = "LBFGS Optimizer";

            using (var pch = Env.StartProgressChannel(computationName))
                using (var ch = Env.Start(computationName))
                {
                    ch.Info("Beginning optimization");
                    ch.Info("num vars: {0}", initial.Length);
                    ch.Info("improvement criterion: {0}", term.FriendlyName);

                    OptimizerState state = MakeState(ch, pch, function, ref initial);
                    term.Reset();

                    var header = new ProgressHeader(new[] { "Loss", "Improvement" }, new[] { "iterations", "gradients" });
                    pch.SetHeader(header,
                                  (Action <IProgressEntry>)(e =>
                    {
                        e.SetProgress(0, (double)(state.Iter - 1));
                        e.SetProgress(1, state.GradientCalculations);
                    }));

                    bool finished = false;
                    pch.Checkpoint(state.Value, null, 0);
                    state.UpdateDir();
                    while (!finished)
                    {
                        bool success = state.LineSearch(ch, false);
                        if (!success)
                        {
                            // problem may be numerical errors in previous gradients
                            // try to save state of optimization by discarding them
                            // and starting over with gradient descent.

                            state.DiscardOldVectors();

                            state.UpdateDir();

                            state.LineSearch(ch, true);
                        }

                        string message;
                        finished = term.Terminate(state, out message);

                        double?improvement = null;
                        double x;
                        int    end;
                        if (message != null && DoubleParser.TryParse(out x, message, 0, message.Length, out end))
                        {
                            improvement = x;
                        }

                        pch.Checkpoint(state.Value, improvement, state.Iter);

                        if (!finished)
                        {
                            state.Shift();
                            state.UpdateDir();
                        }
                    }

                    state.X.CopyTo(ref result);
                    optimum = state.Value;
                    ch.Done();
                }
        }
示例#7
0
 public void SetHeader(ProgressHeader header, Action <IProgressEntry> fillAction)
 {
     _headerAndAction = Tuple.Create(header, fillAction);
 }
        private Model GetVocabularyDictionary()
        {
            int dimension = 0;

            if (!File.Exists(_modelFileNameWithPath))
            {
                throw Host.Except("Custom word embedding model file '{0}' could not be found for Word Embeddings transform.", _modelFileNameWithPath);
            }

            if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null)
            {
                if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model model))
                {
                    dimension = model.Dimension;
                    return(model);
                }
            }

            lock (_embeddingsLock)
            {
                if (_vocab.ContainsKey(_modelFileNameWithPath) && _vocab[_modelFileNameWithPath] != null)
                {
                    if (_vocab[_modelFileNameWithPath].TryGetTarget(out Model modelObject))
                    {
                        dimension = modelObject.Dimension;
                        return(modelObject);
                    }
                }

                Model model = null;
                using (StreamReader sr = File.OpenText(_modelFileNameWithPath))
                {
                    string line;
                    int    lineNumber = 1;
                    char[] delimiters = { ' ', '\t' };
                    using (var ch = Host.Start(LoaderSignature))
                        using (var pch = Host.StartProgressChannel("Building Vocabulary from Model File for Word Embeddings Transform"))
                        {
                            var header = new ProgressHeader(new[] { "lines" });
                            pch.SetHeader(header, e => e.SetProgress(0, lineNumber));
                            string firstLine = sr.ReadLine();
                            while ((line = sr.ReadLine()) != null)
                            {
                                if (lineNumber >= _linesToSkip)
                                {
                                    string[] words = line.TrimEnd().Split(delimiters);
                                    dimension = words.Length - 1;
                                    if (model == null)
                                    {
                                        model = new Model(dimension);
                                    }
                                    if (model.Dimension != dimension)
                                    {
                                        ch.Warning($"Dimension mismatch while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}, expected dimension = {model.Dimension}, received dimension = {dimension}");
                                    }
                                    else
                                    {
                                        float   tmp;
                                        string  key   = words[0];
                                        float[] value = words.Skip(1).Select(x => float.TryParse(x, out tmp) ? tmp : Single.NaN).ToArray();
                                        if (!value.Contains(Single.NaN))
                                        {
                                            model.AddWordVector(ch, key, value);
                                        }
                                        else
                                        {
                                            ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number {lineNumber + 1}");
                                        }
                                    }
                                }
                                lineNumber++;
                            }

                            // Handle first line of the embedding file separately since some embedding files including fastText have a single-line header
                            string[] wordsInFirstLine = firstLine.TrimEnd().Split(delimiters);
                            dimension = wordsInFirstLine.Length - 1;
                            if (model == null)
                            {
                                model = new Model(dimension);
                            }
                            float   temp;
                            string  firstKey   = wordsInFirstLine[0];
                            float[] firstValue = wordsInFirstLine.Skip(1).Select(x => float.TryParse(x, out temp) ? temp : Single.NaN).ToArray();
                            if (!firstValue.Contains(Single.NaN))
                            {
                                model.AddWordVector(ch, firstKey, firstValue);
                            }
                            else
                            {
                                ch.Warning($"Parsing error while reading model file: '{_modelFileNameWithPath}', line number 1");
                            }
                            pch.Checkpoint(lineNumber);
                        }
                }
                _vocab[_modelFileNameWithPath] = new WeakReference <Model>(model, false);
                return(model);
            }
        }
示例#9
0
        /// <summary>
        /// Train and returns a booster.
        /// </summary>
        /// <param name="ch">IChannel</param>
        /// <param name="pch">IProgressChannel</param>
        /// <param name="numberOfTrees">Number of trained trees</param>
        /// <param name="parameters">Parameters see <see cref="XGBoostArguments"/></param>
        /// <param name="dtrain">Training set</param>
        /// <param name="numBoostRound">Number of trees to train</param>
        /// <param name="obj">Custom objective</param>
        /// <param name="maximize">Whether to maximize feval.</param>
        /// <param name="verboseEval">Requires at least one item in evals.
        ///     If "verbose_eval" is True then the evaluation metric on the validation set is
        ///     printed at each boosting stage.</param>
        /// <param name="xgbModel">For continuous training.</param>
        /// <param name="saveBinaryDMatrix">Save DMatrix in binary format (for debugging purpose).</param>
        public static Booster Train(IChannel ch, IProgressChannel pch, out int numberOfTrees,
                                    Dictionary <string, string> parameters, DMatrix dtrain, int numBoostRound = 10,
                                    Booster.FObjType obj     = null, bool maximize    = false,
                                    bool verboseEval         = true, Booster xgbModel = null,
                                    string saveBinaryDMatrix = null)
        {
#if (!XGBOOST_RABIT)
            if (WrappedXGBoostInterface.RabitIsDistributed() == 1)
            {
                var pname = WrappedXGBoostInterface.RabitGetProcessorName();
                ch.Info("[WrappedXGBoostTraining.Train] start {0}:{1}", pname, WrappedXGBoostInterface.RabitGetRank());
            }
#endif

            if (!string.IsNullOrEmpty(saveBinaryDMatrix))
            {
                dtrain.SaveBinary(saveBinaryDMatrix);
            }

            Booster bst             = new Booster(parameters, dtrain, xgbModel);
            int     numParallelTree = 1;
            int     nboost          = 0;

            if (parameters != null && parameters.ContainsKey("num_parallel_tree"))
            {
                numParallelTree = Convert.ToInt32(parameters["num_parallel_tree"]);
                nboost         /= numParallelTree;
            }
            if (parameters.ContainsKey("num_class"))
            {
                int numClass = Convert.ToInt32(parameters["num_class"]);
                nboost /= numClass;
            }

            var prediction = new VBuffer <Float>();
            var grad       = new VBuffer <Float>();
            var hess       = new VBuffer <Float>();
            var start      = DateTime.Now;

#if (!XGBOOST_RABIT)
            int version = bst.LoadRabitCheckpoint();
            ch.Check(WrappedXGBoostInterface.RabitGetWorldSize() != 1 || version == 0);
#else
            int version = 0;
#endif
            int startIteration = version / 2;
            nboost += startIteration;
            int logten = 0;
            int temp   = numBoostRound * 5;
            while (temp > 0)
            {
                logten += 1;
                temp   /= 10;
            }
            temp   = Math.Max(logten - 2, 0);
            logten = 1;
            while (temp-- > 0)
            {
                logten *= 10;
            }

            var metrics = new List <string>()
            {
                "Iteration", "Training Time"
            };
            var units = new List <string>()
            {
                "iterations", "seconds"
            };
            if (verboseEval)
            {
                metrics.Add("Training Error");
                metrics.Add(parameters["objective"]);
            }
            var header = new ProgressHeader(metrics.ToArray(), units.ToArray());

            int    iter       = 0;
            double trainTime  = 0;
            double trainError = double.NaN;

            pch.SetHeader(header, e =>
            {
                e.SetProgress(0, iter, numBoostRound - startIteration);
                e.SetProgress(1, trainTime);
                if (verboseEval)
                {
                    e.SetProgress(2, trainError);
                }
            });
            for (iter = startIteration; iter < numBoostRound; ++iter)
            {
                if (version % 2 == 0)
                {
                    bst.Update(dtrain, iter, ref grad, ref hess, ref prediction, obj);
#if (!XGBOOST_RABIT)
                    bst.SaveRabitCheckpoint();
#endif
                    version += 1;
                }

#if (!XGBOOST_RABIT)
                ch.Check(WrappedXGBoostInterface.RabitGetWorldSize() == 1 ||
                         version == WrappedXGBoostInterface.RabitVersionNumber());
#endif
                nboost += 1;

                trainTime = (DateTime.Now - start).TotalMilliseconds;

                if (verboseEval)
                {
                    pch.Checkpoint(new double?[] { iter, trainTime, trainError });
                    if (iter == startIteration || iter == numBoostRound - 1 || iter % logten == 0 ||
                        (DateTime.Now - start) > TimeSpan.FromMinutes(2))
                    {
                        string strainError = bst.EvalSet(new[] { dtrain }, new[] { "Train" }, iter);
                        // Example: "[0]\tTrain-error:0.028612"
                        if (!string.IsNullOrEmpty(strainError) && strainError.Contains(":"))
                        {
                            double val;
                            if (double.TryParse(strainError.Split(':').Last(), out val))
                            {
                                trainError = val;
                            }
                        }
                    }
                }
                else
                {
                    pch.Checkpoint(new double?[] { iter, trainTime });
                }

                version += 1;
            }
            numberOfTrees = numBoostRound * numParallelTree;
            if (WrappedXGBoostInterface.RabitIsDistributed() == 1)
            {
                var pname = WrappedXGBoostInterface.RabitGetProcessorName();
                ch.Info("[WrappedXGBoostTraining.Train] end {0}:{1}", pname, WrappedXGBoostInterface.RabitGetRank());
            }
            return(bst);
        }