Exemple #1
0
            public Bindings(Arguments args, ISchema schemaInput)
            {
                Contracts.AssertValue(args);
                Contracts.AssertValue(schemaInput);

                Input = schemaInput;

                Contracts.Check(Enum.IsDefined(typeof(HiddenColumnOption), args.Hidden), "hidden");
                HidDefault = args.Hidden;

                RawInfos = new RawColInfo[Utils.Size(args.Column)];
                if (RawInfos.Length > 0)
                {
                    var names = new HashSet <string>();
                    for (int i = 0; i < RawInfos.Length; i++)
                    {
                        var    item = args.Column[i];
                        string dst  = item.Name;
                        string src  = item.Source;

                        if (string.IsNullOrWhiteSpace(src))
                        {
                            src = dst;
                        }
                        else if (string.IsNullOrWhiteSpace(dst))
                        {
                            dst = src;
                        }
                        Contracts.CheckUserArg(!string.IsNullOrWhiteSpace(dst), nameof(Column.Name));

                        if (!names.Add(dst))
                        {
                            throw Contracts.ExceptUserArg(nameof(args.Column), "New column '{0}' specified multiple times", dst);
                        }

                        var hid = item.Hidden ?? args.Hidden;
                        Contracts.CheckUserArg(Enum.IsDefined(typeof(HiddenColumnOption), hid), nameof(args.Hidden));

                        RawInfos[i] = new RawColInfo(dst, src, hid);
                    }
                }

                BuildInfos(out Infos, out NameToInfoIndex, user: true);
            }
Exemple #2
0
        internal LbfgsTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column labelColumn)
            : base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
                   labelColumn, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
        {
            Host.CheckValue(args, nameof(args));
            Args = args;

            Contracts.CheckUserArg(!Args.UseThreads || Args.NumThreads > 0 || Args.NumThreads == null,
                                   nameof(Args.NumThreads), "numThreads must be positive (or empty for default)");
            Contracts.CheckUserArg(Args.L2Weight >= 0, nameof(Args.L2Weight), "Must be non-negative");
            Contracts.CheckUserArg(Args.L1Weight >= 0, nameof(Args.L1Weight), "Must be non-negative");
            Contracts.CheckUserArg(Args.OptTol > 0, nameof(Args.OptTol), "Must be positive");
            Contracts.CheckUserArg(Args.MemorySize > 0, nameof(Args.MemorySize), "Must be positive");
            Contracts.CheckUserArg(Args.MaxIterations > 0, nameof(Args.MaxIterations), "Must be positive");
            Contracts.CheckUserArg(Args.SgdInitializationTolerance >= 0, nameof(Args.SgdInitializationTolerance), "Must be non-negative");
            Contracts.CheckUserArg(Args.NumThreads == null || Args.NumThreads.Value >= 0, nameof(Args.NumThreads), "Must be non-negative");

            L2Weight      = Args.L2Weight;
            L1Weight      = Args.L1Weight;
            OptTol        = Args.OptTol;
            MemorySize    = Args.MemorySize;
            MaxIterations = Args.MaxIterations;
            SgdInitializationTolerance = Args.SgdInitializationTolerance;
            Quiet                = Args.Quiet;
            InitWtsDiameter      = Args.InitWtsDiameter;
            UseThreads           = Args.UseThreads;
            NumThreads           = Args.NumThreads;
            DenseOptimizer       = Args.DenseOptimizer;
            EnforceNonNegativity = Args.EnforceNonNegativity;

            if (EnforceNonNegativity && ShowTrainingStats)
            {
                ShowTrainingStats = false;
                using (var ch = Host.Start("Initialization"))
                {
                    ch.Warning("The training statistics cannot be computed with non-negativity constraint.");
                    ch.Done();
                }
            }

            ShowTrainingStats = false;
            _srcPredictor     = default;
        }
