Exemple #1
0
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx, IDataScorerTransform scorer)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckUserArg(!string.IsNullOrWhiteSpace(args.taggedPredictor), "taggedPredictor",
                             "The input tag is required.");

            if (scorer != null)
            {
                _scorer = scorer;
                return(scorer);
            }
            else
            {
                var    predictor = TagHelper.GetTaggedPredictor(env, input, args.taggedPredictor);
                string feat      = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                                       "featureColumn", args.featureColumn, DefaultColumnNames.Features);
                string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
                                                                   "groupColumn", args.groupColumn, DefaultColumnNames.GroupId);
                var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.customColumnPair);

                _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, predictor);
                return(_scorer);
            }
        }
Exemple #2
0
        protected void DebugChecking2(RoleMappedData td, ITrainer trainer)
        {
            var scorer = PredictorHelper.CreateDefaultScorer(Host, td, trainer.CreatePredictor());

            if (trainer.PredictionKind == PredictionKind.Ranking)
            {
                string schemas = SchemaHelper.ToString(scorer.Schema);
                if (!schemas.Contains("Score"))
                {
                    throw Host.Except("Issue with the schema: {0}", schemas);
                }
            }

            using (var cursor = scorer.GetRowCursor(i => true))
            {
                int ilab, ipred, ifeat;
                cursor.Schema.TryGetColumnIndex(td.Schema.Label.Name, out ilab);
                if (trainer.PredictionKind == PredictionKind.Ranking)
                {
                    cursor.Schema.TryGetColumnIndex("Score", out ipred);
                    cursor.Schema.TryGetColumnIndex(td.Schema.Feature.Name, out ifeat);
                    var getter  = cursor.GetGetter <uint>(ilab);
                    var fgetter = cursor.GetGetter <VBuffer <float> >(ifeat);
                    var pgetter = cursor.GetGetter <float>(ipred);
                    if (pgetter == null)
                    {
                        throw Host.Except("Issue with the schema: {0}", SchemaHelper.ToString(cursor.Schema));
                    }
                    uint            lab         = 0;
                    var             counts      = new Dictionary <uint, int>();
                    var             counts_pred = new Dictionary <float, int>();
                    float           pre         = 0;
                    VBuffer <float> features    = default(VBuffer <float>);
                    int             nbrows      = 0;
                    int             err         = 0;
                    while (cursor.MoveNext())
                    {
                        getter(ref lab);
                        pgetter(ref pre);
                        fgetter(ref features);
                        counts[lab]      = counts.ContainsKey(lab) ? counts[lab] + 1 : 0;
                        counts_pred[pre] = counts_pred.ContainsKey(pre) ? counts_pred[pre] + 1 : 0;
                        if (trainer.PredictionKind == PredictionKind.Ranking)
                        {
                            var elab = features.Values[features.Count - 1];
                            if (pre > 0 && lab < 0)
                            {
                                ++err;
                            }
                        }
                        else if (!lab.Equals(pre))
                        {
                            ++err;
                        }
                        ++nbrows;
                    }
                    if (nbrows == 0)
                    {
                        throw Host.Except("No results.");
                    }
                    if (err * 2 > nbrows)
                    {
                        throw Host.Except("No training.");
                    }
                }
                else
                {
                    cursor.Schema.TryGetColumnIndex("PredictedLabel", out ipred);
                    var  getter      = cursor.GetGetter <bool>(ilab);
                    var  pgetter     = cursor.GetGetter <bool>(ipred);
                    var  counts      = new Dictionary <bool, int>();
                    var  counts_pred = new Dictionary <bool, int>();
                    bool lab         = false;
                    bool pre         = false;
                    int  nbrows      = 0;
                    int  err         = 0;
                    while (cursor.MoveNext())
                    {
                        getter(ref lab);
                        pgetter(ref pre);
                        counts[lab]      = counts.ContainsKey(lab) ? counts[lab] + 1 : 0;
                        counts_pred[pre] = counts_pred.ContainsKey(pre) ? counts_pred[pre] + 1 : 0;
                        if (!lab.Equals(pre))
                        {
                            ++err;
                        }
                        ++nbrows;
                    }
                    if (nbrows == 0)
                    {
                        throw Host.Except("No results.");
                    }
                    if (err * 2 > nbrows)
                    {
                        throw Host.Except("No training.");
                    }
                }
            }
        }
        /// <summary>
        /// Finalizes the test on a predictor, calls the predictor with a scorer,
        /// saves the data, saves the models, loads it back, saves the data again,
        /// checks the output is the same.
        /// </summary>
        /// <param name="env">environment</param>
        /// <param name="outModelFilePath">output filename</param>
        /// <param name="predictor">predictor</param>
        /// <param name="roles">label, feature, ...</param>
        /// <param name="outData">first output data</param>
        /// <param name="outData2">second output data</param>
        /// <param name="kind">prediction kind</param>
        /// <param name="checkError">checks errors</param>
        /// <param name="ratio">check the error is below that threshold (if checkError is true)</param>
        /// <param name="ratioReadSave">check the predictions difference after reloading the model are below this threshold</param>
        public static void FinalizeSerializationTest(IHostEnvironment env,
                                                     string outModelFilePath, IPredictor predictor,
                                                     RoleMappedData roles, string outData, string outData2,
                                                     PredictionKind kind, bool checkError = true,
                                                     float ratio    = 0.8f, float ratioReadSave = 0.06f,
                                                     bool checkType = true)
        {
            string labelColumn = kind != PredictionKind.Clustering ? roles.Schema.Label.Value.Name : null;

            #region save, reading, running

            // Saves model.
            using (var ch = env.Start("Save"))
                using (var fs = File.Create(outModelFilePath))
                    TrainUtils.SaveModel(env, ch, fs, predictor, roles);
            if (!File.Exists(outModelFilePath))
            {
                throw new FileNotFoundException(outModelFilePath);
            }

            // Loads the model back.
            using (var fs = File.OpenRead(outModelFilePath))
            {
                var pred_local = ModelFileUtils.LoadPredictorOrNull(env, fs);
                if (pred_local == null)
                {
                    throw new Exception(string.Format("Unable to load '{0}'", outModelFilePath));
                }
                if (checkType && predictor.GetType() != pred_local.GetType())
                {
                    throw new Exception(string.Format("Type mismatch {0} != {1}", predictor.GetType(), pred_local.GetType()));
                }
            }

            // Checks the outputs.
            var sch1   = SchemaHelper.ToString(roles.Schema.Schema);
            var scorer = PredictorHelper.CreateDefaultScorer(env, roles, predictor);

            var sch2 = SchemaHelper.ToString(scorer.Schema);
            if (string.IsNullOrEmpty(sch1) || string.IsNullOrEmpty(sch2))
            {
                throw new Exception("Empty schemas");
            }

            var saver   = env.CreateSaver("Text");
            var columns = new int[scorer.Schema.Count];
            for (int i = 0; i < columns.Length; ++i)
            {
                columns[i] = saver.IsColumnSavable(scorer.Schema[i].Type) ? i : -1;
            }
            columns = columns.Where(c => c >= 0).ToArray();
            using (var fs2 = File.Create(outData))
                saver.SaveData(fs2, scorer, columns);

            if (!File.Exists(outModelFilePath))
            {
                throw new FileNotFoundException(outData);
            }

            // Check we have the same output.
            using (var fs = File.OpenRead(outModelFilePath))
            {
                var model = ModelFileUtils.LoadPredictorOrNull(env, fs);
                scorer = PredictorHelper.CreateDefaultScorer(env, roles, model);
                saver  = env.CreateSaver("Text");
                using (var fs2 = File.Create(outData2))
                    saver.SaveData(fs2, scorer, columns);
            }

            var t1 = File.ReadAllLines(outData);
            var t2 = File.ReadAllLines(outData2);
            if (t1.Length != t2.Length)
            {
                throw new Exception(string.Format("Not the same number of lines: {0} != {1}", t1.Length, t2.Length));
            }
            var linesN = new List <int>();
            for (int i = 0; i < t1.Length; ++i)
            {
                if (t1[i] != t2[i])
                {
                    linesN.Add(i);
                }
            }
            if (linesN.Count > (int)(t1.Length * ratioReadSave))
            {
                var rows = linesN.Select(i => string.Format("1-Mismatch on line {0}/{3}:\n{1}\n{2}", i, t1[i], t2[i], t1.Length)).ToList();
                rows.Add($"Number of differences: {linesN.Count}/{t1.Length}");
                throw new Exception(string.Join("\n", rows));
            }

            #endregion

            #region clustering

            if (kind == PredictionKind.Clustering)
            {
                // Nothing to do here.
                return;
            }

            #endregion

            #region supervized

            string expectedOuput = kind == PredictionKind.Regression ? "Score" : "PredictedLabel";

            // Get label and basic checking about performance.
            using (var cursor = scorer.GetRowCursor(scorer.Schema))
            {
                int ilabel, ipred;
                ilabel = SchemaHelper.GetColumnIndex(cursor.Schema, labelColumn);
                ipred  = SchemaHelper.GetColumnIndex(cursor.Schema, expectedOuput);
                var ty1   = cursor.Schema[ilabel].Type;
                var ty2   = cursor.Schema[ipred].Type;
                var dist1 = new Dictionary <int, int>();
                var dist2 = new Dictionary <int, int>();
                var conf  = new Dictionary <Tuple <int, int>, long>();

                if (kind == PredictionKind.MulticlassClassification)
                {
                    #region Multiclass

                    if (!ty2.IsKey())
                    {
                        throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema)));
                    }

                    if (ty1.RawKind() == DataKind.Single)
                    {
                        var   lgetter = cursor.GetGetter <float>(SchemaHelper._dc(ilabel, cursor));
                        var   pgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ipred, cursor));
                        float ans     = 0;
                        uint  pre     = 0;
                        while (cursor.MoveNext())
                        {
                            lgetter(ref ans);
                            pgetter(ref pre);

                            // The scorer +1 to the argmax.
                            ++ans;

                            var key = new Tuple <int, int>((int)pre, (int)ans);
                            if (!conf.ContainsKey(key))
                            {
                                conf[key] = 1;
                            }
                            else
                            {
                                ++conf[key];
                            }
                            if (!dist1.ContainsKey((int)ans))
                            {
                                dist1[(int)ans] = 1;
                            }
                            else
                            {
                                ++dist1[(int)ans];
                            }
                            if (!dist2.ContainsKey((int)pre))
                            {
                                dist2[(int)pre] = 1;
                            }
                            else
                            {
                                ++dist2[(int)pre];
                            }
                        }
                    }
                    else if (ty1.RawKind() == DataKind.UInt32 && ty1.IsKey())
                    {
                        var  lgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ilabel, cursor));
                        var  pgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ipred, cursor));
                        uint ans     = 0;
                        uint pre     = 0;
                        while (cursor.MoveNext())
                        {
                            lgetter(ref ans);
                            pgetter(ref pre);

                            var key = new Tuple <int, int>((int)pre, (int)ans);
                            if (!conf.ContainsKey(key))
                            {
                                conf[key] = 1;
                            }
                            else
                            {
                                ++conf[key];
                            }
                            if (!dist1.ContainsKey((int)ans))
                            {
                                dist1[(int)ans] = 1;
                            }
                            else
                            {
                                ++dist1[(int)ans];
                            }
                            if (!dist2.ContainsKey((int)pre))
                            {
                                dist2[(int)pre] = 1;
                            }
                            else
                            {
                                ++dist2[(int)pre];
                            }
                        }
                    }
                    else
                    {
                        throw new NotImplementedException(string.Format("Not implemented for type {0}", ty1.ToString()));
                    }
                    #endregion
                }
                else if (kind == PredictionKind.BinaryClassification)
                {
                    #region binary classification

                    if (ty2.RawKind() != DataKind.Boolean)
                    {
                        throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema)));
                    }

                    if (ty1.RawKind() == DataKind.Single)
                    {
                        var   lgetter = cursor.GetGetter <float>(SchemaHelper._dc(ilabel, cursor));
                        var   pgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ipred, cursor));
                        float ans     = 0;
                        bool  pre     = default(bool);
                        while (cursor.MoveNext())
                        {
                            lgetter(ref ans);
                            pgetter(ref pre);

                            if (ans != 0 && ans != 1)
                            {
                                throw Contracts.Except("The problem is not binary, expected answer is {0}", ans);
                            }

                            var key = new Tuple <int, int>(pre ? 1 : 0, (int)ans);
                            if (!conf.ContainsKey(key))
                            {
                                conf[key] = 1;
                            }
                            else
                            {
                                ++conf[key];
                            }
                            if (!dist1.ContainsKey((int)ans))
                            {
                                dist1[(int)ans] = 1;
                            }
                            else
                            {
                                ++dist1[(int)ans];
                            }
                            if (!dist2.ContainsKey(pre ? 1 : 0))
                            {
                                dist2[pre ? 1 : 0] = 1;
                            }
                            else
                            {
                                ++dist2[pre ? 1 : 0];
                            }
                        }
                    }
                    else if (ty1.RawKind() == DataKind.UInt32)
                    {
                        var  lgetter = cursor.GetGetter <uint>(SchemaHelper._dc(ilabel, cursor));
                        var  pgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ipred, cursor));
                        uint ans     = 0;
                        bool pre     = default(bool);
                        while (cursor.MoveNext())
                        {
                            lgetter(ref ans);
                            pgetter(ref pre);
                            if (ty1.IsKey())
                            {
                                --ans;
                            }

                            if (ans != 0 && ans != 1)
                            {
                                throw Contracts.Except("The problem is not binary, expected answer is {0}", ans);
                            }

                            var key = new Tuple <int, int>(pre ? 1 : 0, (int)ans);
                            if (!conf.ContainsKey(key))
                            {
                                conf[key] = 1;
                            }
                            else
                            {
                                ++conf[key];
                            }
                            if (!dist1.ContainsKey((int)ans))
                            {
                                dist1[(int)ans] = 1;
                            }
                            else
                            {
                                ++dist1[(int)ans];
                            }
                            if (!dist2.ContainsKey(pre ? 1 : 0))
                            {
                                dist2[pre ? 1 : 0] = 1;
                            }
                            else
                            {
                                ++dist2[pre ? 1 : 0];
                            }
                        }
                    }
                    else if (ty1.RawKind() == DataKind.Boolean)
                    {
                        var  lgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ilabel, cursor));
                        var  pgetter = cursor.GetGetter <bool>(SchemaHelper._dc(ipred, cursor));
                        bool ans     = default(bool);
                        bool pre     = default(bool);
                        while (cursor.MoveNext())
                        {
                            lgetter(ref ans);
                            pgetter(ref pre);

                            var key = new Tuple <int, int>(pre ? 1 : 0, ans ? 1 : 0);
                            if (!conf.ContainsKey(key))
                            {
                                conf[key] = 1;
                            }
                            else
                            {
                                ++conf[key];
                            }

                            if (!dist1.ContainsKey(ans ? 1 : 0))
                            {
                                dist1[ans ? 1 : 0] = 1;
                            }
                            else
                            {
                                ++dist1[ans ? 1 : 0];
                            }
                            if (!dist2.ContainsKey(pre ? 1 : 0))
                            {
                                dist2[pre ? 1 : 0] = 1;
                            }
                            else
                            {
                                ++dist2[pre ? 1 : 0];
                            }
                        }
                    }
                    else
                    {
                        throw new NotImplementedException(string.Format("Not implemented for type {0}", ty1));
                    }

                    #endregion
                }
                else if (kind == PredictionKind.Regression)
                {
                    #region regression

                    if (ty1.RawKind() != DataKind.Single)
                    {
                        throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema)));
                    }
                    if (ty2.RawKind() != DataKind.Single)
                    {
                        throw new Exception(string.Format("Label='{0}' Predicted={1}'\nSchema: {2}", ty1, ty2, SchemaHelper.ToString(cursor.Schema)));
                    }

                    var   lgetter = cursor.GetGetter <float>(SchemaHelper._dc(ilabel, cursor));
                    var   pgetter = cursor.GetGetter <float>(SchemaHelper._dc(ipred, cursor));
                    float ans     = 0;
                    float pre     = 0f;
                    float error   = 0f;
                    while (cursor.MoveNext())
                    {
                        lgetter(ref ans);
                        pgetter(ref pre);
                        error += (ans - pre) * (ans - pre);
                        if (!dist1.ContainsKey((int)ans))
                        {
                            dist1[(int)ans] = 1;
                        }
                        else
                        {
                            ++dist1[(int)ans];
                        }
                        if (!dist2.ContainsKey((int)pre))
                        {
                            dist2[(int)pre] = 1;
                        }
                        else
                        {
                            ++dist2[(int)pre];
                        }
                    }

                    if (float.IsNaN(error) || float.IsInfinity(error))
                    {
                        throw new Exception("Regression wen wrong. Error is infinite.");
                    }

                    #endregion
                }
                else
                {
                    throw new NotImplementedException(string.Format("Not implemented for kind {0}", kind));
                }

                var nbError = conf.Where(c => c.Key.Item1 != c.Key.Item2).Select(c => c.Value).Sum();
                var nbTotal = conf.Select(c => c.Value).Sum();

                if (checkError && (nbError * 1.0 > nbTotal * ratio || dist2.Count <= 1))
                {
                    var sconf = string.Join("\n", conf.OrderBy(c => c.Key)
                                            .Select(c => string.Format("pred={0} exp={1} count={2}", c.Key.Item1, c.Key.Item2, c.Value)));
                    var sdist2 = string.Join("\n", dist1.OrderBy(c => c.Key)
                                             .Select(c => string.Format("label={0} count={1}", c.Key, c.Value)));
                    var sdist1 = string.Join("\n", dist2.OrderBy(c => c.Key).Take(20)
                                             .Select(c => string.Format("label={0} count={1}", c.Key, c.Value)));
                    throw new Exception(string.Format("Too many errors {0}/{1}={7}\n###########\nConfusion:\n{2}\n########\nDIST1\n{3}\n###########\nDIST2\n{4}\nOutput:\n{5}\n...\n{6}",
                                                      nbError, nbTotal,
                                                      sconf, sdist1, sdist2,
                                                      string.Join("\n", t1.Take(Math.Min(30, t1.Length))),
                                                      string.Join("\n", t1.Skip(Math.Max(0, t1.Length - 30))),
                                                      nbError * 1.0 / nbTotal));
                }
            }

            #endregion
        }
 IDataScorerTransform GetScorer()
 {
     return(PredictorHelper.CreateDefaultScorer(_env, _predictor.roleMapData, _predictor.predictor, null));
 }
