private static IDataLoader LoadStopwords(IHostEnvironment env, IChannel ch, string dataFile, IComponentFactory <IMultiStreamSource, IDataLoader> loader, ref string stopwordsCol) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); MultiFileSource fileSource = new MultiFileSource(dataFile); IDataLoader dataLoader; // First column using the file. if (loader == null) { // Determine the default loader from the extension. var ext = Path.GetExtension(dataFile); bool isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); if (isBinary || isTranspose) { ch.Assert(isBinary != isTranspose); ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Arguments.StopwordsColumn), "stopwordsColumn should be specified"); if (isBinary) { dataLoader = new BinaryLoader(env, new BinaryLoader.Arguments(), fileSource); } else { ch.Assert(isTranspose); dataLoader = new TransposeLoader(env, new TransposeLoader.Arguments(), fileSource); } } else { if (!string.IsNullOrWhiteSpace(stopwordsCol)) { ch.Warning("{0} should not be specified when default loader is TextLoader. Ignoring stopwordsColumn={0}", stopwordsCol); } dataLoader = TextLoader.Create( env, new TextLoader.Arguments() { Separator = "tab", Column = new[] { new TextLoader.Column("Stopwords", DataKind.TX, 0) } }, fileSource); stopwordsCol = "Stopwords"; } ch.AssertNonEmpty(stopwordsCol); } else { dataLoader = loader.CreateComponent(env, fileSource); } return(dataLoader); }
protected override void CheckArgs(IChannel ch) { Contracts.AssertValue(ch); base.CheckArgs(ch); ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics >= 1 && Args.EarlyStoppingMetrics <= 2), nameof(Args.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); }
private protected override void CheckOptions(IChannel ch) { Contracts.AssertValue(ch); base.CheckOptions(ch); ch.CheckUserArg((FastTreeTrainerOptions.EarlyStoppingRule == null && !FastTreeTrainerOptions.EnablePruning) || (FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2), nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); }
// Returns true if a normalizer was added. public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValue(view, nameof(view)); ch.CheckValueOrNull(featureColumn); ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures), "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'."); if (autoNorm == NormalizeOption.No) { ch.Info("Not adding a normalizer."); return(false); } if (string.IsNullOrEmpty(featureColumn)) { return(false); } int featCol; var schema = view.Schema; if (schema.TryGetColumnIndex(featureColumn, out featCol)) { if (autoNorm != NormalizeOption.Yes) { DvBool isNormalized = DvBool.False; if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol)) { ch.Info("Not adding a normalizer."); return(false); } if (autoNorm == NormalizeOption.Warn) { ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options."); return(false); } } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input) => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn); if (view is IDataLoader loader) { view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer); } else { view = ApplyNormalizer(env, view); } return(true); } return(false); }
private protected override void CheckOptions(IChannel ch) { Contracts.AssertValue(ch); base.CheckOptions(ch); bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || FastTreeTrainerOptions.EnablePruning; if (doEarlyStop) { ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics >= 1 && FastTreeTrainerOptions.EarlyStoppingMetrics <= 2, nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); } }
private IEnumerable <int> CreateSparse(IChannel ch, Random rgen) { ch.CheckUserArg(0 <= _param && _param < 1, nameof(Arguments.Parameter), "For sparse ararys"); // The parameter is the level of sparsity. Use the geometric distribution to determine the number of // Geometric distribution (with 0 support) would be Math. double denom = Math.Log(1 - _param); if (double.IsNegativeInfinity(denom)) { // The parameter is so high, it's effectively dense. foreach (int v in CreateDense(ch, rgen)) { yield return(0); } yield break; } if (denom == 0) { // The parameter must have been so small that we effectively will never have an "on" entry. for (int i = 0; i < _len; ++i) { yield return(0); } yield break; } ch.Assert(FloatUtils.IsFinite(denom) && denom < 0); int remaining = _len; while (remaining > 0) { // A value being sparse or not we view as a Bernoulli trial, so, we can more efficiently // model the number of sparse values as being a geometric distribution. This reduces the // number of calls to the random number generator considerable vs. the naive sampling. double r = 1 - rgen.NextDouble(); // Has support in [0,1). We subtract 1-r to make support in (0,1]. int numZeros = (int)Math.Min(Math.Floor(Math.Log(r) / denom), remaining); for (int i = 0; i < numZeros; ++i) { yield return(0); } if ((remaining -= numZeros) > 0) { --remaining; yield return(rgen.Next(_bins)); } } ch.Assert(remaining == 0); }
protected override void CheckArgs(IChannel ch) { Contracts.AssertValue(ch); base.CheckArgs(ch); // REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just // a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now // we just leave the existing regression checks, though with a warning. if (Args.EarlyStoppingMetrics > 0) { ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution."); } ch.CheckUserArg((Args.EarlyStoppingRule == null && !Args.EnablePruning) || (Args.EarlyStoppingMetrics >= 1 && Args.EarlyStoppingMetrics <= 2), nameof(Args.EarlyStoppingMetrics), "earlyStoppingMetrics should be 1 or 2. (1: L1, 2: L2)"); }
private static IDataLoader LoadStopwords(IHostEnvironment env, IChannel ch, string dataFile, SubComponent <IDataLoader, SignatureDataLoader> loader, ref string stopwordsCol) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); // First column using the file. if (!loader.IsGood()) { // Determine the default loader from the extension. var ext = Path.GetExtension(dataFile); bool isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); if (isBinary || isTranspose) { ch.Assert(isBinary != isTranspose); ch.CheckUserArg(!string.IsNullOrWhiteSpace(stopwordsCol), nameof(Arguments.StopwordsColumn), "stopwordsColumn should be specified"); if (isBinary) { loader = new SubComponent <IDataLoader, SignatureDataLoader>("BinaryLoader"); } else { ch.Assert(isTranspose); loader = new SubComponent <IDataLoader, SignatureDataLoader>("TransposeLoader"); } } else { if (!string.IsNullOrWhiteSpace(stopwordsCol)) { ch.Warning("{0} should not be specified when default loader is TextLoader. Ignoring stopwordsColumn={0}", stopwordsCol); } loader = new SubComponent <IDataLoader, SignatureDataLoader>("TextLoader", "sep=tab col=Stopwords:TX:0"); stopwordsCol = "Stopwords"; } } ch.AssertNonEmpty(stopwordsCol); return(loader.CreateInstance(env, new MultiFileSource(dataFile))); }
private protected override void CheckOptions(IChannel ch) { Contracts.AssertValue(ch); base.CheckOptions(ch); // REVIEW: In order to properly support early stopping, the early stopping metric should be a subcomponent, not just // a simple integer, because the metric that we might want is parameterized by this floating point "index" parameter. For now // we just leave the existing regression checks, though with a warning. if (FastTreeTrainerOptions.EarlyStoppingMetrics > 0) { ch.Warning("For Tweedie regression, early stopping does not yet use the Tweedie distribution."); } bool doEarlyStop = FastTreeTrainerOptions.EarlyStoppingRuleFactory != null || FastTreeTrainerOptions.EnablePruning; // Please do not remove it! See comment above. if (doEarlyStop) { ch.CheckUserArg(FastTreeTrainerOptions.EarlyStoppingMetrics == 1 || FastTreeTrainerOptions.EarlyStoppingMetrics == 2, nameof(FastTreeTrainerOptions.EarlyStoppingMetrics), "should be 1 (L1-norm) or 2 (L2-norm)."); } }
private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParameters predictor, int weightSetCount) { int numFeatures = data.Schema.Feature.Value.Type.GetVectorSize(); var cursorFactory = new FloatLabelCursor.Factory(data, CursOpt.Label | CursOpt.Features); int numThreads = 1; ch.CheckUserArg(numThreads > 0, nameof(_options.NumberOfThreads), "The number of threads must be either null or a positive integer."); var positiveInstanceWeight = _options.PositiveInstanceWeight; VBuffer <float> weights = default; float bias = 0.0f; if (predictor != null) { predictor.GetFeatureWeights(ref weights); VBufferUtils.Densify(ref weights); bias = predictor.Bias; } else { weights = VBufferUtils.CreateDense <float>(numFeatures); } var weightsEditor = VBufferEditor.CreateFromBuffer(ref weights); // Reference: Parasail. SymSGD. bool tuneLR = _options.LearningRate == null; var lr = _options.LearningRate ?? 1.0f; bool tuneNumLocIter = (_options.UpdateFrequency == null); var numLocIter = _options.UpdateFrequency ?? 1; var l2Const = _options.L2Regularization; var piw = _options.PositiveInstanceWeight; // This is state of the learner that is shared with the native code. State state = new State(); GCHandle stateGCHandle = default; try { stateGCHandle = GCHandle.Alloc(state, GCHandleType.Pinned); state.TotalInstancesProcessed = 0; using (InputDataManager inputDataManager = new InputDataManager(this, cursorFactory, ch)) { bool shouldInitialize = true; using (var pch = Host.StartProgressChannel("Preprocessing")) inputDataManager.LoadAsMuchAsPossible(); int iter = 0; if (inputDataManager.IsFullyLoaded) { ch.Info("Data fully loaded into memory."); } using (var pch = Host.StartProgressChannel("Training")) { if (inputDataManager.IsFullyLoaded) { pch.SetHeader(new ProgressHeader(new[] { "iterations" }), entry => entry.SetProgress(0, state.PassIteration, _options.NumberOfIterations)); // If fully loaded, call the SymSGDNative and do not come back until learned for all iterations. Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures, _options.NumberOfIterations, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize, stateGCHandle, ch.Info); shouldInitialize = false; } else { pch.SetHeader(new ProgressHeader(new[] { "iterations" }), entry => entry.SetProgress(0, iter, _options.NumberOfIterations)); // Since we loaded data in batch sizes, multiple passes over the loaded data is feasible. int numPassesForABatch = inputDataManager.Count / 10000; while (iter < _options.NumberOfIterations) { // We want to train on the final passes thoroughly (without learning on the same batch multiple times) // This is for fine tuning the AUC. Experimentally, we found that 1 or 2 passes is enough int numFinalPassesToTrainThoroughly = 2; // We also do not want to learn for more passes than what the user asked int numPassesForThisBatch = Math.Min(numPassesForABatch, _options.NumberOfIterations - iter - numFinalPassesToTrainThoroughly); // If all of this leaves us with 0 passes, then set numPassesForThisBatch to 1 numPassesForThisBatch = Math.Max(1, numPassesForThisBatch); state.PassIteration = iter; Native.LearnAll(inputDataManager, tuneLR, ref lr, l2Const, piw, weightsEditor.Values, ref bias, numFeatures, numPassesForThisBatch, numThreads, tuneNumLocIter, ref numLocIter, _options.Tolerance, _options.Shuffle, shouldInitialize, stateGCHandle, ch.Info); shouldInitialize = false; // Check if we are done with going through the data if (inputDataManager.FinishedTheLoad) { iter += numPassesForThisBatch; // Check if more passes are left if (iter < _options.NumberOfIterations) { inputDataManager.RestartLoading(_options.Shuffle, Host); } } // If more passes are left, load as much as possible if (iter < _options.NumberOfIterations) { inputDataManager.LoadAsMuchAsPossible(); } } } // Maps back the dense features that are mislocated if (numThreads > 1) { Native.MapBackWeightVector(weightsEditor.Values, stateGCHandle); } Native.DeallocateSequentially(stateGCHandle); } } } finally { if (stateGCHandle.IsAllocated) { stateGCHandle.Free(); } } return(CreatePredictor(weights, bias)); }
/// <summary> /// This builds the <see cref="TermMap"/> instances per column. /// </summary> private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] infos, ArgumentsBase args, ColumnBase[] column, IDataView trainingData) { Contracts.AssertValue(env); env.AssertValue(ch); ch.AssertValue(infos); ch.AssertValue(args); ch.AssertValue(column); ch.AssertValue(trainingData); if ((args.Term != null || !string.IsNullOrEmpty(args.Terms)) && (!string.IsNullOrWhiteSpace(args.DataFile) || args.Loader.IsGood() || !string.IsNullOrWhiteSpace(args.TermsColumn))) { ch.Warning("Explicit term list specified. Data file arguments will be ignored"); } if (!Enum.IsDefined(typeof(SortOrder), args.Sort)) { throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected", args.Sort); } TermMap termsFromFile = null; var termMap = new TermMap[infos.Length]; int[] lims = new int[infos.Length]; int trainsNeeded = 0; HashSet <int> toTrain = null; for (int iinfo = 0; iinfo < infos.Length; iinfo++) { // First check whether we have a terms argument, and handle it appropriately. var terms = new DvText(column[iinfo].Terms); var termsArray = column[iinfo].Term; if (!terms.HasChars && termsArray == null) { terms = new DvText(args.Terms); termsArray = args.Term; } terms = terms.Trim(); if (terms.HasChars || (termsArray != null && termsArray.Length > 0)) { // We have terms! Pass it in. var sortOrder = column[iinfo].Sort ?? args.Sort; if (!Enum.IsDefined(typeof(SortOrder), sortOrder)) { throw ch.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, infos[iinfo].Name); } var bldr = Builder.Create(infos[iinfo].TypeSrc, sortOrder); if (terms.HasChars) { bldr.ParseAddTermArg(ref terms, ch); } else { bldr.ParseAddTermArg(termsArray, ch); } termMap[iinfo] = bldr.Finish(); } else if (!string.IsNullOrWhiteSpace(args.DataFile)) { // First column using this file. if (termsFromFile == null) { var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort); termsFromFile = CreateFileTermMap(env, ch, args, bldr); } if (!termsFromFile.ItemType.Equals(infos[iinfo].TypeSrc.ItemType)) { // We have no current plans to support re-interpretation based on different column // type, not only because it's unclear what realistic customer use-cases for such // a complicated feature would be, and also because it's difficult to see how we // can logically reconcile "reinterpretation" for different types with the resulting // data view having an actual type. throw ch.ExceptUserArg(nameof(args.DataFile), "Data file terms loaded as type '{0}' but mismatches column '{1}' item type '{2}'", termsFromFile.ItemType, infos[iinfo].Name, infos[iinfo].TypeSrc.ItemType); } termMap[iinfo] = termsFromFile; } else { // Auto train this column. Leave the term map null for now, but set the lim appropriately. lims[iinfo] = column[iinfo].MaxNumTerms ?? args.MaxNumTerms; ch.CheckUserArg(lims[iinfo] > 0, nameof(Column.MaxNumTerms), "Must be positive"); Utils.Add(ref toTrain, infos[iinfo].Source); ++trainsNeeded; } } ch.Assert((Utils.Size(toTrain) == 0) == (trainsNeeded == 0)); ch.Assert(Utils.Size(toTrain) <= trainsNeeded); if (trainsNeeded > 0) { Trainer[] trainer = new Trainer[trainsNeeded]; int[] trainerInfo = new int[trainsNeeded]; // Open the cursor, then instantiate the trainers. int itrainer; using (var cursor = trainingData.GetRowCursor(toTrain.Contains)) using (var pch = env.StartProgressChannel("Building term dictionary")) { long rowCur = 0; double rowCount = trainingData.GetRowCount(true) ?? double.NaN; var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" }); itrainer = 0; for (int iinfo = 0; iinfo < infos.Length; ++iinfo) { if (termMap[iinfo] != null) { continue; } var bldr = Builder.Create(infos[iinfo].TypeSrc, column[iinfo].Sort ?? args.Sort); trainerInfo[itrainer] = iinfo; trainer[itrainer++] = Trainer.Create(cursor, infos[iinfo].Source, false, lims[iinfo], bldr); } ch.Assert(itrainer == trainer.Length); pch.SetHeader(header, e => { e.SetProgress(0, rowCur, rowCount); // Purely feedback for the user. That the other thread might be // working in the background is not a problem. e.SetMetric(0, trainer.Sum(t => t.Count)); }); // The [0,tmin) trainers are finished. int tmin = 0; // We might exit early if all trainers reach their maximum. while (tmin < trainer.Length && cursor.MoveNext()) { rowCur++; for (int t = tmin; t < trainer.Length; ++t) { if (!trainer[t].ProcessRow()) { Utils.Swap(ref trainerInfo[t], ref trainerInfo[tmin]); Utils.Swap(ref trainer[t], ref trainer[tmin++]); } } } pch.Checkpoint(trainer.Sum(t => t.Count), rowCur); } for (itrainer = 0; itrainer < trainer.Length; ++itrainer) { int iinfo = trainerInfo[itrainer]; ch.Assert(termMap[iinfo] == null); if (trainer[itrainer].Count == 0) { ch.Warning("Term map for output column '{0}' contains no entries.", infos[iinfo].Name); } termMap[iinfo] = trainer[itrainer].Finish(); // Allow the intermediate structures in the trainer and builder to be released as we iterate // over the columns, as the Finish operation can potentially result in the allocation of // additional structures. trainer[itrainer] = null; } ch.Assert(termMap.All(tm => tm != null)); ch.Assert(termMap.Zip(infos, (tm, info) => tm.ItemType.Equals(info.TypeSrc.ItemType)).All(x => x)); } return(termMap); }
/// <summary> /// Utility method to create the file-based <see cref="TermMap"/> if the <see cref="ArgumentsBase.DataFile"/> /// argument of <paramref name="args"/> was present. /// </summary> private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, ArgumentsBase args, Builder bldr) { Contracts.AssertValue(ch); ch.AssertValue(env); ch.AssertValue(args); ch.Assert(!string.IsNullOrWhiteSpace(args.DataFile)); ch.AssertValue(bldr); string file = args.DataFile; // First column using the file. string src = args.TermsColumn; var sub = args.Loader; // If the user manually specifies a loader, or this is already a pre-processed binary // file, then we assume the user knows what they're doing and do not attempt to convert // to the desired type ourselves. bool autoConvert = false; if (!sub.IsGood()) { // Determine the default loader from the extension. var ext = Path.GetExtension(file); bool isBinary = string.Equals(ext, ".idv", StringComparison.OrdinalIgnoreCase); bool isTranspose = string.Equals(ext, ".tdv", StringComparison.OrdinalIgnoreCase); if (isBinary || isTranspose) { ch.Assert(isBinary != isTranspose); ch.CheckUserArg(!string.IsNullOrWhiteSpace(src), nameof(args.TermsColumn), "Must be specified"); if (isBinary) { sub = new SubComponent <IDataLoader, SignatureDataLoader>("BinaryLoader"); } else { ch.Assert(isTranspose); sub = new SubComponent <IDataLoader, SignatureDataLoader>("TransposeLoader"); } } else { if (!string.IsNullOrWhiteSpace(src)) { ch.Warning( "{0} should not be specified when default loader is TextLoader. Ignoring {0}={1}", nameof(Arguments.TermsColumn), src); } sub = new SubComponent <IDataLoader, SignatureDataLoader>("TextLoader", "sep=tab col=Term:TX:0"); src = "Term"; autoConvert = true; } } ch.AssertNonEmpty(src); int colSrc; var loader = sub.CreateInstance(env, new MultiFileSource(file)); if (!loader.Schema.TryGetColumnIndex(src, out colSrc)) { throw ch.ExceptUserArg(nameof(args.TermsColumn), "Unknown column '{0}'", src); } var typeSrc = loader.Schema.GetColumnType(colSrc); if (!autoConvert && !typeSrc.Equals(bldr.ItemType)) { throw ch.ExceptUserArg(nameof(args.TermsColumn), "Must be of type '{0}' but was '{1}'", bldr.ItemType, typeSrc); } using (var cursor = loader.GetRowCursor(col => col == colSrc)) using (var pch = env.StartProgressChannel("Building term dictionary from file")) { var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" }); var trainer = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr); double rowCount = loader.GetRowCount(true) ?? double.NaN; long rowCur = 0; pch.SetHeader(header, e => { e.SetProgress(0, rowCur, rowCount); // Purely feedback for the user. That the other thread might be // working in the background is not a problem. e.SetMetric(0, trainer.Count); }); while (cursor.MoveNext() && trainer.ProcessRow()) { rowCur++; } if (trainer.Count == 0) { ch.Warning("Term map loaded from file resulted in an empty map."); } pch.Checkpoint(trainer.Count, rowCur); return(trainer.Finish()); } }
private TPredictor TrainCore(IChannel ch, RoleMappedData data) { Host.AssertValue(ch); ch.AssertValue(data); // 1. Subset Selection var stackingTrainer = Combiner as IStackingTrainer <TOutput>; //REVIEW: Implement stacking for Batch mode. ch.CheckUserArg(stackingTrainer == null || Args.BatchSize <= 0, nameof(Args.BatchSize), "Stacking works only with Non-batch mode"); var validationDataSetProportion = SubModelSelector.ValidationDatasetProportion; if (stackingTrainer != null) { validationDataSetProportion = Math.Max(validationDataSetProportion, stackingTrainer.ValidationDatasetProportion); } var needMetrics = Args.ShowMetrics || Combiner is IWeightedAverager; var models = new List <FeatureSubsetModel <IPredictorProducing <TOutput> > >(); _subsetSelector.Initialize(data, NumModels, Args.BatchSize, validationDataSetProportion); int batchNumber = 1; foreach (var batch in _subsetSelector.GetBatches(Host.Rand)) { // 2. Core train ch.Info("Training {0} learners for the batch {1}", Trainers.Length, batchNumber++); var batchModels = new FeatureSubsetModel <IPredictorProducing <TOutput> > [Trainers.Length]; Parallel.ForEach(_subsetSelector.GetSubsets(batch, Host.Rand), new ParallelOptions() { MaxDegreeOfParallelism = Args.TrainParallel ? -1 : 1 }, (subset, state, index) => { ch.Info("Beginning training model {0} of {1}", index + 1, Trainers.Length); Stopwatch sw = Stopwatch.StartNew(); try { if (EnsureMinimumFeaturesSelected(subset)) { var model = new FeatureSubsetModel <IPredictorProducing <TOutput> >( Trainers[(int)index].Train(subset.Data), subset.SelectedFeatures, null); SubModelSelector.CalculateMetrics(model, _subsetSelector, subset, batch, needMetrics); batchModels[(int)index] = model; } } catch (Exception ex) { ch.Assert(batchModels[(int)index] == null); ch.Warning(ex.Sensitivity(), "Trainer {0} of {1} was not learned properly due to the exception '{2}' and will not be added to models.", index + 1, Trainers.Length, ex.Message); } ch.Info("Trainer {0} of {1} finished in {2}", index + 1, Trainers.Length, sw.Elapsed); }); var modelsList = batchModels.Where(m => m != null).ToList(); if (Args.ShowMetrics) { PrintMetrics(ch, modelsList); } modelsList = SubModelSelector.Prune(modelsList).ToList(); if (stackingTrainer != null) { stackingTrainer.Train(modelsList, _subsetSelector.GetTestData(null, batch), Host); } models.AddRange(modelsList); int modelSize = Utils.Size(models); if (modelSize < Utils.Size(Trainers)) { ch.Warning("{0} of {1} trainings failed.", Utils.Size(Trainers) - modelSize, Utils.Size(Trainers)); } ch.Check(modelSize > 0, "Ensemble training resulted in no valid models."); } return(CreatePredictor(models)); }
private void LoadStopWords(IHostEnvironment env, IChannel ch, ArgumentsBase loaderArgs, out NormStr.Pool stopWordsMap) { Contracts.AssertValue(env); env.AssertValue(ch); ch.AssertValue(loaderArgs); if ((!string.IsNullOrEmpty(loaderArgs.Stopwords) || Utils.Size(loaderArgs.Stopword) > 0) && (!string.IsNullOrWhiteSpace(loaderArgs.DataFile) || loaderArgs.Loader != null || !string.IsNullOrWhiteSpace(loaderArgs.StopwordsColumn))) { ch.Warning("Explicit stopwords list specified. Data file arguments will be ignored"); } var src = default(ReadOnlyMemory <char>); stopWordsMap = new NormStr.Pool(); var buffer = new StringBuilder(); var stopwords = loaderArgs.Stopwords.AsMemory(); stopwords = ReadOnlyMemoryUtils.TrimSpaces(stopwords); if (!stopwords.IsEmpty) { bool warnEmpty = true; for (bool more = true; more;) { ReadOnlyMemory <char> stopword; more = ReadOnlyMemoryUtils.SplitOne(stopwords, ',', out stopword, out stopwords); stopword = ReadOnlyMemoryUtils.TrimSpaces(stopword); if (!stopword.IsEmpty) { buffer.Clear(); ReadOnlyMemoryUtils.AddLowerCaseToStringBuilder(stopword.Span, buffer); stopWordsMap.Add(buffer); } else if (warnEmpty) { ch.Warning("Empty strings ignored in 'stopwords' specification"); warnEmpty = false; } } ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Arguments.Stopwords), "stopwords is empty"); } else if (Utils.Size(loaderArgs.Stopword) > 0) { bool warnEmpty = true; foreach (string word in loaderArgs.Stopword) { var stopword = word.AsSpan(); stopword = stopword.Trim(' '); if (!stopword.IsEmpty) { buffer.Clear(); ReadOnlyMemoryUtils.AddLowerCaseToStringBuilder(stopword, buffer); stopWordsMap.Add(buffer); } else if (warnEmpty) { ch.Warning("Empty strings ignored in 'stopword' specification"); warnEmpty = false; } } } else { string srcCol = loaderArgs.StopwordsColumn; var loader = LoadStopwords(env, ch, loaderArgs.DataFile, loaderArgs.Loader, ref srcCol); int colSrc; if (!loader.Schema.TryGetColumnIndex(srcCol, out colSrc)) { throw ch.ExceptUserArg(nameof(Arguments.StopwordsColumn), "Unknown column '{0}'", srcCol); } var typeSrc = loader.Schema[colSrc].Type; ch.CheckUserArg(typeSrc.IsText, nameof(Arguments.StopwordsColumn), "Must be a scalar text column"); // Accumulate the stopwords. using (var cursor = loader.GetRowCursor(col => col == colSrc)) { bool warnEmpty = true; var getter = cursor.GetGetter <ReadOnlyMemory <char> >(colSrc); while (cursor.MoveNext()) { getter(ref src); if (!src.IsEmpty) { buffer.Clear(); ReadOnlyMemoryUtils.AddLowerCaseToStringBuilder(src.Span, buffer); stopWordsMap.Add(buffer); } else if (warnEmpty) { ch.Warning("Empty rows ignored in data file"); warnEmpty = false; } } } ch.CheckUserArg(stopWordsMap.Count > 0, nameof(Arguments.DataFile), "dataFile is empty"); } }
// Returns true if a normalizer was added. public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); ch.CheckValue(trainer, nameof(trainer)); ch.CheckValue(view, nameof(view)); ch.CheckValueOrNull(featureColumn); ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures), "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'."); if (autoNorm == NormalizeOption.No) { ch.Info("Not adding a normalizer."); return(false); } if (string.IsNullOrEmpty(featureColumn)) { return(false); } int featCol; var schema = view.Schema; if (schema.TryGetColumnIndex(featureColumn, out featCol)) { if (autoNorm != NormalizeOption.Yes) { var nn = trainer as ITrainerEx; DvBool isNormalized = DvBool.False; if (nn == null || !nn.NeedNormalization || (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) && isNormalized.IsTrue)) { ch.Info("Not adding a normalizer."); return(false); } if (autoNorm == NormalizeOption.Warn) { ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options."); return(false); } } ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off."); // Quote the feature column name string quotedFeatureColumnName = featureColumn; StringBuilder sb = new StringBuilder(); if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb)) { quotedFeatureColumnName = sb.ToString(); } var component = new SubComponent <IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName)); var loader = view as IDataLoader; if (loader != null) { view = CompositeDataLoader.Create(env, loader, new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(null, component)); } else { view = component.CreateInstance(env, view); } return(true); } return(false); }
public static Bindings Create(Arguments args, ISchema input, IChannel ch, bool exceptUser = true) { Contracts.AssertValue(ch); ch.AssertValue(args); ch.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column), "Column is missing"); ch.CheckUserArg(Enum.IsDefined(typeof(Language), args.Language), nameof(args.Language), "'lang' has invalid value"); int namesLen = 0; foreach (var col in args.Column) { ch.CheckUserArg(!string.IsNullOrWhiteSpace(col.Source), nameof(col.Source), "Cannot be empty"); ch.CheckUserArg(!string.IsNullOrWhiteSpace(col.TokensColumn), nameof(col.TokensColumn), "Cannot be empty"); ch.CheckUserArg(!col.Language.HasValue || Enum.IsDefined(typeof(Language), col.Language), nameof(col.Language), "Value is invalid"); namesLen += string.IsNullOrWhiteSpace(col.TypesColumn) && !args.CreateTypesColumn ? 1 : 2; } int index = 0; var names = new string[namesLen]; var infos = new ColInfo[names.Length]; var groups = new ColGroupInfo[args.Column.Length]; for (int i = 0; i < args.Column.Length; i++) { var col = args.Column[i]; int srcIdx; ColumnType srcType; Bind(input, col.Source, t => t.ItemType.IsText, SrcTypeName, out srcIdx, out srcType, exceptUser); ch.Assert(srcIdx >= 0); int langsColIdx; string langsCol = !string.IsNullOrWhiteSpace(col.LanguagesColumn) ? col.LanguagesColumn : args.LanguagesColumn; if (!string.IsNullOrWhiteSpace(langsCol)) { ColumnType langsColType; Bind(input, langsCol, t => t.IsText, LangTypeName, out langsColIdx, out langsColType, exceptUser); ch.Assert(langsColIdx >= 0); } else { langsCol = null; langsColIdx = -1; } var lang = col.Language ?? args.Language; bool requireTypes = !string.IsNullOrWhiteSpace(col.TypesColumn) || args.CreateTypesColumn; groups[i] = new ColGroupInfo(lang, srcIdx, col.Source, srcType, langsColIdx, langsCol, requireTypes); names[index] = col.TokensColumn; infos[index++] = new ColInfo(i); if (requireTypes) { names[index] = !string.IsNullOrWhiteSpace(col.TypesColumn) ? col.TypesColumn : string.Format("{0}_Tokens", col.Source); infos[index++] = new ColInfo(i, isTypes: true); } } ch.Assert(index == namesLen); return new Bindings(groups, infos.ToArray(), input, exceptUser, names.ToArray()); }