Exemple #3
0
 private void BuildNameDict(int[] indexCopy, bool drop, out int[] sources, out int[] dropped, out Dictionary <string, int> nameToCol, bool user)
 {
     Contracts.AssertValue(indexCopy);
     foreach (int col in indexCopy)
     {
         if (col < 0 || _input.ColumnCount <= col)
         {
             const string fmt = "Column index {0} invalid for input with {1} columns";
             if (user)
             {
                 throw Contracts.ExceptUserArg(nameof(Arguments.Index), fmt, col, _input.ColumnCount);
             }
             else
             {
                 throw Contracts.ExceptDecode(fmt, col, _input.ColumnCount);
             }
         }
     }
     if (drop)
     {
         sources = Enumerable.Range(0, _input.ColumnCount).Except(indexCopy).ToArray();
         dropped = indexCopy;
     }
     else
     {
         sources = indexCopy;
         dropped = null;
     }
     if (user)
     {
         Contracts.CheckUserArg(sources.Length > 0, nameof(Arguments.Index), "Choose columns by index has no output columns");
     }
     else
     {
         Contracts.CheckDecode(sources.Length > 0, "Choose columns by index has no output columns");
     }
     nameToCol = new Dictionary <string, int>();
     for (int c = 0; c < sources.Length; ++c)
     {
         nameToCol[_input.GetColumnName(sources[c])] = c;
     }
 }
Exemple #4
0
            internal ColumnInfo(Column item, Arguments args)
            {
                Contracts.CheckValue(item, nameof(item));
                Contracts.CheckValue(args, nameof(args));

                Input  = item.Source ?? item.Name;
                Output = item.Name;

                if (item.UseAlpha ?? args.UseAlpha)
                {
                    Colors |= ColorBits.Alpha; Planes++;
                }
                if (item.UseRed ?? args.UseRed)
                {
                    Colors |= ColorBits.Red; Planes++;
                }
                if (item.UseGreen ?? args.UseGreen)
                {
                    Colors |= ColorBits.Green; Planes++;
                }
                if (item.UseBlue ?? args.UseBlue)
                {
                    Colors |= ColorBits.Blue; Planes++;
                }
                Contracts.CheckUserArg(Planes > 0, nameof(item.UseRed), "Need to use at least one color plane");

                Interleave = item.InterleaveArgb ?? args.InterleaveArgb;

                Convert = item.Convert ?? args.Convert;
                if (!Convert)
                {
                    Offset = 0;
                    Scale  = 1;
                }
                else
                {
                    Offset = item.Offset ?? args.Offset ?? 0;
                    Scale  = item.Scale ?? args.Scale ?? 1;
                    Contracts.CheckUserArg(FloatUtils.IsFinite(Offset), nameof(item.Offset));
                    Contracts.CheckUserArg(FloatUtils.IsFiniteNonZero(Scale), nameof(item.Scale));
                }
            }
Exemple #5
0
        private static int GetAndVerifyInvertHashMaxCount(Arguments args, Column col, ColInfoEx ex)
        {
            var invertHashMaxCount = col.InvertHash ?? args.InvertHash;

            if (invertHashMaxCount != 0)
            {
                if (invertHashMaxCount == -1)
                {
                    invertHashMaxCount = int.MaxValue;
                }
                Contracts.CheckUserArg(invertHashMaxCount > 0, nameof(args.InvertHash), "Value too small, must be -1 or larger");
                // If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length,
                // this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue.
                if (ex.HashBits >= 31)
                {
                    throw Contracts.ExceptUserArg(nameof(args.InvertHash), "Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", ex.HashBits);
                }
            }
            return(invertHashMaxCount);
        }
            /// <summary>
            /// Construct a <see cref="KeyDataViewType"/> out of the DbType and the keyCount.
            /// </summary>
            private static KeyDataViewType ConstructKeyType(DbType dbType, KeyCount keyCount)
            {
                Contracts.CheckValue(keyCount, nameof(keyCount));

                KeyDataViewType keyType;
                Type            rawType = dbType.ToType();

                Contracts.CheckUserArg(KeyDataViewType.IsValidDataType(rawType), nameof(DatabaseLoader.Column.Type), "Bad item type for Key");

                if (keyCount.Count == null)
                {
                    keyType = new KeyDataViewType(rawType, rawType.ToMaxInt());
                }
                else
                {
                    keyType = new KeyDataViewType(rawType, keyCount.Count.GetValueOrDefault());
                }

                return(keyType);
            }
