Esempio n. 1
0
        private void RunCore(IChannel ch)
        {
            ch.Trace("Constructing data pipeline");
            IDataLoader      loader;
            IPredictor       predictor;
            RoleMappedSchema trainSchema;

            LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);

            ch.Trace("Binding columns");
            ISchema schema = loader.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);

            // Score.
            ch.Trace("Scoring and evaluating");
            ch.Assert(Args.Scorer == null || Args.Scorer is ICommandLineComponentFactory, "TestCommand should only be used from the command line.");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema);

            // Evaluate.
            var evaluator = Args.Evaluator?.CreateComponent(Host) ??
                            EvaluateUtils.GetEvaluator(Host, scorePipe.Schema);
            var data    = new RoleMappedData(scorePipe, label, null, group, weight, name, customCols);
            var metrics = evaluator.Evaluate(data);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall))
            {
                throw ch.Except("No overall metrics found");
            }
            overall = evaluator.GetOverallResults(overall);
            MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1);
            evaluator.PrintAdditionalMetrics(ch, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(data);
                var perInstData = new RoleMappedData(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }
 protected override void UpdateXGBoostOptions(IChannel ch, Dictionary <string, string> options, Float[] labels, uint[] groups)
 {
     Contracts.AssertValue(ch, nameof(ch));
     ch.AssertValue(options, nameof(options));
     ch.AssertValue(labels, nameof(labels));
     if (!options.ContainsKey("objective"))
     {
         options["objective"] = "reg:linear";
     }
 }
Esempio n. 3
0
        private void RunCore(IChannel ch)
        {
            ch.Trace("Constructing data pipeline");
            IDataLoader      loader;
            IPredictor       predictor;
            RoleMappedSchema trainSchema;

            LoadModelObjects(ch, true, out predictor, true, out trainSchema, out loader);
            ch.AssertValue(predictor);
            ch.AssertValueOrNull(trainSchema);
            ch.AssertValue(loader);

            ch.Trace("Binding columns");
            ISchema schema = loader.Schema;
            string  label  = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.LabelColumn),
                                                                 Args.LabelColumn, DefaultColumnNames.Label);
            string features = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.FeatureColumn),
                                                                  Args.FeatureColumn, DefaultColumnNames.Features);
            string group = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.GroupColumn),
                                                               Args.GroupColumn, DefaultColumnNames.GroupId);
            string weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.WeightColumn),
                                                                Args.WeightColumn, DefaultColumnNames.Weight);
            string name = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Args.NameColumn),
                                                              Args.NameColumn, DefaultColumnNames.Name);
            var customCols = TrainUtils.CheckAndGenerateCustomColumns(ch, Args.CustomColumn);

            // Score.
            ch.Trace("Scoring and evaluating");
            IDataScorerTransform scorePipe = ScoreUtils.GetScorer(Args.Scorer, predictor, loader, features, group, customCols, Host, trainSchema);

            // Evaluate.
            var evalComp = Args.Evaluator;

            if (!evalComp.IsGood())
            {
                evalComp = EvaluateUtils.GetEvaluatorType(ch, scorePipe.Schema);
            }
            var evaluator = evalComp.CreateInstance(Host);
            var data      = TrainUtils.CreateExamples(scorePipe, label, null, group, weight, name, customCols);
            var metrics   = evaluator.Evaluate(data);

            MetricWriter.PrintWarnings(ch, metrics);
            evaluator.PrintFoldResults(ch, metrics);
            evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics);
            Dictionary <string, IDataView>[] metricValues = { metrics };
            SendTelemetryMetric(metricValues);
            if (!string.IsNullOrWhiteSpace(Args.OutputDataFile))
            {
                var perInst     = evaluator.GetPerInstanceMetrics(data);
                var perInstData = TrainUtils.CreateExamples(perInst, label, null, group, weight, name, customCols);
                var idv         = evaluator.GetPerInstanceDataViewToSave(perInstData);
                MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, idv);
            }
        }
Esempio n. 4
0
                protected DataViewCursorBase(IHostEnvironment env, DataViewBase <TRow> dataView,
                                             Func <int, bool> predicate)
                    : base(env, dataView.Schema, dataView._schemaDefn, dataView._peeks, predicate)
                {
                    Contracts.AssertValue(env);
                    Ch = env.Start("Cursor");
                    Ch.AssertValue(dataView);
                    Ch.AssertValue(predicate);

                    DataView  = dataView;
                    _position = -1;
                    State     = CursorState.NotStarted;
                }
