protected override INearestNeighborsPredictor Train(RoleMappedData data) { int count; data.CheckMulticlassLabel(out count); return(base.Train(data)); }
TVectorPredictor Train(RoleMappedData data) { Contracts.CheckValue(data, "data"); data.CheckFeatureFloatVector(); int count; data.CheckMulticlassLabel(out count); using (var ch = Host.Start("Training")) { // Train one-vs-all models. _predictors = new TVectorPredictor[1]; for (int i = 0; i < _predictors.Length; i++) { ch.Info("Training learner {0}", i); Contracts.CheckValue(_args.predictorType, "predictorType", "Must specify a base learner type"); TScalarTrainer trainer; if (_trainer != null) { trainer = _trainer; } else { var temp = ScikitSubComponent <ITrainer, SignatureBinaryClassifierTrainer> .AsSubComponent(_args.predictorType); var inst = temp.CreateInstance(Host); trainer = inst as TScalarTrainer; if (trainer == null) { var allTypes = TrainerHelper.GetParentTypes(inst.GetType()).ToArray(); var allTypesStr = string.Join("\n", allTypes.Select(c => c.ToString())); throw ch.ExceptNotSupp($"Unable to cast {inst.GetType()} into {typeof(TScalarTrainer)}.\n-TYPES-\n{allTypesStr}"); } } _trainer = null; _predictors[i] = TrainPredictor(ch, trainer, data, count); } } return(CreatePredictor()); }
private protected override void CheckLabel(RoleMappedData examples, out int weightSetCount) { examples.CheckMulticlassLabel(out weightSetCount); }
protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count) { var data0 = data; #region adding group ID // We insert a group Id. string groupColumnTemp = DataViewUtils.GetTempColumnName(data.Schema.Schema) + "GR"; var groupArgs = new GenerateNumberTransform.Options { Columns = new[] { GenerateNumberTransform.Column.Parse(groupColumnTemp) }, UseCounter = true }; var withGroup = new GenerateNumberTransform(Host, groupArgs, data.Data); data = new RoleMappedData(withGroup, data.Schema.GetColumnRoleNames()); #endregion #region preparing the training dataset string dstName, labName; var trans = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args); var newFeatures = trans.Schema.GetTempColumnName() + "NF"; // We check the label is not boolean. int indexLab = SchemaHelper.GetColumnIndex(trans.Schema, dstName); var typeLab = trans.Schema[indexLab].Type; if (typeLab.RawKind() == DataKind.Boolean) { throw Host.Except("Column '{0}' has an unexpected type {1}.", dstName, typeLab.RawKind()); } var args3 = new DescribeTransform.Arguments { columns = new string[] { labName, dstName }, oneRowPerColumn = true }; var desc = new DescribeTransform(Host, args3, trans); IDataView viewI; if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single) { viewI = desc; } else if (_args.singleColumn) { var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, false); #endif #endregion } else if (data.Schema.Label.Value.Type.IsKey()) { ulong nb = data.Schema.Label.Value.Type.AsKey().GetKeyCount(); var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) int nb_; MinMaxLabelOverDataSet(trans, labName, out nb_); int count3; data.CheckMulticlassLabel(out count3); if ((ulong)count3 != nb) { throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3); } DebugChecking0(viewI, labName, true); DebugChecking0Vfloat(viewI, labName, nb); #endif #endregion } else { int nb; if (count <= 0) { MinMaxLabelOverDataSet(trans, labName, out nb); } else { nb = count; } var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, true); #endif #endregion } ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name); var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName); var after_concatenation_ = ComponentCreation.CreateTransform(Host, args, viewI); #endregion #region converting label and group into keys // We need to convert the label into a Key. var convArgs = new MulticlassConvertTransform.Arguments { column = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", dstName)) }, keyCount = new KeyCount(4), resultType = DataKind.UInt32 }; IDataView after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_); // The group must be a key too! convArgs = new MulticlassConvertTransform.Arguments { column = new[] { MulticlassConvertTransform.Column.Parse(string.Format("{0}k:{0}", groupColumnTemp)) }, keyCount = new KeyCount(), resultType = _args.groupIsU4 ? DataKind.UInt32 : DataKind.UInt64 }; after_concatenation_key_label = new MulticlassConvertTransform(Host, convArgs, after_concatenation_key_label); #endregion #region preparing the RoleMapData view string groupColumn = groupColumnTemp + "k"; dstName += "k"; var roles = data.Schema.GetColumnRoleNames(); var rolesArray = roles.ToArray(); roles = roles .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value) .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value) .Where(kvp => kvp.Key.Value != groupColumn) .Where(kvp => kvp.Key.Value != groupColumnTemp); rolesArray = roles.ToArray(); if (rolesArray.Any() && rolesArray[0].Value == groupColumnTemp) { throw ch.Except("Duplicated group."); } roles = roles .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures)) .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName)) .Prepend(RoleMappedSchema.ColumnRole.Group.Bind(groupColumn)); var trainer_input = new RoleMappedData(after_concatenation_key_label, roles); #endregion ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type); ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type); // We train the unique binary classifier. var trainedPredictor = trainer.Train(trainer_input); var predictors = new TScalarPredictor[] { trainedPredictor }; // We train the reclassification classifier. if (_args.reclassicationPredictor != null) { var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null); TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor)); } return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor)); }
protected override TVectorPredictor TrainPredictor(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int count) { var data0 = data; string dstName, labName; var trans = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, true, _args); var newFeatures = trans.Schema.GetTempColumnName() + "NF"; var args3 = new DescribeTransform.Arguments { columns = new string[] { labName, dstName }, oneRowPerColumn = true }; var desc = new DescribeTransform(Host, args3, trans); IDataView viewI; if (_args.singleColumn && data.Schema.Label.Value.Type.RawKind() == DataKind.Single) { viewI = desc; } else if (_args.singleColumn) { var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { NumberDataViewType.Single }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, false); #endif #endregion } else if (data.Schema.Label.Value.Type.IsKey()) { ulong nb = data.Schema.Label.Value.Type.AsKey().GetKeyCount(); var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, (int)nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) int nb_; MinMaxLabelOverDataSet(trans, labName, out nb_); int count3; data.CheckMulticlassLabel(out count3); if ((ulong)count3 != nb) { throw ch.Except("Count mismatch (KeyCount){0} != {1}", nb, count3); } DebugChecking0(viewI, labName, true); DebugChecking0Vfloat(viewI, labName, nb); #endif #endregion } else { int nb; if (count <= 0) { MinMaxLabelOverDataSet(trans, labName, out nb); } else { nb = count; } var sch = new TypeReplacementSchema(desc.Schema, new[] { labName }, new[] { new VectorDataViewType(NumberDataViewType.Single, nb) }); viewI = new TypeReplacementDataView(desc, sch); #region debug #if (DEBUG) DebugChecking0(viewI, labName, true); #endif #endregion } ch.Info("Merging column label '{0}' with features '{1}'", labName, data.Schema.Feature.Value.Name); var args = string.Format("Concat{{col={0}:{1},{2}}}", newFeatures, data.Schema.Feature.Value.Name, labName); IDataView after_concatenation = ComponentCreation.CreateTransform(Host, args, viewI); var roles = data.Schema.GetColumnRoleNames() .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Label.Value) .Where(kvp => kvp.Key.Value != RoleMappedSchema.ColumnRole.Feature.Value) .Prepend(RoleMappedSchema.ColumnRole.Feature.Bind(newFeatures)) .Prepend(RoleMappedSchema.ColumnRole.Label.Bind(dstName)); var trainer_input = new RoleMappedData(after_concatenation, roles); ch.Info("New Features: {0}:{1}", trainer_input.Schema.Feature.Value.Name, trainer_input.Schema.Feature.Value.Type); ch.Info("New Label: {0}:{1}", trainer_input.Schema.Label.Value.Name, trainer_input.Schema.Label.Value.Type); // We train the unique binary classifier. var trainedPredictor = trainer.Train(trainer_input); var predictors = new TScalarPredictor[] { trainedPredictor }; // We train the reclassification classifier. if (_args.reclassicationPredictor != null) { var pred = CreateFinalPredictor(ch, data, trans, count, _args, predictors, null); TrainReclassificationPredictor(data0, pred, ScikitSubComponent <ITrainer, SignatureTrainer> .AsSubComponent(_args.reclassicationPredictor)); } return(CreateFinalPredictor(ch, data, trans, count, _args, predictors, _reclassPredictor)); }
private protected override void CheckLabel(RoleMappedData data) { Contracts.AssertValue(data); // REVIEW: For floating point labels, this will make a pass over the data. // Should we instead leverage the pass made by the LBFGS base class? Ideally, it wouldn't // make a pass over the data... data.CheckMulticlassLabel(out _numClasses); // Initialize prior counts. _prior = new Double[_numClasses]; // Try to get the label key values metedata. var labelCol = data.Schema.Label.Value; var labelMetadataType = labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type; if (!(labelMetadataType is VectorType vecType && vecType.ItemType == TextDataViewType.Instance && vecType.Size == _numClasses)) { _labelNames = null; return; } VBuffer <ReadOnlyMemory <char> > labelNames = default; labelCol.GetKeyValues(ref labelNames); // If label names is not dense or contain NA or default value, then it follows that // at least one class does not have a valid name for its label. If the label names we // try to get from the metadata are not unique, we may also not use them in model summary. // In both cases we set _labelNames to null and use the "Class_n", where n is the class number // for model summary saving instead. if (!labelNames.IsDense) { _labelNames = null; return; } _labelNames = new string[_numClasses]; ReadOnlySpan <ReadOnlyMemory <char> > values = labelNames.GetValues(); // This hashset is used to verify the uniqueness of label names. HashSet <string> labelNamesSet = new HashSet <string>(); for (int i = 0; i < _numClasses; i++) { ReadOnlyMemory <char> value = values[i]; if (value.IsEmpty) { _labelNames = null; break; } var vs = values[i].ToString(); if (!labelNamesSet.Add(vs)) { _labelNames = null; break; } _labelNames[i] = vs; Contracts.Assert(!string.IsNullOrEmpty(_labelNames[i])); } Contracts.Assert(_labelNames == null || _labelNames.Length == _numClasses); }