Exemple #7
0
        /// <summary>
        /// Construct a <see cref="KeyType"/> out of the data kind and the keyCount.
        /// </summary>
        public static KeyType ConstructKeyType(DataKind?type, KeyCount keyCount)
        {
            Contracts.CheckValue(keyCount, nameof(keyCount));

            KeyType keyType;
            Type    rawType = type.HasValue ? type.Value.ToType() : DataKind.U8.ToType();

            Contracts.CheckUserArg(KeyType.IsValidDataType(rawType), nameof(TextLoader.Column.Type), "Bad item type for Key");

            if (keyCount.Count == null)
            {
                keyType = new KeyType(rawType, rawType.ToMaxInt());
            }
            else
            {
                keyType = new KeyType(rawType, keyCount.Count.GetValueOrDefault());
            }

            return(keyType);
        }
Exemple #8
0
            public ColInfoEx(Column item, Arguments args)
            {
                NgramLength = item.NgramLength ?? args.NgramLength;
                Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(item.NgramLength));
                SkipLength = item.SkipLength ?? args.SkipLength;
                Contracts.CheckUserArg(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(item.SkipLength));
                if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
                {
                    throw Contracts.ExceptUserArg(nameof(item.SkipLength),
                                                  "The sum of skipLength and ngramLength must be less than or equal to {0}",
                                                  NgramBufferBuilder.MaxSkipNgramLength);
                }

                HashBits = item.HashBits ?? args.HashBits;
                Contracts.CheckUserArg(1 <= HashBits && HashBits <= 30, nameof(item.HashBits));
                Seed       = item.Seed ?? args.Seed;
                Rehash     = item.RehashUnigrams ?? args.RehashUnigrams;
                Ordered    = item.Ordered ?? args.Ordered;
                AllLengths = item.AllLengths ?? args.AllLengths;
            }
            public ColInfoEx(Column item, Arguments args, ColInfo info)
            {
                Kind = item.Kind ?? args.Kind;
                Contracts.CheckUserArg(Kind == WhiteningKind.Pca || Kind == WhiteningKind.Zca, nameof(item.Kind));
                Epsilon = item.Eps ?? args.Eps;
                Contracts.CheckUserArg(0 <= Epsilon && Epsilon < Float.PositiveInfinity, nameof(item.Eps));
                MaxRow = item.MaxRows ?? args.MaxRows;
                Contracts.CheckUserArg(MaxRow > 0, nameof(item.MaxRows));
                SaveInv = item.SaveInverse ?? args.SaveInverse;
                PcaNum  = item.PcaNum ?? args.PcaNum;
                Contracts.CheckUserArg(PcaNum >= 0, nameof(item.PcaNum));

                if (Kind == WhiteningKind.Zca || PcaNum == 0)
                {
                    Type = info.TypeSrc.AsVector;
                }
                else
                {
                    Type = new VectorType(NumberType.Float, PcaNum); // REVIEW: make it work with pcaNum == 1.
                }
            }
Exemple #10
0
        public TypeName(IHostEnvironment env, float p, int foo)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckParam(0 <= p && p <= 1, nameof(p), "Should be in range [0,1]");
            env.CheckParam(0 <= p && p <= 1, "p");                   // Should fail.
            env.CheckParam(0 <= p && p <= 1, nameof(p) + nameof(p)); // Should fail.
            env.CheckValue(paramName: nameof(p), val: "p");          // Should succeed despite confusing order.
            env.CheckValue(paramName: "p", val: nameof(p));          // Should fail despite confusing order.
            env.CheckValue("p", nameof(p));
            env.CheckUserArg(foo > 5, "foo", "Nice");
            env.CheckUserArg(foo > 5, nameof(foo), "Nice");
            env.Except();                                           // Not throwing or doing anything with the exception, so should fail.
            Contracts.ExceptParam(nameof(env), "What a silly env"); // Should also fail.
            if (false)
            {
                throw env.Except(); // Should not fail.
            }
            if (false)
            {
                throw env.ExceptParam(nameof(env), "What a silly env"); // Should not fail.
            }
            if (false)
            {
                throw env.ExceptParam("env", "What a silly env"); // Should fail due to name error.
            }
            var e = env.Except();

            env.Check(true, $"Hello {foo} is cool");
            env.Check(true, "Hello it is cool");
            string coolMessage = "Hello it is cool";

            env.Check(true, coolMessage);
            env.Check(true, string.Format("Hello {0} is cool", foo));
            env.Check(true, Messages.CoolMessage);
            env.CheckDecode(true, "Not suspicious, no ModelLoadContext");
            Contracts.Check(true, "Fine: " + nameof(env));
            Contracts.Check(true, "Less fine: " + env.GetType().Name);
            Contracts.CheckUserArg(0 <= p && p <= 1,
                                   "p", "On a new line");
        }