Exemple #5
0
        IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, out IDataView sourceCtx)
        {
            sourceCtx = input;
            Contracts.CheckValue(env, "env");
            env.CheckValue(args, "args");
            env.CheckValue(input, "input");
            env.CheckValue(args.tag, "tag is empty");
            env.CheckValue(args.trainer, "trainer",
                           "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");

            var views = TagHelper.EnumerateTaggedView(true, input).Where(c => c.Item1 == args.tag);

            if (views.Any())
            {
                throw env.Except("Tag '{0}' is already used.", args.tag);
            }

            var host = env.Register("TagTrainOrScoreTransform");

            using (var ch = host.Start("Train"))
            {
                ch.Trace("Constructing trainer");
                var trainerSett = ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(args.trainer);

                ITrainer trainer    = trainerSett.CreateInstance(host);
                var      customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

                string             feat;
                string             group;
                var                data       = CreateDataFromArgs(_host, ch, new OpaqueDataView(input), args, out feat, out group);
                ICalibratorTrainer calibrator = args.calibrator == null
                                    ? null
                                    : ScikitSubComponent <ICalibratorTrainer, SignatureCalibrator> .AsSubComponent(args.calibrator).CreateInstance(host);

                var nameTrainer = args.trainer.ToString().Replace("{", "").Replace("}", "").Replace(" ", "").Replace("=", "").Replace("+", "Y").Replace("-", "N");
                var extTrainer  = new ExtendedTrainer(trainer, nameTrainer);
                _predictor = extTrainer.Train(host, ch, data, null, calibrator, args.maxCalibrationExamples);

                if (!string.IsNullOrEmpty(args.outputModel))
                {
                    ch.Info("Saving model into '{0}'", args.outputModel);
                    using (var fs = File.Create(args.outputModel))
                        TrainUtils.SaveModel(env, ch, fs, _predictor, data);
                    ch.Info("Done.");
                }

                if (_cali != null)
                {
                    throw ch.ExceptNotImpl("Calibrator is not implemented yet.");
                }

                ch.Trace("Scoring");
                if (_args.scorer != null)
                {
                    var mapper   = new SchemaBindablePredictorWrapper(_predictor);
                    var roles    = new RoleMappedSchema(input.Schema, null, feat, group: group);
                    var bound    = mapper.Bind(_host, roles);
                    var scorPars = ScikitSubComponent <IDataScorerTransform, SignatureDataScorer> .AsSubComponent(_args.scorer);

                    _scorer = scorPars.CreateInstance(_host, input, bound, roles);
                }
                else
                {
                    _scorer = PredictorHelper.CreateDefaultScorer(_host, input, feat, group, _predictor);
                }

                ch.Info("Tagging with tag '{0}'.", args.tag);

                var ar = new TagViewTransform.Arguments {
                    tag = args.tag
                };
                var res = new TagViewTransform(env, ar, _scorer, _predictor);
                return(res);
            }
        }