protected override INearestNeighborsPredictor Train(RoleMappedData data)
        {
            int count;

            data.CheckMultiClassLabel(out count);
            return(base.Train(data));
        }
        public override void Train(RoleMappedData data)
        {
            Host.CheckValue(data, nameof(data));
            data.CheckFeatureFloatVector();

            int count;

            data.CheckMultiClassLabel(out count);
            Host.Assert(count > 0);

            using (var ch = Host.Start("Training"))
            {
                _pred = TrainCore(ch, data, count);
                ch.Check(_pred != null, "Training did not result in a predictor");
                ch.Done();
            }
        }
示例#3
0
        TVectorPredictor Train(RoleMappedData data)
        {
            Contracts.CheckValue(data, "data");
            data.CheckFeatureFloatVector();

            int count;

            data.CheckMultiClassLabel(out count);

            using (var ch = Host.Start("Training"))
            {
                // Train one-vs-all models.
                _predictors = new TVectorPredictor[1];
                for (int i = 0; i < _predictors.Length; i++)
                {
                    ch.Info("Training learner {0}", i);
                    Contracts.CheckValue(_args.predictorType, "predictorType", "Must specify a base learner type");

                    TScalarTrainer trainer;
                    if (_trainer != null)
                    {
                        trainer = _trainer;
                    }
                    else
                    {
                        var temp = ScikitSubComponent <ITrainer, SignatureBinaryClassifierTrainer> .AsSubComponent(_args.predictorType);

                        trainer = temp.CreateInstance(Host) as TScalarTrainer;
                    }

                    _trainer       = null;
                    _predictors[i] = TrainPredictor(ch, trainer, data, count);
                }
            }

            return(CreatePredictor());
        }
        private protected override void CheckLabel(RoleMappedData data)
        {
            Contracts.AssertValue(data);
            // REVIEW: For floating point labels, this will make a pass over the data.
            // Should we instead leverage the pass made by the LBFGS base class? Ideally, it wouldn't
            // make a pass over the data...
            data.CheckMultiClassLabel(out _numClasses);

            // Initialize prior counts.
            _prior = new Double[_numClasses];

            // Try to get the label key values metedata.
            var labelCol          = data.Schema.Label.Value;
            var labelMetadataType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type;

            if (!(labelMetadataType is VectorType vecType && vecType.ItemType == TextType.Instance && vecType.Size == _numClasses))
            {
                _labelNames = null;
                return;
            }
            VBuffer <ReadOnlyMemory <char> > labelNames = default;

            labelCol.Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref labelNames);

            // If label names is not dense or contain NA or default value, then it follows that
            // at least one class does not have a valid name for its label. If the label names we
            // try to get from the metadata are not unique, we may also not use them in model summary.
            // In both cases we set _labelNames to null and use the "Class_n", where n is the class number
            // for model summary saving instead.
            if (!labelNames.IsDense)
            {
                _labelNames = null;
                return;
            }

            _labelNames = new string[_numClasses];
            ReadOnlySpan <ReadOnlyMemory <char> > values = labelNames.GetValues();

            // This hashset is used to verify the uniqueness of label names.
            HashSet <string> labelNamesSet = new HashSet <string>();

            for (int i = 0; i < _numClasses; i++)
            {
                ReadOnlyMemory <char> value = values[i];
                if (value.IsEmpty)
                {
                    _labelNames = null;
                    break;
                }

                var vs = values[i].ToString();
                if (!labelNamesSet.Add(vs))
                {
                    _labelNames = null;
                    break;
                }

                _labelNames[i] = vs;

                Contracts.Assert(!string.IsNullOrEmpty(_labelNames[i]));
            }

            Contracts.Assert(_labelNames == null || _labelNames.Length == _numClasses);
        }
示例#5
0
 private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
 {
     examples.CheckMultiClassLabel(out weightSetCount);
 }