Exemple #11
0
            private ZlibImpl(ArgumentsBase args, bool isDeflate)
            {
                Contracts.CheckUserArg(args.CompressionLevel == null ||
                                       (0 <= args.CompressionLevel && args.CompressionLevel <= 9),
                                       nameof(args.CompressionLevel), "Must be in range 0 to 9 or null");
                Contracts.CheckUserArg(8 <= args.WindowBits && args.WindowBits <= 15, nameof(args.WindowBits), "Must be in range 8 to 15");
                Contracts.CheckUserArg(1 <= args.MemoryLevel && args.MemoryLevel <= 9, nameof(args.MemoryLevel), "Must be in range 1 to 9");
                Contracts.CheckUserArg(Enum.IsDefined(typeof(Constants.Strategy), args.Strategy), nameof(args.Strategy), "Value was not defined");

                if (args.CompressionLevel == null)
                {
                    _level = Constants.Level.DefaultCompression;
                }
                else
                {
                    _level = (Constants.Level)args.CompressionLevel;
                }
                Contracts.Assert(Enum.IsDefined(typeof(Constants.Level), _level));
                _windowBits  = args.WindowBits;
                _isDeflate   = isDeflate;
                _memoryLevel = args.MemoryLevel;
                _strategy    = args.Strategy;
            }
        internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, bool showTrainingStats = false)
            : base(env, name)
        {
            Contracts.CheckUserArg(!args.UseThreads || args.NumThreads > 0 || args.NumThreads == null,
                                   nameof(args.NumThreads), "numThreads must be positive (or empty for default)");

            Contracts.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "Must be non-negative");
            L2Weight = args.L2Weight;
            Contracts.CheckUserArg(args.L1Weight >= 0, nameof(args.L1Weight), "Must be non-negative");
            L1Weight = args.L1Weight;
            Contracts.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Must be positive");
            OptTol = args.OptTol;
            Contracts.CheckUserArg(args.MemorySize > 0, nameof(args.MemorySize), "Must be positive");
            MemorySize = args.MemorySize;
            Contracts.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Must be positive");
            MaxIterations = args.MaxIterations;
            Contracts.CheckUserArg(args.SgdInitializationTolerance >= 0, nameof(args.SgdInitializationTolerance), "Must be non-negative");
            SgdInitializationTolerance = args.SgdInitializationTolerance;
            Quiet           = args.Quiet;
            InitWtsDiameter = args.InitWtsDiameter;
            UseThreads      = args.UseThreads;
            Contracts.CheckUserArg(args.NumThreads == null || args.NumThreads.Value >= 0, nameof(args.NumThreads), "Must be non-negative");
            NumThreads           = args.NumThreads;
            DenseOptimizer       = args.DenseOptimizer;
            ShowTrainingStats    = showTrainingStats;
            EnforceNonNegativity = args.EnforceNonNegativity;

            if (EnforceNonNegativity && ShowTrainingStats)
            {
                ShowTrainingStats = false;
                using (var ch = Host.Start("Initialization"))
                {
                    ch.Warning("The training statistics cannot be computed with non-negativity constraint.");
                    ch.Done();
                }
            }
        }