Esempio n. 5
0
        public void Dispose()
        {
            _watch.Stop();

            long physicalMemoryUsageInMB = System.Diagnostics.Process.GetCurrentProcess().PeakWorkingSet64 / 1024 / 1024;

            _ch.Info("Physical memory usage(MB): {0}", physicalMemoryUsageInMB);

            long virtualMemoryUsageInMB = System.Diagnostics.Process.GetCurrentProcess().PeakVirtualMemorySize64 / 1024 / 1024;

            _ch.Info("Virtual memory usage(MB): {0}", virtualMemoryUsageInMB);

            // Print the fractions of seconds if elapsed time is small enough that fractions matter
            Double elapsedSeconds = (Double)_watch.ElapsedMilliseconds / 1000;

            if (elapsedSeconds > 99)
            {
                elapsedSeconds = Math.Round(elapsedSeconds);
            }

            // REVIEW: This is \n\n is to prevent changes across bunch of baseline files.
            // Ideally we should change our comparison method to ignore empty lines.
            _ch.Info("{0}\t Time elapsed(s): {1}\n\n", DateTime.UtcNow, elapsedSeconds);

            using (var pipe = _host.StartPipe <TelemetryMessage>("TelemetryPipe"))
            {
                _ch.AssertValue(pipe);

                pipe.Send(TelemetryMessage.CreateMetric("TLC_RunTime", elapsedSeconds));
                pipe.Send(TelemetryMessage.CreateMetric("TLC_PhysicalMemoryUsageInMB", physicalMemoryUsageInMB));
                pipe.Send(TelemetryMessage.CreateMetric("TLC_VirtualMemoryUsageInMB", virtualMemoryUsageInMB));
            }
        }
Esempio n. 6
0
        /// <summary>
        /// Save schema associations of role/column-name in <paramref name="rep"/>.
        /// </summary>
        internal static void SaveRoleMappings(IHostEnvironment env, IChannel ch, RoleMappedSchema schema, RepositoryWriter rep)
        {
            // REVIEW: Should we also save this stuff, for instance, in some portion of the
            // score command or transform?
            Contracts.AssertValue(env);
            env.AssertValue(ch);
            ch.AssertValue(schema);

            ArrayDataViewBuilder builder = new ArrayDataViewBuilder(env);

            List<string> rolesList = new List<string>();
            List<string> columnNamesList = new List<string>();
            // OrderBy is stable, so there is no danger in it "reordering" columns
            // when a role is filled by multiple columns.
            foreach (var role in schema.GetColumnRoleNames().OrderBy(r => r.Key.Value))
            {
                rolesList.Add(role.Key.Value);
                columnNamesList.Add(role.Value);
            }
            builder.AddColumn("Role", rolesList.ToArray());
            builder.AddColumn("Column", columnNamesList.ToArray());

            using (var entry = rep.CreateEntry(DirTrainingInfo, RoleMappingFile))
            {
                // REVIEW: It seems very important that we have the role mappings
                // be easily human interpretable and even manipulable, but relying on the
                // text saver/loader means that special characters like '\n' won't be reinterpretable.
                // On the other hand, no one is such a big lunatic that they will actually
                // ever go ahead and do something so stupid as that.
                var saver = new TextSaver(env, new TextSaver.Arguments() { Dense = true, Silent = true });
                var view = builder.GetDataView();
                saver.SaveData(entry.Stream, view, Utils.GetIdentityPermutation(view.Schema.ColumnCount));
            }
        }
