Пример #1
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
                                            IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inputPredictor);

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
            {
                ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
                           ": Trainer does not support incremental training.");
                inputPredictor = null;
            }
            ch.Assert(validData == null || trainer.Info.SupportsValidation);
            var predictor   = trainer.Train(new TrainContext(data, validData, testData, inputPredictor));
            var caliTrainer = calibrator?.CreateComponent(env);

            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, caliTrainer, maxCalibrationExamples, trainer, predictor, data));
        }
Пример #2
0
        /// <summary>
        /// Save the model to the output path.
        /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
        /// and command. It also uses featureColumn, if provided, to extract feature names.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="ch">The communication channel to use.</param>
        /// <param name="output">The output file handle.</param>
        /// <param name="predictor">The predictor.</param>
        /// <param name="data">The training examples.</param>
        /// <param name="command">The command string.</param>
        public static void SaveModel(IHostEnvironment env, IChannel ch, IFileHandle output,
                                     IPredictor predictor, RoleMappedData data, string command = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckParam(output != null && output.CanWrite, nameof(output));
            ch.CheckValueOrNull(predictor);
            ch.CheckValue(data, nameof(data));
            ch.CheckValueOrNull(command);

            using (var stream = output.CreateWriteStream())
                SaveModel(env, ch, stream, predictor, data, command);
        }
Пример #3
0
        /// <summary>
        /// Save the model in text format (if it can save itself)
        /// </summary>
        public static void SaveIni(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer)
        {
            Contracts.CheckValue(ch, nameof(ch));
            ch.CheckValue(predictor, nameof(predictor));
            ch.CheckValueOrNull(schema);
            ch.CheckValue(writer, nameof(writer));

            var iniSaver = predictor as ICanSaveInIniFormat;

            if (iniSaver != null)
            {
                iniSaver.SaveAsIni(writer, schema);
                return;
            }

            var summarySaver = predictor as ICanSaveSummary;

            if (summarySaver != null)
            {
                writer.WriteLine("'{0}' does not support saving in INI format, writing out model summary instead", predictor.GetType().Name);
                ch.Error("'{0}' doesn't currently have standardized INI format output, will save model summary instead",
                         predictor.GetType().Name);
                summarySaver.SaveSummary(writer, schema);
            }
            else
            {
                writer.WriteLine("'{0}' does not support saving in INI format", predictor.GetType().Name);
                ch.Error("'{0}' doesn't currently have standardized INI format output", predictor.GetType().Name);
            }
        }
Пример #4
0
        /// <summary>
        /// Save the model to the stream.
        /// The method saves the loader and the transformations of dataPipe and saves optionally predictor
        /// and command. It also uses featureColumn, if provided, to extract feature names.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="ch">The communication channel to use.</param>
        /// <param name="outputStream">The output model stream.</param>
        /// <param name="predictor">The predictor.</param>
        /// <param name="data">The training examples.</param>
        /// <param name="command">The command string.</param>
        public static void SaveModel(IHostEnvironment env, IChannel ch, Stream outputStream, IPredictor predictor, RoleMappedData data, string command = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(outputStream, nameof(outputStream));
            ch.CheckValueOrNull(predictor);
            ch.CheckValue(data, nameof(data));
            ch.CheckValueOrNull(command);

            using (var ch2 = env.Start("SaveModel"))
                using (var pch = env.StartProgressChannel("Saving model"))
                {
                    using (var rep = RepositoryWriter.CreateNew(outputStream, ch2))
                    {
                        if (predictor != null)
                        {
                            ch2.Trace("Saving predictor");
                            ModelSaveContext.SaveModel(rep, predictor, ModelFileUtils.DirPredictor);
                        }

                        ch2.Trace("Saving loader and transformations");
                        var dataPipe = data.Data;
                        if (dataPipe is IDataLoader)
                        {
                            ModelSaveContext.SaveModel(rep, dataPipe, ModelFileUtils.DirDataLoaderModel);
                        }
                        else
                        {
                            SaveDataPipe(env, rep, dataPipe);
                        }

                        // REVIEW: Handle statistics.
                        // ModelSaveContext.SaveModel(rep, dataStats, DirDataStats);
                        if (!string.IsNullOrWhiteSpace(command))
                        {
                            using (var ent = rep.CreateEntry(ModelFileUtils.DirTrainingInfo, "Command.txt"))
                                using (var writer = Utils.OpenWriter(ent.Stream))
                                    writer.WriteLine(command);
                        }
                        ModelFileUtils.SaveRoleMappings(env, ch, data.Schema, rep);

                        rep.Commit();
                    }
                    ch2.Done();
                }
        }