Exemple #13
0
            internal ColumnInfo(string name,
                                int ngramLength,
                                int skipLength,
                                bool allLengths,
                                NgramExtractingEstimator.WeightingCriteria weighting,
                                int[] maxNumTerms,
                                string inputColumnName = null)
            {
                Name            = name;
                InputColumnName = inputColumnName ?? name;
                NgramLength     = ngramLength;
                Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength));
                SkipLength = skipLength;
                if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
                {
                    throw Contracts.ExceptUserArg(nameof(skipLength),
                                                  $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}");
                }
                AllLengths = allLengths;
                Weighting  = weighting;
                var limits = new int[ngramLength];

                if (!AllLengths)
                {
                    Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 ||
                                           Utils.Size(maxNumTerms) == 1 && maxNumTerms[0] > 0, nameof(maxNumTerms));
                    limits[ngramLength - 1] = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[0];
                }
                else
                {
                    Contracts.CheckUserArg(Utils.Size(maxNumTerms) <= ngramLength, nameof(maxNumTerms));
                    Contracts.CheckUserArg(Utils.Size(maxNumTerms) == 0 || maxNumTerms.All(i => i >= 0) && maxNumTerms[maxNumTerms.Length - 1] > 0, nameof(maxNumTerms));
                    var extend = Utils.Size(maxNumTerms) == 0 ? NgramExtractingEstimator.Defaults.MaxNumTerms : maxNumTerms[maxNumTerms.Length - 1];
                    limits = Utils.BuildArray(ngramLength, i => i < Utils.Size(maxNumTerms) ? maxNumTerms[i] : extend);
                }
                Limits = ImmutableArray.Create(limits);
            }