Esempio n. 7
0
            protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters predictor, OnlineLinearTrainer <TTransformer, TModel> parent)
            {
                Contracts.CheckValue(ch, nameof(ch));
                ch.Check(numFeatures > 0, "Cannot train with zero features!");
                ch.AssertValueOrNull(predictor);
                ch.AssertValue(parent);
                ch.Assert(Iteration == 0);
                ch.Assert(Bias == 0);

                ParentHost = parent.Host;

                ch.Trace("{0} Initializing {1} on {2} features", DateTime.UtcNow, parent.Name, numFeatures);

                // We want a dense vector, to prevent memory creation during training
                // unless we have a lot of features.
                if (predictor != null)
                {
                    ((IHaveFeatureWeights)predictor).GetFeatureWeights(ref Weights);
                    VBufferUtils.Densify(ref Weights);
                    Bias = predictor.Bias;
                }
                else if (!string.IsNullOrWhiteSpace(parent.OnlineLinearTrainerOptions.InitialWeights))
                {
                    ch.Info("Initializing weights and bias to " + parent.OnlineLinearTrainerOptions.InitialWeights);
                    string[] weightStr = parent.OnlineLinearTrainerOptions.InitialWeights.Split(',');
                    if (weightStr.Length != numFeatures + 1)
                    {
                        throw ch.Except(
                                  "Could not initialize weights from 'initialWeights': expecting {0} values to initialize {1} weights and the intercept",
                                  numFeatures + 1, numFeatures);
                    }

                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = float.Parse(weightStr[i], CultureInfo.InvariantCulture);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = float.Parse(weightStr[numFeatures], CultureInfo.InvariantCulture);
                }
                else if (parent.OnlineLinearTrainerOptions.InitialWeightsDiameter > 0)
                {
                    var weightValues = new float[numFeatures];
                    for (int i = 0; i < numFeatures; i++)
                    {
                        weightValues[i] = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                    }
                    Weights = new VBuffer <float>(numFeatures, weightValues);
                    Bias    = parent.OnlineLinearTrainerOptions.InitialWeightsDiameter * (parent.Host.Rand.NextSingle() - (float)0.5);
                }
                else if (numFeatures <= 1000)
                {
                    Weights = VBufferUtils.CreateDense <float>(numFeatures);
                }
                else
                {
                    Weights = VBufferUtils.CreateEmpty <float>(numFeatures);
                }
                WeightsScale = 1;
            }
        private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, IDataView input, string labelName, int labelPermutationSeed)
        {
            Contracts.AssertValue(env);
            env.AssertValue(ch);
            ch.AssertValue(input);
            ch.AssertNonWhiteSpace(labelName);

            int col;

            if (!input.Schema.TryGetColumnIndex(labelName, out col))
            {
                throw ch.Except("Label column '{0}' not found.", labelName);
            }
            ColumnType labelType = input.Schema[col].Type;

            if (!labelType.IsKey)
            {
                if (labelPermutationSeed != 0)
                {
                    ch.Warning(
                        "labelPermutationSeed != 0 only applies on a multi-class learning problem when the label type is a key.");
                }
                return(input);
            }
            return(Utils.MarshalInvoke(AppendFloatMapper <int>, labelType.RawType, env, ch, input, labelName, (KeyType)labelType,
                                       labelPermutationSeed));
        }
        private static IDataView AppendLabelTransform(IHostEnvironment env, IChannel ch, IDataView input, string labelName, int labelPermutationSeed)
        {
            Contracts.AssertValue(env);
            env.AssertValue(ch);
            ch.AssertValue(input);
            ch.AssertNonWhiteSpace(labelName);

            var col = input.Schema.GetColumnOrNull(labelName);

            if (!col.HasValue)
            {
                throw ch.ExceptSchemaMismatch(nameof(input), "Label", labelName);
            }

            ColumnType labelType = col.Value.Type;

            if (!labelType.IsKey)
            {
                if (labelPermutationSeed != 0)
                {
                    ch.Warning(
                        "labelPermutationSeed != 0 only applies on a multi-class learning problem when the label type is a key.");
                }
                return(input);
            }
            return(Utils.MarshalInvoke(AppendFloatMapper <int>, labelType.RawType, env, ch, input, labelName, (KeyType)labelType,
                                       labelPermutationSeed));
        }
        protected SynchronizedCursorBase(IChannelProvider provider, TBase input)
        {
            Contracts.AssertValue(provider, "provider");
            Ch = provider.Start("Cursor");

            Ch.AssertValue(input, "input");
            Input = input;
            _root = Input.GetRootCursor();
        }
            public static Bindings Create(ModelLoadContext ctx, ISchema input, IChannel ch)
            {
                Contracts.AssertValue(ch);
                ch.AssertValue(ctx);

                // *** Binary format ***
                // int: count of group column infos (ie, count of source columns)
                // For each group column info
                //     int: the tokenizer language
                //     int: the id of source column name
                //     int: the id of languages column name
                //     bool: whether the types output is required
                //     For each column info that belongs to this group column info
                //     (either one column info for tokens or two for tokens and types)
                //          int: the id of the column name

                int groupsLen = ctx.Reader.ReadInt32();
                ch.CheckDecode(groupsLen > 0);

                var names = new List<string>();
                var infos = new List<ColInfo>();
                var groups = new ColGroupInfo[groupsLen];
                for (int i = 0; i < groups.Length; i++)
                {
                    int lang = ctx.Reader.ReadInt32();
                    ch.CheckDecode(Enum.IsDefined(typeof(Language), lang));

                    string srcName = ctx.LoadNonEmptyString();
                    int srcIdx;
                    ColumnType srcType;
                    Bind(input, srcName, t => t.ItemType.IsText, SrcTypeName, out srcIdx, out srcType, false);

                    string langsName = ctx.LoadStringOrNull();
                    int langsIdx;
                    if (langsName != null)
                    {
                        ColumnType langsType;
                        Bind(input, langsName, t => t.IsText, LangTypeName, out langsIdx, out langsType, false);
                    }
                    else
                        langsIdx = -1;

                    bool requireTypes = ctx.Reader.ReadBoolByte();
                    groups[i] = new ColGroupInfo((Language)lang, srcIdx, srcName, srcType, langsIdx, langsName, requireTypes);

                    infos.Add(new ColInfo(i));
                    names.Add(ctx.LoadNonEmptyString());
                    if (requireTypes)
                    {
                        infos.Add(new ColInfo(i, isTypes: true));
                        names.Add(ctx.LoadNonEmptyString());
                    }
                }

                return new Bindings(groups, infos.ToArray(), input, false, names.ToArray());
            }
Esempio n. 12
0
        private static void ExecCore(EnvironmentBlock *penv, IHost host, IChannel ch, string graph, int cdata, DataSourceBlock **ppdata)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(host);
            ch.AssertNonEmpty(graph);
            ch.Assert(cdata >= 0);
            ch.Assert(ppdata != null || cdata == 0);

            RunGraphCore(penv, host, graph, cdata, ppdata);
        }
Esempio n. 13
0
        private protected override bool ShouldStop(IChannel ch, ref EarlyStoppingRuleBase earlyStoppingRule, ref int bestIteration)
        {
            if (FastTreeTrainerOptions.EarlyStoppingRuleFactory == null)
            {
                return(false);
            }

            ch.AssertValue(ValidTest);
            ch.AssertValue(TrainTest);

            var validationResult = ValidTest.ComputeTests().First();

            ch.Assert(validationResult.FinalValue >= 0);
            bool lowerIsBetter = validationResult.LowerIsBetter;

            var trainingResult = TrainTest.ComputeTests().First();

            ch.Assert(trainingResult.FinalValue >= 0);

            // Create early stopping rule if it's null.
            if (earlyStoppingRule == null)
            {
                if (FastTreeTrainerOptions.EarlyStoppingRuleFactory != null)
                {
                    earlyStoppingRule = FastTreeTrainerOptions.EarlyStoppingRuleFactory.CreateComponent(Host, lowerIsBetter);
                }
            }

            // Early stopping rule cannot be null!
            ch.Assert(earlyStoppingRule != null);

            bool isBestCandidate;
            bool shouldStop = earlyStoppingRule.CheckScore((float)validationResult.FinalValue,
                                                           (float)trainingResult.FinalValue, out isBestCandidate);

            if (isBestCandidate)
            {
                bestIteration = Ensemble.NumTrees;
            }

            return(shouldStop);
        }
Esempio n. 14
0
        protected override Delegate GetGetterCore(IChannel ch, Row input,
                                                  int iinfo, out Action disposer)
        {
            Host.AssertValue(ch);
            ch.AssertValue(input);
            ch.Assert(0 <= iinfo && iinfo < Infos.Length);
            disposer = null;

            var info = Infos[iinfo];

            return(GetGetter(ch, input, iinfo));
        }
Esempio n. 15
0
        protected SynchronizedCursorBase(IChannelProvider provider, DataViewRowCursor input)
        {
            Contracts.AssertValue(provider);
            Ch = provider.Start("Cursor");

            Ch.AssertValue(input);
            Input = input;
            // If this thing happens to be itself an instance of this class (which, practically, it will
            // be in the majority of situations), we can treat the input as likewise being a passthrough,
            // thereby saving lots of "nested" calls on the stack when doing common operations like movement.
            Root = Input is SynchronizedCursorBase syncInput ? syncInput.Root : input;
        }
        /// <summary>
        /// This method simply prints the overall metrics using EvaluateUtils.PrintConfusionMatrixAndPerFoldResults.
        /// Override if something else is needed.
        /// </summary>
        protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            ch.AssertValue(ch);
            ch.AssertValue(metrics);

            IDataView fold;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
            {
                throw ch.Except("No overall metrics found");
            }

            string weightedMetrics;
            string unweightedMetrics = MetricWriter.GetPerFoldResults(Host, fold, out weightedMetrics);

            if (!string.IsNullOrEmpty(weightedMetrics))
            {
                ch.Info(weightedMetrics);
            }
            ch.Info(unweightedMetrics);
        }
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var       idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;

            if (idvToShuffle.CanShuffle)
            {
                idvToFeedTrain = idvToShuffle;
            }
            else
            {
                var shuffleArgs = new ShuffleTransform.Arguments
                {
                    PoolOnly     = false,
                    ForceShuffle = _args.Shuffle
                };
                idvToFeedTrain = new ShuffleTransform(Host, shuffleArgs, idvToShuffle);
            }

            ch.Assert(idvToFeedTrain.CanShuffle);

            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);

            ch.AssertValue(examplesToFeedTrain.Schema.Label);
            ch.AssertValue(examplesToFeedTrain.Schema.Feature);
            if (examples.Schema.Weight != null)
            {
                ch.AssertValue(examplesToFeedTrain.Schema.Weight);
            }

            int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize;

            ch.Check(numFeatures > 0, "Training set has no features, aborting training.");
            return(examplesToFeedTrain);
        }
        private ValueGetter <bool> GetGetter(IChannel ch, DataViewRow input, int iinfo)
        {
            Host.AssertValue(ch);
            ch.AssertValue(input);
            ch.Assert(0 <= iinfo && iinfo < Infos.Length);

            var info   = Infos[iinfo];
            var column = input.Schema[info.Source];

            ch.Assert(TestIsMulticlassLabel(info.TypeSrc) == null);

            if (info.TypeSrc.GetKeyCount() > 0)
            {
                var  srcGetter = input.GetGetter <uint>(column);
                var  src       = default(uint);
                uint cls       = (uint)(_classIndex[iinfo] + 1);

                return
                    ((ref bool dst) =>
                {
                    srcGetter(ref src);
                    dst = src == cls;
                });
            }
            if (info.TypeSrc == NumberDataViewType.Single)
            {
                var srcGetter = input.GetGetter <float>(column);
                var src       = default(float);

                return
                    ((ref bool dst) =>
                {
                    srcGetter(ref src);
                    dst = src == _classIndex[iinfo];
                });
            }
            if (info.TypeSrc == NumberDataViewType.Double)
            {
                var srcGetter = input.GetGetter <double>(column);
                var src       = default(double);

                return
                    ((ref bool dst) =>
                {
                    srcGetter(ref src);
                    dst = src == _classIndex[iinfo];
                });
            }
            throw Host.ExceptNotSupp($"Label column type is not supported for binary remapping: {info.TypeSrc}. Supported types: key, float, double.");
        }
Esempio n. 19
0
        private static bool AddCacheIfWanted(IHostEnvironment env, IChannel ch, ITrainer trainer, ref RoleMappedData data, bool?cacheData)
        {
            Contracts.AssertValue(env, nameof(env));
            env.AssertValue(ch, nameof(ch));
            ch.AssertValue(trainer, nameof(trainer));
            ch.AssertValue(data, nameof(data));

            bool shouldCache = cacheData ?? !(data.Data is BinaryLoader) && trainer.Info.WantCaching;

            if (shouldCache)
            {
                ch.Trace("Caching");
                var prefetch  = data.Schema.GetColumnRoles().Select(kc => kc.Value.Index).ToArray();
                var cacheView = new CacheDataView(env, data.Data, prefetch);
                // Because the prefetching worked, we know that these are valid columns.
                data = new RoleMappedData(cacheView, data.Schema.GetColumnRoleNames());
            }
            else
            {
                ch.Trace("Not caching");
            }
            return(shouldCache);
        }
        protected override bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion earlyStoppingRule, ref int bestIteration)
        {
            if (Args.EarlyStoppingRule == null)
            {
                return(false);
            }

            ch.AssertValue(ValidTest);
            ch.AssertValue(TrainTest);

            var validationResult = ValidTest.ComputeTests().First();

            ch.Assert(validationResult.FinalValue >= 0);
            bool lowerIsBetter = validationResult.LowerIsBetter;

            var trainingResult = TrainTest.ComputeTests().First();

            ch.Assert(trainingResult.FinalValue >= 0);

            // Create early stopping rule.
            if (earlyStoppingRule == null)
            {
                earlyStoppingRule = Args.EarlyStoppingRule.CreateComponent(Host, lowerIsBetter);
                ch.Assert(earlyStoppingRule != null);
            }

            bool isBestCandidate;
            bool shouldStop = earlyStoppingRule.CheckScore((Float)validationResult.FinalValue,
                                                           (Float)trainingResult.FinalValue, out isBestCandidate);

            if (isBestCandidate)
            {
                bestIteration = Ensemble.NumTrees;
            }

            return(shouldStop);
        }
Esempio n. 21
0
        private static void PrintExceptionData(IChannel ch, Exception ex, bool includeComponents)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(ex);

            var sb = new StringBuilder();

            using (var sw = new StringWriter(sb, CultureInfo.InvariantCulture))
                PrintExceptionData(sw, ex, includeComponents);

            if (sb.Length > 0)
            {
                ch.Error(ex.Sensitivity(), sb.ToString());
            }
        }
Esempio n. 22
0
        /// <summary>
        /// Prints exception type, message, stack trace and data for every exception in the
        /// <see cref="Exception.InnerException"/> chain.
        /// </summary>
        private static void PrintFullExceptionDetails(IChannel ch, Exception ex)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(ex);
            int index = 0;

            for (var e = ex; e != null; e = e.InnerException)
            {
                index++;
                ch.Error(e.Sensitivity(), "({0}) Unexpected exception: {1}, '{2}'", index, e.Message, e.GetType());
                PrintExceptionData(ch, e, true);
                // While the message can be sensitive, we suppose the stack trace itself is not.
                ch.Error(MessageSensitivity.None, e.StackTrace);
            }
        }
        protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer)
        {
            Host.AssertValue(ch);
            ch.AssertValue(input);
            ch.Assert(0 <= iinfo && iinfo < Infos.Length);
            disposer = null;

            var info = Infos[iinfo];

            if (!info.TypeSrc.IsVector)
            {
                throw Host.ExceptParam(nameof(input),
                                       "Text input given, expects a text vector");
            }
            return(GetGetterVec(ch, input, iinfo));
        }
Esempio n. 24
0
            public TypedRowBase(TypedCursorable <TRow> parent, Row input, string channelMessage)
                : base(input)
            {
                Contracts.AssertValue(parent);
                Contracts.AssertValue(parent._host);
                Ch = parent._host.Start(channelMessage);
                Ch.AssertValue(input);

                int n = parent._pokes.Length;

                Ch.Assert(n == parent._columns.Length);
                Ch.Assert(n == parent._columnIndices.Length);
                _setters = new Action <TRow> [n];
                for (int i = 0; i < n; i++)
                {
                    _setters[i] = GenerateSetter(Input, parent._columnIndices[i], parent._columns[i], parent._pokes[i], parent._peeks[i]);
                }
            }
Esempio n. 25
0
        private TypeNaInfo KindReport <T>(IChannel ch, PrimitiveType type)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(type);
            ch.Assert(type.IsStandardScalar());

            var             conv = Conversions.Instance;
            InPredicate <T> isNaDel;
            bool            hasNaPred   = conv.TryGetIsNAPredicate(type, out isNaDel);
            bool            defaultIsNa = false;

            if (hasNaPred)
            {
                T def = default(T);
                defaultIsNa = isNaDel(in def);
            }
            return(new TypeNaInfo(hasNaPred, defaultIsNa));
        }
        private static void Load(IChannel ch, ModelLoadContext ctx, CodecFactory factory, ref VBuffer <ReadOnlyMemory <char> > values)
        {
            Contracts.AssertValue(ch);
            ch.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // Codec parameterization: A codec parameterization that should be a ReadOnlyMemory codec
            // int: n, the number of bytes used to write the values
            // byte[n]: As encoded using the codec

            // Get the codec from the factory, and from the stream. We have to
            // attempt to read the codec from the stream, since codecs can potentially
            // be versioned based on their parameterization.
            IValueCodec codec;

            // This *could* happen if we have an old version attempt to read a new version.
            // Enabling this sort of binary classification is why we also need to write the
            // codec specification.
            if (!factory.TryReadCodec(ctx.Reader.BaseStream, out codec))
            {
                throw ch.ExceptDecode();
            }
            ch.AssertValue(codec);
            ch.CheckDecode(codec.Type.IsVector);
            ch.CheckDecode(codec.Type.ItemType.IsText);
            var textCodec = (IValueCodec <VBuffer <ReadOnlyMemory <char> > >)codec;

            var bufferLen = ctx.Reader.ReadInt32();

            ch.CheckDecode(bufferLen >= 0);
            using (var stream = new SubsetStream(ctx.Reader.BaseStream, bufferLen))
            {
                using (var reader = textCodec.OpenReader(stream, 1))
                {
                    reader.MoveNext();
                    values = default(VBuffer <ReadOnlyMemory <char> >);
                    reader.Get(ref values);
                }
                ch.CheckDecode(stream.ReadByte() == -1);
            }
        }
Esempio n. 27
0
        private static void TrainCore <TDataSet, TPredictor>(IChannel ch, ITrainer trainer, Action <TDataSet> train, TDataSet data, TDataSet validData = null, TPredictor predictor = null)
            where TDataSet : class
            where TPredictor : class
        {
            const string inputModelArg = nameof(TrainCommand.Arguments.InputModelFile);

            if (validData != null)
            {
                if (predictor != null)
                {
                    var incValidTrainer = trainer as IIncrementalValidatingTrainer <TDataSet, TPredictor>;
                    if (incValidTrainer != null)
                    {
                        incValidTrainer.Train(data, validData, predictor);
                        return;
                    }

                    ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
                }

                var validTrainer = trainer as IValidatingTrainer <TDataSet>;
                ch.AssertValue(validTrainer);
                validTrainer.Train(data, validData);
            }
            else
            {
                if (predictor != null)
                {
                    var incTrainer = trainer as IIncrementalTrainer <TDataSet, TPredictor>;
                    if (incTrainer != null)
                    {
                        incTrainer.Train(data, predictor);
                        return;
                    }

                    ch.Warning("Ignoring " + inputModelArg + ": Trainer is not an incremental trainer.");
                }

                train(data);
            }
        }
Esempio n. 28
0
        internal override OptimizerState MakeState(IChannel ch, IProgressChannelProvider progress, DifferentiableFunction function, ref VBuffer <Float> initial)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(progress);

            if (EnforceNonNegativity)
            {
                VBufferUtils.Apply(ref initial, delegate(int ind, ref Float initialVal)
                {
                    if (initialVal < 0.0 && ind >= _biasCount)
                    {
                        initialVal = 0;
                    }
                });
            }

            if (_l1weight > 0 && _biasCount < initial.Length)
            {
                return(new L1OptimizerState(ch, progress, function, in initial, M, TotalMemoryLimit, _biasCount, _l1weight, KeepDense, EnforceNonNegativity));
            }
            return(new FunctionOptimizerState(ch, progress, function, in initial, M, TotalMemoryLimit, KeepDense, EnforceNonNegativity));
        }
        protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics)
        {
            ch.AssertValue(metrics);

            IDataView fold;

            if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold))
            {
                throw ch.Except("No overall metrics found");
            }

            // Show only the metrics for the requested index.
            fold = ExtractRelevantIndex(fold);

            string weightedMetrics;
            string unweightedMetrics = MetricWriter.GetPerFoldResults(Host, fold, out weightedMetrics);

            if (!string.IsNullOrEmpty(weightedMetrics))
            {
                ch.Info(weightedMetrics);
            }
            ch.Info(unweightedMetrics);
        }
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var       idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;

            if (idvToShuffle.CanShuffle)
            {
                idvToFeedTrain = idvToShuffle;
            }
            else
            {
                var shuffleArgs = new RowShufflingTransformer.Options
                {
                    PoolOnly     = false,
                    ForceShuffle = _options.Shuffle
                };
                idvToFeedTrain = new RowShufflingTransformer(Host, shuffleArgs, idvToShuffle);
            }

            ch.Assert(idvToFeedTrain.CanShuffle);

            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);

            ch.Assert(examplesToFeedTrain.Schema.Label.HasValue);
            ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue);
            if (examples.Schema.Weight.HasValue)
            {
                ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue);
            }

            ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training.");
            return(examplesToFeedTrain);
        }