示例#6
0
 protected override void CheckLabel(RoleMappedData examples)
 {
     examples.CheckMultiClassLabel(out _numClasses);
 }
        protected override void CheckLabel(RoleMappedData data)
        {
            Contracts.AssertValue(data);
            // REVIEW: For floating point labels, this will make a pass over the data.
            // Should we instead leverage the pass made by the LBFGS base class? Ideally, it wouldn't
            // make a pass over the data...
            data.CheckMultiClassLabel(out _numClasses);

            // Initialize prior counts.
            _prior = new Double[_numClasses];

            // Try to get the label key values metedata.
            var schema            = data.Data.Schema;
            var labelIdx          = data.Schema.Label.Index;
            var labelMetadataType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, labelIdx);

            if (labelMetadataType == null || !labelMetadataType.IsKnownSizeVector || !labelMetadataType.ItemType.IsText ||
                labelMetadataType.VectorSize != _numClasses)
            {
                _labelNames = null;
                return;
            }

            VBuffer <DvText> labelNames = default(VBuffer <DvText>);

            schema.GetMetadata(MetadataUtils.Kinds.KeyValues, labelIdx, ref labelNames);

            // If label names is not dense or contain NA or default value, then it follows that
            // at least one class does not have a valid name for its label. If the label names we
            // try to get from the metadata are not unique, we may also not use them in model summary.
            // In both cases we set _labelNames to null and use the "Class_n", where n is the class number
            // for model summary saving instead.
            if (!labelNames.IsDense)
            {
                _labelNames = null;
                return;
            }

            _labelNames = new string[_numClasses];
            DvText[] values = labelNames.Values;

            // This hashset is used to verify the uniqueness of label names.
            HashSet <string> labelNamesSet = new HashSet <string>();

            for (int i = 0; i < _numClasses; i++)
            {
                DvText value = values[i];
                if (value.IsEmpty || value.IsNA)
                {
                    _labelNames = null;
                    break;
                }

                var vs = values[i].ToString();
                if (!labelNamesSet.Add(vs))
                {
                    _labelNames = null;
                    break;
                }

                _labelNames[i] = vs;

                Contracts.Assert(!string.IsNullOrEmpty(_labelNames[i]));
            }

            Contracts.Assert(_labelNames == null || _labelNames.Length == _numClasses);
        }
示例#8
0
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var    data0 = data;
            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;

            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.R4)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberType.R4 });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                int nb  = data.Schema.Label.Value.Type.AsKey().KeyCount();
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorType(NumberType.R4, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMultiClassLabel(out count3);
                if (count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorType(NumberType.R4, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var       args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            IDataView after_concatenation = ComponentCreation.CreateTransform(Host, args, viewI);

            var roles = data.Schema.GetColumnRoleNames()
                        .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                        .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                        .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                        .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName));
            var trainer_input = new RoleMappedData(after_concatenation, roles);

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }
            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var data0 = data;

            #region adding group ID

            // We insert a group Id.
            string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR";
            var    groupArgs       = new GenerateNumberTransform.Arguments
            {
                Column     = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) },
                UseCounter = true
            };

            var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data);
            data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames());

            #endregion

            #region preparing the training dataset

            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            // We check the label is not boolean.
            int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName);
            var typeLab  = trans.Schema[indexLab].Type;
            if (typeLab.RawKind() == DataKind.BL)
            {
                throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind());
            }

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;
            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.R4)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberType.R4 });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                int nb  = data.Schema.Label.Value.Type.AsKey().KeyCount();
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorType(NumberType.R4, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMultiClassLabel(out count3);
                if (count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorType(NumberType.R4, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI);

            #endregion

            #region converting label and group into keys

            // We need to convert the label into a Key.
            var convArgs = new MultiClassConvertTransform.Arguments
            {
                column   = new[] { MultiClassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) },
                keyRange = new KeyRange()
                {
                    Min = 0, Max = 4
                },
                resultType = DataKind.U4
            };
            IDataView after_concatenation_key_label = new MultiClassConvertTransform(Host, convArgs, after_concatenation_);

            // The group must be a key too!
            convArgs = new MultiClassConvertTransform.Arguments
            {
                column   = new[] { MultiClassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) },
                keyRange = new KeyRange()
                {
                    Min = 0, Max = null
                },
                resultType = _args.groupIsU4 ? DataKind.U4 : DataKind.U8
            };
            after_concatenation_key_label = new MultiClassConvertTransform(Host, convArgs, after_concatenation_key_label);

            #endregion

            #region preparing the RoleMapData view

            string groupColumn = groupColumnTemp + "k";
            dstName += "k";

            var roles      = data.Schema.GetColumnRoleNames();
            var rolesArray = roles.ToArray();
            roles = roles
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                    .Where(kvp => kvp.Key.Value != groupColumn)
                    .Where(kvp => kvp.Key.Value != groupColumnTemp);
            rolesArray = roles.ToArray();
            if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp)
            {
                throw ch.Except("Duplicated group.");
            }
            roles = roles
                    .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                    .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName))
                    .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn));
            var trainer_input = new RoleMappedData(after_concatenation_key_label, roles);

            #endregion

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }

            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }