예제 #1
0
        protected override INearestNeighborsPredictor Train(RoleMappedData data)
        {
            int count;

            data.CheckMulticlassLabel(out count);
            return(base.Train(data));
        }
        TVectorPredictor Train(RoleMappedData data)
        {
            Contracts.CheckValue(data, "data");
            data.CheckFeatureFloatVector();

            int count;

            data.CheckMulticlassLabel(out count);

            using (var ch = Host.Start("Training"))
            {
                // Train one-vs-all models.
                _predictors = new TVectorPredictor[1];
                for (int i = 0; i < _predictors.Length; i++)
                {
                    ch.Info("Training learner {0}", i);
                    Contracts.CheckValue(_args.predictorType, "predictorType", "Must specify a base learner type");

                    TScalarTrainer trainer;
                    if (_trainer != null)
                    {
                        trainer = _trainer;
                    }
                    else
                    {
                        var temp = ScikitSubComponent <ITrainer, SignatureBinaryClassifierTrainer> .AsSubComponent(_args.predictorType);

                        var inst = temp.CreateInstance(Host);
                        trainer = inst as TScalarTrainer;
                        if (trainer == null)
                        {
                            var allTypes    = TrainerHelper.GetParentTypes(inst.GetType()).ToArray();
                            var allTypesStr = string.Join("\n", allTypes.Select(c => c.ToString()));
                            throw ch.ExceptNotSupp($"Unable to cast {inst.GetType()} into {typeof(TScalarTrainer)}.\n-TYPES-\n{allTypesStr}");
                        }
                    }

                    _trainer       = null;
                    _predictors[i] = TrainPredictor(ch, trainer, data, count);
                }
            }

            return(CreatePredictor());
        }
예제 #3
0
 private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount)
 {
     examples.CheckMulticlassLabel(out weightSetCount);
 }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var data0 = data;

            #region adding group ID

            // We insert a group Id.
            string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR";
            var    groupArgs       = new GenerateNumberTransform.Options
            {
                Columns    = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) },
                UseCounter = true
            };

            var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data);
            data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames());

            #endregion

            #region preparing the training dataset

            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            // We check the label is not boolean.
            int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName);
            var typeLab  = trans.Schema[indexLab].Type;
            if (typeLab.RawKind() == DataKind.Boolean)
            {
                throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind());
            }

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;
            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                ulong nb  = data.Schema.Label.Value.Type.AsKey().GetKeyCount();
                var   sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMulticlassLabel(out count3);
                if ((ulong)count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI);

            #endregion

            #region converting label and group into keys

            // We need to convert the label into a Key.
            var convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) },
                keyCount   = new KeyCount(4),
                resultType = DataKind.UInt32
            };
            IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_);

            // The group must be a key too!
            convArgs = new MulticlassConvertTransform.Arguments
            {
                column     = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) },
                keyCount   = new KeyCount(),
                resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64
            };
            after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label);

            #endregion

            #region preparing the RoleMapData view

            string groupColumn = groupColumnTemp + "k";
            dstName += "k";

            var roles      = data.Schema.GetColumnRoleNames();
            var rolesArray = roles.ToArray();
            roles = roles
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                    .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                    .Where(kvp => kvp.Key.Value != groupColumn)
                    .Where(kvp => kvp.Key.Value != groupColumnTemp);
            rolesArray = roles.ToArray();
            if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp)
            {
                throw ch.Except("Duplicated group.");
            }
            roles = roles
                    .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                    .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName))
                    .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn));
            var trainer_input = new RoleMappedData(after_concatenation_key_label, roles);

            #endregion

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }

            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
        protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count)
        {
            var    data0 = data;
            string dstName, labName;
            var    trans       = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args);
            var    newFeatures = trans.Schema.GetTempColumnName() + "NF";

            var args3 = new DescribeTransform.Arguments {
                columns = new string[] { labName, dstName }, oneRowPerColumn = true
            };
            var desc = new DescribeTransform(Host, args3, trans);

            IDataView viewI;

            if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single)
            {
                viewI = desc;
            }
            else if (_args.singleColumn)
            {
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, false);
#endif
                #endregion
            }
            else if (data.Schema.Label.Value.Type.IsKey())
            {
                ulong nb  = data.Schema.Label.Value.Type.AsKey().GetKeyCount();
                var   sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                int nb_;
                MinMaxLabelOverDataSet(trans, labName, out nb_);
                int count3;
                data.CheckMulticlassLabel(out count3);
                if ((ulong)count3 != nb)
                {
                    throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3);
                }
                DebugChecking0(viewI, labName, true);
                DebugChecking0Vfloat(viewI, labName, nb);