Exemple #14
0
 // Used in command line tool to construct lodable class.
 private TolerantEarlyStoppingRule(Options options, bool lowerIsBetter)
     : base(lowerIsBetter)
 {
     Contracts.CheckUserArg(options.Threshold >= 0, nameof(options.Threshold), "Must be non-negative.");
     Threshold = options.Threshold;
 }
 internal LightGbmRankingTrainer(IHostEnvironment env, Options options)
     : base(env, LoadNameValue, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumnName))
 {
     Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0.");
 }
 public TolerantEarlyStoppingCriterion(Options options, bool lowerIsBetter)
     : base(options, lowerIsBetter)
 {
     Contracts.CheckUserArg(EarlyStoppingCriterionOptions.Threshold >= 0, nameof(options.Threshold), "Must be non-negative.");
 }
 public GLEarlyStoppingCriterion(Options options, bool lowerIsBetter)
     : base(options, lowerIsBetter)
 {
     Contracts.CheckUserArg(0 <= EarlyStoppingCriterionOptions.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
 }
 public TransformInfo(Column item, Arguments args, int d)
 {
     Dimension = d;
     Rank      = item.Rank ?? args.Rank;
     Contracts.CheckUserArg(0 < Rank && Rank <= Dimension, nameof(item.Rank), "Rank must be positive, and at most the dimension of untransformed data");
 }
Exemple #19
0
        private void Run(IChannel ch)
        {
            ILegacyDataLoader loader  = null;
            IPredictor        rawPred = null;
            IDataView         view;
            RoleMappedSchema  trainSchema = null;

            if (_model == null)
            {
                if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(ImplOptions.LoadPredictor != true, nameof(ImplOptions.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(ImplOptions.InputModelFile) + " is also specifified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, ImplOptions.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            var model = ConvertTransformListToOnnxModel(ctx, ch, source, end, transforms, _inputsToDrop, _outputsToDrop);

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(ImplOptions.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, ImplOptions.OutputModelFile);
            }
        }
Exemple #20
0
        private void Run(IChannel ch)
        {
            IDataLoader      loader  = null;
            IPredictor       rawPred = null;
            IDataView        view;
            RoleMappedSchema trainSchema = null;

            if (_model == null)
            {
                if (string.IsNullOrEmpty(Args.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);

            var ctx = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                          ModelVersion, _domain);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = RoleMappedData.Create(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = TrainUtils.CreateExamplesOpt(end, DefaultColumnNames.Label,
                                                        DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            HashSet <string> inputColumns = new HashSet <string>();

            //Create graph inputs.
            for (int i = 0; i < source.Schema.ColumnCount; i++)
            {
                string colName = source.Schema.GetColumnName(i);
                if (_inputsToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(source.Schema.GetColumnType(i), colName);
                inputColumns.Add(colName);
            }

            //Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                Host.Assert(trans.CanSaveOnnx);
                trans.SaveAsOnnx(ctx);
            }

            //Add graph outputs.
            for (int i = 0; i < end.Schema.ColumnCount; ++i)
            {
                if (end.Schema.IsHidden(i))
                {
                    continue;
                }

                var idataviewColumnName = end.Schema.GetColumnName(i);;
                if (_outputsToDrop.Contains(idataviewColumnName) || _inputsToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName = ctx.TryGetVariableName(idataviewColumnName);
                if (variableName != null)
                {
                    ctx.AddOutputVariable(end.Schema.GetColumnType(i), variableName);
                }
            }

            var model = ctx.MakeModel();

            if (_outputModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputModelPath))
                    using (var stream = file.CreateWriteStream())
                        model.WriteTo(stream);
            }

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, Args.OutputModelFile);
            }
        }
Exemple #21
0
 /// <summary>
 /// Create a rule which may terminate the training process in case of loss of generality. The loss of generality means
 /// the specified score on validation start increaseing.
 /// </summary>
 /// <param name="threshold">The maximum gap (in percentage such as 0.01 for 1% and 0.5 for 50%) between the (current) validation
 /// score and its best historical value.</param>
 public GeneralityLossRule(float threshold = 0.01f) :
     base()
 {
     Contracts.CheckUserArg(0 <= threshold && threshold <= 1, nameof(threshold), "Must be in range [0,1].");
     Threshold = threshold;
 }
Exemple #22
0
 public TolerantEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
     : base(args, lowerIsBetter)
 {
     Contracts.CheckUserArg(Args.Threshold >= 0, nameof(args.Threshold), "Must be non-negative.");
 }
        internal Dictionary <string, object> ToDictionary(IHost host)
        {
            Contracts.CheckValue(host, nameof(host));
            Contracts.CheckUserArg(MaxBin > 0, nameof(MaxBin), "must be > 0.");
            Contracts.CheckUserArg(Sigmoid > 0, nameof(Sigmoid), "must be > 0.");
            Dictionary <string, object> res = new Dictionary <string, object>();

            var boosterParams = Booster.CreateComponent(host);

            boosterParams.UpdateParameters(res);

            res[GetArgName(nameof(MaxBin))] = MaxBin;

            res["verbose"] = Silent ? "-1" : "1";
            if (NThread.HasValue)
            {
                res["nthread"] = NThread.Value;
            }

            res["seed"] = (Seed.HasValue) ? Seed : host.Rand.Next();

            string metric = null;

            switch (EvalMetric)
            {
            case EvalMetricType.DefaultMetric:
                break;

            case EvalMetricType.Mae:
                metric = "l1";
                break;

            case EvalMetricType.Logloss:
                metric = "binary_logloss";
                break;

            case EvalMetricType.Error:
                metric = "binary_error";
                break;

            case EvalMetricType.Merror:
                metric = "multi_error";
                break;

            case EvalMetricType.Mlogloss:
                metric = "multi_logloss";
                break;

            case EvalMetricType.Rmse:
            case EvalMetricType.Auc:
            case EvalMetricType.Ndcg:
            case EvalMetricType.Map:
                metric = EvalMetric.ToString().ToLower();
                break;
            }
            if (!string.IsNullOrEmpty(metric))
            {
                res["metric"] = metric;
            }
            res["sigmoid"]    = Sigmoid;
            res["label_gain"] = CustomGains;
            res[GetArgName(nameof(UseMissing))]      = UseMissing;
            res[GetArgName(nameof(MinDataPerGroup))] = MinDataPerGroup;
            res[GetArgName(nameof(MaxCatThreshold))] = MaxCatThreshold;
            res[GetArgName(nameof(CatSmooth))]       = CatSmooth;
            res[GetArgName(nameof(CatL2))]           = CatL2;
            return(res);
        }
Exemple #24
0
 public GLEarlyStoppingCriterion(Arguments args, bool lowerIsBetter)
     : base(args, lowerIsBetter)
 {
     Contracts.CheckUserArg(0 <= Args.Threshold && args.Threshold <= 1, nameof(args.Threshold), "Must be in range [0,1].");
 }
        private void Run(IChannel ch)
        {
            IDataLoader      loader;
            IPredictor       rawPred;
            RoleMappedSchema trainSchema;

            if (string.IsNullOrEmpty(Args.InputModelFile))
            {
                loader      = CreateLoader();
                rawPred     = null;
                trainSchema = null;
                Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor),
                                  "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified.");
            }
            else
            {
                LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
            }

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSavePfa> transforms;

            GetPipe(ch, loader, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scorePfa  = scorePipe as ITransformCanSavePfa;
                if (scorePfa?.CanSavePfa == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scorePfa);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as PFA.");
                    ch.Warning("We do not know how to save the predictor as PFA. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            var ctx = new BoundPfaContext(Host, source.Schema, _inputsToDrop, allowSet: _allowSet);

            foreach (var trans in transforms)
            {
                Host.Assert(trans.CanSavePfa);
                trans.SaveAsPfa(ctx);
            }

            var toExport = new List <string>();

            for (int i = 0; i < end.Schema.ColumnCount; ++i)
            {
                if (end.Schema.IsHidden(i))
                {
                    continue;
                }
                var name = end.Schema.GetColumnName(i);
                if (_outputsToDrop.Contains(name))
                {
                    continue;
                }
                if (!ctx.IsInput(name) || _keepInput)
                {
                    toExport.Add(name);
                }
            }
            JObject pfaDoc = ctx.Finalize(end.Schema, toExport.ToArray());

            if (_name != null)
            {
                pfaDoc["name"] = _name;
            }

            if (_outputModelPath == null)
            {
                ch.Info(MessageSensitivity.Schema, pfaDoc.ToString(_formatting));
            }
            else
            {
                using (var file = Host.CreateOutputFile(_outputModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                            writer.Write(pfaDoc.ToString(_formatting));
            }

            if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
            {
                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, Args.OutputModelFile);
            }
        }
Exemple #26
0
 public ColInfoEx(ArgumentsBase args)
 {
     Separators = PredictionUtil.SeparatorFromString(args.TermSeparators);
     Contracts.CheckUserArg(Utils.Size(Separators) > 0, nameof(args.TermSeparators));
 }
Exemple #27
0
 public LinearSvm(IHostEnvironment env, Arguments args)
     : base(args, env, UserNameValue)
 {
     Contracts.CheckUserArg(args.Lambda > 0, nameof(args.Lambda), UserErrorPositive);
     Contracts.CheckUserArg(args.BatchSize > 0, nameof(args.BatchSize), UserErrorPositive);
 }
Exemple #28
0
 // Used in command line tool to construct lodable class.
 private GeneralityLossRule(Options options, bool lowerIsBetter)
     : base(lowerIsBetter)
 {
     Contracts.CheckUserArg(0 <= options.Threshold && options.Threshold <= 1, nameof(options.Threshold), "Must be in range [0,1].");
     Threshold = options.Threshold;
 }
Exemple #29
0
 internal LightGbmBinaryTrainer(IHostEnvironment env, Options options)
     : base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumnName))
 {
     Contracts.CheckUserArg(options.Sigmoid > 0, nameof(Options.Sigmoid), "must be > 0.");
     Contracts.CheckUserArg(options.WeightOfPositiveExamples > 0, nameof(Options.WeightOfPositiveExamples), "must be > 0.");
 }
Exemple #30
0
 public ColInfoEx(Arguments args, int iinfo)
 {
     Separators = PredictionUtil.SeparatorFromString(args.Column[iinfo].TermSeparators ?? args.TermSeparators);
     Contracts.CheckUserArg(Utils.Size(Separators) > 0, nameof(args.TermSeparators));
 }