Пример #5
0
        // Returns true if a normalizer was added.
        public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValue(view, nameof(view));
            ch.CheckValueOrNull(featureColumn);
            ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures),
                            "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'.");

            if (autoNorm == NormalizeOption.No)
            {
                ch.Info("Not adding a normalizer.");
                return(false);
            }

            if (string.IsNullOrEmpty(featureColumn))
            {
                return(false);
            }

            int featCol;
            var schema = view.Schema;

            if (schema.TryGetColumnIndex(featureColumn, out featCol))
            {
                if (autoNorm != NormalizeOption.Yes)
                {
                    DvBool isNormalized = DvBool.False;
                    if (!trainer.Info.NeedNormalization || schema.IsNormalized(featCol))
                    {
                        ch.Info("Not adding a normalizer.");
                        return(false);
                    }
                    if (autoNorm == NormalizeOption.Warn)
                    {
                        ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.");
                        return(false);
                    }
                }
                ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
                IDataView ApplyNormalizer(IHostEnvironment innerEnv, IDataView input)
                => NormalizeTransform.CreateMinMaxNormalizer(innerEnv, input, featureColumn);

                if (view is IDataLoader loader)
                {
                    view = CompositeDataLoader.ApplyTransform(env, loader, tag: null, creationArgs: null, ApplyNormalizer);
                }
                else
                {
                    view = ApplyNormalizer(env, view);
                }
                return(true);
            }
            return(false);
        }
Пример #6
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, string name, RoleMappedData validData,
                                            ICalibratorTrainer calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inpPredictor = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckNonEmpty(name, nameof(name));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inpPredictor);

            var trainerRmd = trainer as ITrainer <RoleMappedData>;

            if (trainerRmd == null)
            {
                throw ch.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Trainer '{0}' does not accept known training data type", name);
            }

            Action <IChannel, ITrainer, Action <object>, object, object, object> trainCoreAction = TrainCore;
            IPredictor predictor;

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            var genericExam = trainCoreAction.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
                typeof(RoleMappedData),
                inpPredictor != null ? inpPredictor.GetType() : typeof(IPredictor));
            Action <RoleMappedData> trainExam = trainerRmd.Train;

            genericExam.Invoke(null, new object[] { ch, trainerRmd, trainExam, data, validData, inpPredictor });

            ch.Trace("Constructing predictor");
            predictor = trainerRmd.CreatePredictor();
            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data));
        }
Пример #7
0
        /// <summary>
        /// Trains a model.
        /// </summary>
        /// <param name="env">host</param>
        /// <param name="ch">channel</param>
        /// <param name="data">traing data</param>
        /// <param name="validData">validation data</param>
        /// <param name="calibrator">calibrator</param>
        /// <param name="maxCalibrationExamples">number of examples used to calibrate</param>
        /// <param name="cacheData">cache training data</param>
        /// <param name="inputPredictor">for continuous training, initial state</param>
        /// <returns>predictor</returns>
        public IPredictor Train(IHostEnvironment env, IChannel ch, RoleMappedData data, RoleMappedData validData = null,
                                ICalibratorTrainer calibrator = null, int maxCalibrationExamples = 0,
                                bool?cacheData = null, IPredictor inputPredictor = null)
        {
            /*
             * return TrainUtils.Train(env, ch, data, Trainer, LoadName, validData, calibrator, maxCalibrationExamples,
             *                      cacheData, inpPredictor);
             */

            var trainer = Trainer;
            var name    = LoadName;

            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckNonEmpty(name, nameof(name));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inputPredictor);

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace(MessageSensitivity.None, "Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
            {
                ch.Warning(MessageSensitivity.None, "Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
                           ": Trainer does not support incremental training.");
                inputPredictor = null;
            }
            ch.Assert(validData == null || trainer.Info.SupportsValidation);
            var predictor = trainer.Train(new TrainContext(data, validData, null, inputPredictor));

            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, calibrator, maxCalibrationExamples, trainer, predictor, data));
        }