#endif
                #endregion
            }
            else
            {
                int nb;
                if (count <= 0)
                {
                    MinMaxLabelOverDataSet(trans, labName, out nb);
                }
                else
                {
                    nb = count;
                }
                var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) });
                viewI = new TypeReplacementDataView(desc, sch);
                #region debug
#if (DEBUG)
                DebugChecking0(viewI, labName, true);
#endif
                #endregion
            }

            ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name);
            var       args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName);
            IDataView after_concatenation = ComponentCreation.CreateTransform(Host, args, viewI);

            var roles = data.Schema.GetColumnRoleNames()
                        .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value)
                        .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value)
                        .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures))
                        .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName));
            var trainer_input = new RoleMappedData(after_concatenation, roles);

            ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type);
            ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type);

            // We train the unique binary classifier.
            var trainedPredictor = trainer.Train(trainer_input);
            var predictors       = new TScalarPredictor[] { trainedPredictor };

            // We train the reclassification classifier.
            if (_args.reclassicationPredictor != null)
            {
                var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null);
                TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor));
            }
            return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor));
        }
        private protected override void CheckLabel(RoleMappedData data)
        {
            Contracts.AssertValue(data);
            // REVIEW: For floating point labels, this will make a pass over the data.
            // Should we instead leverage the pass made by the LBFGS base class? Ideally, it wouldn't
            // make a pass over the data...
            data.CheckMulticlassLabel(out _numClasses);

            // Initialize prior counts.
            _prior = new Double[_numClasses];

            // Try to get the label key values metedata.
            var labelCol          = data.Schema.Label.Value;
            var labelMetadataType = labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;

            if (!(labelMetadataType is VectorType vecType && vecType.ItemType == TextDataViewType.Instance && vecType.Size == _numClasses))
            {
                _labelNames = null;
                return;
            }
            VBuffer <ReadOnlyMemory <char> > labelNames = default;

            labelCol.GetKeyValues(ref labelNames);

            // If label names is not dense or contain NA or default value, then it follows that
            // at least one class does not have a valid name for its label. If the label names we
            // try to get from the metadata are not unique, we may also not use them in model summary.
            // In both cases we set _labelNames to null and use the "Class_n", where n is the class number
            // for model summary saving instead.
            if (!labelNames.IsDense)
            {
                _labelNames = null;
                return;
            }

            _labelNames = new string[_numClasses];
            ReadOnlySpan <ReadOnlyMemory <char> > values = labelNames.GetValues();

            // This hashset is used to verify the uniqueness of label names.
            HashSet <string> labelNamesSet = new HashSet <string>();

            for (int i = 0; i < _numClasses; i++)
            {
                ReadOnlyMemory <char> value = values[i];
                if (value.IsEmpty)
                {
                    _labelNames = null;
                    break;
                }

                var vs = values[i].ToString();
                if (!labelNamesSet.Add(vs))
                {
                    _labelNames = null;
                    break;
                }

                _labelNames[i] = vs;

                Contracts.Assert(!string.IsNullOrEmpty(_labelNames[i]));
            }

            Contracts.Assert(_labelNames == null || _labelNames.Length == _numClasses);
        }