示例#1
0
        public override MulticlassPredictionTransformer <OvaPredictor> Fit(IDataView input)
        {
            var roles = new KeyValuePair <CR, string> [1];

            roles[0] = new KeyValuePair <CR, string>(new CR(DefaultColumnNames.Label), LabelColumn.Name);
            var td = new RoleMappedData(input, roles);

            td.CheckMultiClassLabel(out var numClasses);

            var    predictors    = new TScalarPredictor[numClasses];
            string featureColumn = null;

            using (var ch = Host.Start("Fitting"))
            {
                for (int i = 0; i < predictors.Length; i++)
                {
                    ch.Info($"Training learner {i}");

                    if (i == 0)
                    {
                        var transformer = TrainOne(ch, Trainer, td, i);
                        featureColumn = transformer.FeatureColumn;
                    }

                    predictors[i] = TrainOne(ch, Trainer, td, i).Model;
                }
            }

            return(new MulticlassPredictionTransformer <OvaPredictor>(Host, OvaPredictor.Create(Host, _args.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name));
        }
示例#2
0
        private OvaPredictor(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, RegistrationName, ctx)
        {
            // *** Binary format ***
            // bool: useDist
            // int: predictor count
            bool useDist = ctx.Reader.ReadBoolByte();
            int  len     = ctx.Reader.ReadInt32();

            Host.CheckDecode(len > 0);

            if (useDist)
            {
                var predictors = new IValueMapperDist[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplDist(predictors);
            }
            else
            {
                var predictors = new TScalarPredictor[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplRaw(predictors);
            }

            OutputType = new VectorType(NumberType.Float, _impl.Predictors.Length);
        }
        private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx)
            : base(env, RegistrationName, ctx)
        {
            // *** Binary format ***
            // byte: OutputFormula as byte
            // int: predictor count
            OutputFormula outputFormula = (OutputFormula)ctx.Reader.ReadByte();
            int           len           = ctx.Reader.ReadInt32();

            Host.CheckDecode(len > 0);

            if (outputFormula == OutputFormula.Raw)
            {
                var predictors = new TScalarPredictor[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplRaw(predictors);
            }
            else if (outputFormula == OutputFormula.ProbabilityNormalization)
            {
                var predictors = new IValueMapperDist[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplDist(predictors);
            }
            else if (outputFormula == OutputFormula.Softmax)
            {
                var predictors = new TScalarPredictor[len];
                LoadPredictors(Host, predictors, ctx);
                _impl = new ImplSoftmax(predictors);
            }

            DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length);
        }
示例#4
0
        private TDistPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls1, int cls2)
        {
            string dstName;
            var    view = MapLabels(data, cls1, cls2, out dstName);

            var roles = data.Schema.GetColumnRoleNames()
                        .Where(kvp => kvp.Key.Value != CR.Label.Value)
                        .Prepend(CR.Label.Bind(dstName));
            var td = RoleMappedData.Create(view, roles);

            trainer.Train(td);

            ICalibratorTrainer calibrator;

            if (!Args.Calibrator.IsGood())
            {
                calibrator = null;
            }
            else
            {
                calibrator = Args.Calibrator.CreateInstance(Host);
            }
            TScalarPredictor predictor = trainer.CreatePredictor();
            var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
                                                              trainer, predictor, td);
            var dist = res as TDistPredictor;

            Host.Check(dist != null, "Calibrated predictor does not implement the expected interface");
            Host.Check(dist is IValueMapperDist, "Calibrated predictor does not implement the IValueMapperDist interface");
            return(dist);
        }
            internal ImplRaw(ModelLoadContext ctx, IHostEnvironment env)
            {
                // labelType
                GuessLabelType();
                int[] indices = ctx.Reader.ReadIntArray();

                TLabel[] classes;
                if (LabelType == NumberDataViewType.Single)
                {
                    classes = ctx.Reader.ReadFloatArray() as TLabel[];
                    env.CheckValue(classes, "classes");
                }
                else if (LabelType == NumberDataViewType.Byte)
                {
                    classes = ctx.Reader.ReadByteArray() as TLabel[];
                    env.CheckValue(classes, "classes");
                }
                else if (LabelType == NumberDataViewType.UInt16)
                {
                    var val = ctx.Reader.ReadUIntArray();
                    env.CheckValue(val, "classes");
                    classes = val.Select(c => (ushort)c).ToArray() as TLabel[];
                }
                else if (LabelType == NumberDataViewType.UInt32)
                {
                    var val = ctx.Reader.ReadUIntArray();
                    env.CheckValue(val, "classes");
                    classes = val as TLabel[];
                }
                else
                {
                    throw env.Except("Unexpected type for LabelType.");
                }

                _classes      = new VBuffer <TLabel>(classes.Length, classes, indices);
                _singleColumn = ctx.Reader.ReadInt32() == 1;
                _labelKey     = ctx.Reader.ReadInt32() == 1;
                FinalizeOutputType();

                int len = ctx.Reader.ReadInt32();

                env.CheckDecode(len > 0);
                var        predictors = new TScalarPredictor[len];
                IPredictor reclassPredictor;

                LoadPredictors(env, predictors, out reclassPredictor, ctx);
                Preparation(predictors, reclassPredictor);
                var checkCode = ctx.Reader.ReadByte();

                if (checkCode != 213)
                {
                    throw Contracts.Except("CheckCode is wrong. Serialization failed.");
                }
            }
示例#6
0
        protected override OvaPredictor TrainCore(IChannel ch, RoleMappedData data, int count)
        {
            // Train one-vs-all models.
            var predictors = new TScalarPredictor[count];

            for (int i = 0; i < predictors.Length; i++)
            {
                ch.Info($"Training learner {i}");
                predictors[i] = TrainOne(ch, Trainer, data, i).Model;
            }
            return(OvaPredictor.Create(Host, _args.UseProbabilities, predictors));
        }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var data0 = data;

            #region adding group ID

            // We insert a group Id.
            string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR";
            var    groupArgs       = new GenerateNumberTransform.Options
            {
                Columns    = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) },
                UseCounter = true
            };

            var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data);
            data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames());

            #endregion

            #region preparing the training dataset

            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            // We check the label is not boolean.
            int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName);
            var typeLab  = trans.Schema[indexLab].Type;
            if (typeLab.RawKind() == DataKind.Boolean)
            {
                throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind());
            }

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;
            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                ulong nb  = data.Schema.Label.Value.Type.AsKey().GetKeyCount();
                var   sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMulticlassLabel(out count3);
                if ((ulong)count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI);

            #endregion

            #region converting label and group into keys

            // We need to convert the label into a Key.
            var convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) },
                keyCount   = new KeyCount(4),
                resultType = DataKind.UInt32
            };
            IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_);

            // The group must be a key too!
            convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) },
                keyCount   = new KeyCount(),
                resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64
            };
            after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label);

            #endregion

            #region preparing the RoleMapData view

            string groupColumn = groupColumnTemp + "k";
            dstName += "k";

            var roles      = data.Schema.GetColumnRoleNames();
            var rolesArray = roles.ToArray();
            roles = roles
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                    .Where(kvp => kvp.Key.Value != groupColumn)
                    .Where(kvp => kvp.Key.Value != groupColumnTemp);
            rolesArray = roles.ToArray();
            if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp)
            {
                throw ch.Except("Duplicated group.");
            }
            roles = roles
                    .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                    .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName))
                    .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn));
            var trainer_input = new RoleMappedData(after_concatenation_key_label, roles);

            #endregion

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }

            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
示例#8
0
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var    data0 = data;
            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;

            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.R4)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberType.R4 });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                int nb  = data.Schema.Label.Value.Type.AsKey().KeyCount();
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorType(NumberType.R4, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMultiClassLabel(out count3);
                if (count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorType(NumberType.R4, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var       args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            IDataView after_concatenation = ComponentCreation.CreateTransform(Host, args, viewI);

            var roles = data.Schema.GetColumnRoleNames()
                        .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                        .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                        .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                        .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName));
            var trainer_input = new RoleMappedData(after_concatenation, roles);

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }
            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
示例#9
0
 internal ImplRawRanker(VBuffer <TLabel> classes, TScalarPredictor[] predictors, TScalarPredictor reclassPredicgtor,
                        bool singleColumn, bool labelKey) :
     base(classes, predictors, reclassPredicgtor, singleColumn, labelKey)
 {
 }