Пример #8
0
        /// <summary>
        /// Save the model summary.
        /// </summary>
        public static void SaveSummary(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer)
        {
            Contracts.CheckValue(ch, nameof(ch));
            ch.CheckValue(predictor, nameof(predictor));
            ch.CheckValueOrNull(schema);
            ch.CheckValue(writer, nameof(writer));

            var saver = predictor as ICanSaveSummary;

            if (saver != null)
            {
                saver.SaveSummary(writer, schema);
            }
            else
            {
                writer.WriteLine("'{0}' does not support saving summary", predictor.GetType().Name);
                ch.Error("'{0}' does not support saving summary", predictor.GetType().Name);
            }
        }
Пример #9
0
        /// <summary>
        /// Save the model in text format (if it can save itself)
        /// </summary>
        public static void SaveCode(IChannel ch, IPredictor predictor, RoleMappedSchema schema, TextWriter writer)
        {
            Contracts.CheckValue(ch, nameof(ch));
            ch.CheckValue(predictor, nameof(predictor));
            ch.CheckValueOrNull(schema);
            ch.CheckValue(writer, nameof(writer));

            var saver = predictor as ICanSaveInSourceCode;

            if (saver != null)
            {
                saver.SaveAsCode(writer, schema);
            }
            else
            {
                writer.WriteLine("'{0}' does not support saving in code.", predictor.GetType().Name);
                ch.Error("'{0}' doesn't currently support saving the model as code", predictor.GetType().Name);
            }
        }
Пример #10
0
        // Returns true if a normalizer was added.
        public static bool AddNormalizerIfNeeded(IHostEnvironment env, IChannel ch, ITrainer trainer, ref IDataView view, string featureColumn, NormalizeOption autoNorm)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValue(view, nameof(view));
            ch.CheckValueOrNull(featureColumn);
            ch.CheckUserArg(Enum.IsDefined(typeof(NormalizeOption), autoNorm), nameof(TrainCommand.Arguments.NormalizeFeatures),
                            "Normalize option is invalid. Specify one of 'norm=No', 'norm=Warn', 'norm=Auto', or 'norm=Yes'.");

            if (autoNorm == NormalizeOption.No)
            {
                ch.Info("Not adding a normalizer.");
                return(false);
            }

            if (string.IsNullOrEmpty(featureColumn))
            {
                return(false);
            }

            int featCol;
            var schema = view.Schema;

            if (schema.TryGetColumnIndex(featureColumn, out featCol))
            {
                if (autoNorm != NormalizeOption.Yes)
                {
                    var    nn           = trainer as ITrainerEx;
                    DvBool isNormalized = DvBool.False;
                    if (nn == null || !nn.NeedNormalization ||
                        (schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, featCol, ref isNormalized) &&
                         isNormalized.IsTrue))
                    {
                        ch.Info("Not adding a normalizer.");
                        return(false);
                    }
                    if (autoNorm == NormalizeOption.Warn)
                    {
                        ch.Warning("A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.");
                        return(false);
                    }
                }
                ch.Info("Automatically adding a MinMax normalization transform, use 'norm=Warn' or 'norm=No' to turn this behavior off.");
                // Quote the feature column name
                string        quotedFeatureColumnName = featureColumn;
                StringBuilder sb = new StringBuilder();
                if (CmdQuoter.QuoteValue(quotedFeatureColumnName, sb))
                {
                    quotedFeatureColumnName = sb.ToString();
                }
                var component = new SubComponent <IDataTransform, SignatureDataTransform>("MinMax", string.Format("col={{ name={0} source={0} }}", quotedFeatureColumnName));
                var loader    = view as IDataLoader;
                if (loader != null)
                {
                    view = CompositeDataLoader.Create(env, loader,
                                                      new KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >(null, component));
                }
                else
                {
                    view = component.CreateInstance(env, view);
                }
                return(true);
            }
            return(false);
        }