void DecomposableTrainAndPredict() { using (var env = new LocalEnvironment() .AddStandardComponents()) // ScoreUtils.GetScorer requires scorers to be registered in the ComponentCatalog { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); var term = TermTransform.Create(env, loader, "Label"); var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); var trainer = new SdcaMultiClassTrainer(env, "Features", "Label", advancedSettings: (s) => { s.MaxIterations = 100; s.Shuffle = true; s.NumThreads = 1; }); IDataView trainData = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat; var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); // Cut out term transform from pipeline. var newScorer = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term); var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(newScorer); var model = env.CreatePredictionEngine <IrisDataNoLabel, IrisPrediction>(keyToValue); var testData = loader.AsEnumerable <IrisDataNoLabel>(env, false); foreach (var input in testData.Take(20)) { var prediction = model.Predict(input); Assert.True(prediction.PredictedLabel == "Iris-setosa"); } } }
private static IDataView ApplyKeyToVec(List <KeyToVectorTransform.ColumnInfo> ktv, IDataView viewTrain, IHost host) { Contracts.AssertValueOrNull(ktv); Contracts.AssertValue(viewTrain); Contracts.AssertValue(host); if (Utils.Size(ktv) > 0) { // Instead of simply using KeyToVector, we are jumping to some hoops here to do the right thing in a very common case // when the user has slightly different key values between the training and testing set. // The solution is to apply KeyToValue, then Term using the terms from the key metadata of the original key column // and finally the KeyToVector transform. viewTrain = new KeyToValueTransform(host, ktv.Select(x => (x.Input, x.Output)).ToArray()) .Transform(viewTrain); viewTrain = TermTransform.Create(host, new TermTransform.Arguments() { Column = ktv .Select(c => new TermTransform.Column() { Name = c.Output, Source = c.Output, Terms = GetTerms(viewTrain, c.Input) }) .ToArray(), TextKeyValues = true }, viewTrain); viewTrain = KeyToVectorTransform.Create(host, viewTrain, ktv.Select(c => new KeyToVectorTransform.ColumnInfo(c.Output, c.Output)).ToArray()); } return(viewTrain); }
void Extensibility() { var dataPath = GetDataPath(IrisDataPath); using (var env = new LocalEnvironment()) { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); Action <IrisData, IrisData> action = (i, j) => { j.Label = i.Label; j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength; j.PetalWidth = i.PetalWidth; j.SepalLength = i.SepalLength; j.SepalWidth = i.SepalWidth; }; var lambda = LambdaTransform.CreateMap(env, loader, action); var term = TermTransform.Create(env, lambda, "Label"); var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Transform(term); var trainer = new SdcaMultiClassTrainer(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); IDataView trainData = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat; var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features"); IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); var keyToValue = new KeyToValueTransform(env, "PredictedLabel").Transform(scorer); var model = env.CreatePredictionEngine <IrisData, IrisPrediction>(keyToValue); var testLoader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); var testData = testLoader.AsEnumerable <IrisData>(env, false); foreach (var input in testData.Take(20)) { var prediction = model.Predict(input); Assert.True(prediction.PredictedLabel == input.Label); } } }
public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvironment env, ClassificationLabelInput input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register("PrepareClassificationLabel"); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); int labelCol; if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol)) { throw host.Except($"Column '{input.LabelColumn}' not found."); } var labelType = input.Data.Schema.GetColumnType(labelCol); if (labelType.IsKey || labelType.IsBool) { var nop = NopTransform.CreateIfNeeded(env, input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModel(env, nop, input.Data), OutputData = nop }); } var args = new TermTransform.Arguments() { Column = new[] { new TermTransform.Column() { Name = input.LabelColumn, Source = input.LabelColumn, TextKeyValues = input.TextKeyValues, Sort = TermTransform.SortOrder.Value } } }; var xf = TermTransform.Create(host, args, input.Data); return(new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf }); }
public void Metacomponents() { using (var env = new LocalEnvironment()) { var loader = TextLoader.ReadFile(env, MakeIrisTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.irisData.trainFilename))); var term = TermTransform.Create(env, loader, "Label"); var concat = new ConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth").Transform(term); var trainer = new Ova(env, new Ova.Arguments { PredictorType = ComponentFactoryUtils.CreateFromFunction( e => new AveragedPerceptronTrainer(env, new AveragedPerceptronTrainer.Arguments())) }); IDataView trainData = trainer.Info.WantCaching ? (IDataView) new CacheDataView(env, concat, prefetch: null) : concat; var trainRoles = new RoleMappedData(trainData, label: "Label", feature: "Features"); // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); var predictor = trainer.Train(new TrainContext(trainRoles)); } }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, TermLoaderArguments termLoaderArgs = null) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); h.CheckValue(args, nameof(args)); h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column), "Columns must be specified"); // To each input column to the NgramHashExtractorArguments, a HashTransform using 31 // bits (to minimize collisions) is applied first, followed by an NgramHashTransform. IDataView view = input; List <TermTransform.Column> termCols = null; if (termLoaderArgs != null) { termCols = new List <TermTransform.Column>(); } var hashColumns = new List <HashTransformer.Column>(); var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length]; var colCount = args.Column.Length; // The NGramHashExtractor has a ManyToOne column type. To avoid stepping over the source // column name when a 'name' destination column name was specified, we use temporary column names. string[][] tmpColNames = new string[colCount][]; for (int iinfo = 0; iinfo < colCount; iinfo++) { var column = args.Column[iinfo]; h.CheckUserArg(!string.IsNullOrWhiteSpace(column.Name), nameof(column.Name)); h.CheckUserArg(Utils.Size(column.Source) > 0 && column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source)); int srcCount = column.Source.Length; tmpColNames[iinfo] = new string[srcCount]; for (int isrc = 0; isrc < srcCount; isrc++) { var tmpName = input.Schema.GetTempColumnName(column.Source[isrc]); tmpColNames[iinfo][isrc] = tmpName; if (termLoaderArgs != null) { termCols.Add( new TermTransform.Column { Name = tmpName, Source = column.Source[isrc] }); } hashColumns.Add( new HashTransformer.Column { Name = tmpName, Source = termLoaderArgs == null ? column.Source[isrc] : tmpName, HashBits = 30, Seed = column.Seed, Ordered = false, InvertHash = column.InvertHash }); } ngramHashColumns[iinfo] = new NgramHashTransform.Column { Name = column.Name, Source = tmpColNames[iinfo], AllLengths = column.AllLengths, HashBits = column.HashBits, NgramLength = column.NgramLength, RehashUnigrams = false, Seed = column.Seed, SkipLength = column.SkipLength, Ordered = column.Ordered, InvertHash = column.InvertHash, // REVIEW: This is an ugly internal hack to get around // the problem that we want the *original* source names surfacing // in the descriptions where appropriate, rather than _tmp000 and // what have you. The alternative is we do something elaborate // with metadata or something but I'm not sure that's better. FriendlyNames = column.FriendlyNames }; } if (termLoaderArgs != null) { h.Assert(Utils.Size(termCols) == hashColumns.Count); var termArgs = new TermTransform.Arguments() { MaxNumTerms = int.MaxValue, Terms = termLoaderArgs.Terms, Term = termLoaderArgs.Term, DataFile = termLoaderArgs.DataFile, Loader = termLoaderArgs.Loader, TermsColumn = termLoaderArgs.TermsColumn, Sort = termLoaderArgs.Sort, Column = termCols.ToArray() }; view = TermTransform.Create(h, termArgs, view); if (termLoaderArgs.DropUnknowns) { var naDropArgs = new NADropTransform.Arguments { Column = new NADropTransform.Column[termCols.Count] }; for (int iinfo = 0; iinfo < termCols.Count; iinfo++) { naDropArgs.Column[iinfo] = new NADropTransform.Column { Name = termCols[iinfo].Name, Source = termCols[iinfo].Name }; } view = new NADropTransform(h, naDropArgs, view); } } // Args for the Hash function with multiple columns var hashArgs = new HashTransformer.Arguments { HashBits = 31, Seed = args.Seed, Ordered = false, Column = hashColumns.ToArray(), InvertHash = args.InvertHash }; view = HashTransformer.Create(h, hashArgs, view); // creating the NgramHash function var ngramHashArgs = new NgramHashTransform.Arguments { AllLengths = args.AllLengths, HashBits = args.HashBits, NgramLength = args.NgramLength, SkipLength = args.SkipLength, RehashUnigrams = false, Ordered = args.Ordered, Seed = args.Seed, Column = ngramHashColumns, InvertHash = args.InvertHash }; view = new NgramHashTransform(h, ngramHashArgs, view); return(SelectColumnsTransform.CreateDrop(h, view, tmpColNames.SelectMany(cols => cols).ToArray())); }
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input, TermLoaderArguments termLoaderArgs = null) { Contracts.CheckValue(env, nameof(env)); var h = env.Register(LoaderSignature); h.CheckValue(args, nameof(args)); h.CheckValue(input, nameof(input)); h.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column), "Columns must be specified"); IDataView view = input; var termCols = new List <Column>(); var isTermCol = new bool[args.Column.Length]; for (int i = 0; i < args.Column.Length; i++) { var col = args.Column[i]; h.CheckNonWhiteSpace(col.Name, nameof(col.Name)); h.CheckNonWhiteSpace(col.Source, nameof(col.Source)); int colId; if (input.Schema.TryGetColumnIndex(col.Source, out colId) && input.Schema.GetColumnType(colId).ItemType.IsText) { termCols.Add(col); isTermCol[i] = true; } } // If the column types of args.column are text, apply term transform to convert them to keys. // Otherwise, skip term transform and apply ngram transform directly. // This logic allows NgramExtractorTransform to handle both text and key input columns. // Note: ngram transform handles the validation of the types natively (in case the types // of args.column are not text nor keys). if (termCols.Count > 0) { TermTransform.Arguments termArgs = null; MissingValueDroppingTransformer.Arguments naDropArgs = null; if (termLoaderArgs != null) { termArgs = new TermTransform.Arguments() { MaxNumTerms = int.MaxValue, Terms = termLoaderArgs.Terms, Term = termLoaderArgs.Term, DataFile = termLoaderArgs.DataFile, Loader = termLoaderArgs.Loader, TermsColumn = termLoaderArgs.TermsColumn, Sort = termLoaderArgs.Sort, Column = new TermTransform.Column[termCols.Count] }; if (termLoaderArgs.DropUnknowns) { naDropArgs = new MissingValueDroppingTransformer.Arguments { Column = new MissingValueDroppingTransformer.Column[termCols.Count] } } ; } else { termArgs = new TermTransform.Arguments() { MaxNumTerms = Utils.Size(args.MaxNumTerms) > 0 ? args.MaxNumTerms[0] : NgramTransform.Arguments.DefaultMaxTerms, Column = new TermTransform.Column[termCols.Count] }; } for (int iinfo = 0; iinfo < termCols.Count; iinfo++) { var column = termCols[iinfo]; termArgs.Column[iinfo] = new TermTransform.Column() { Name = column.Name, Source = column.Source, MaxNumTerms = Utils.Size(column.MaxNumTerms) > 0 ? column.MaxNumTerms[0] : default(int?) }; if (naDropArgs != null) { naDropArgs.Column[iinfo] = new MissingValueDroppingTransformer.Column { Name = column.Name, Source = column.Name } } ; } view = TermTransform.Create(h, termArgs, view); if (naDropArgs != null) { view = new MissingValueDroppingTransformer(h, naDropArgs, view); } } var ngramArgs = new NgramTransform.Arguments() { MaxNumTerms = args.MaxNumTerms, NgramLength = args.NgramLength, SkipLength = args.SkipLength, AllLengths = args.AllLengths, Weighting = args.Weighting, Column = new NgramTransform.Column[args.Column.Length] }; for (int iinfo = 0; iinfo < args.Column.Length; iinfo++) { var column = args.Column[iinfo]; ngramArgs.Column[iinfo] = new NgramTransform.Column() { Name = column.Name, Source = isTermCol[iinfo] ? column.Name : column.Source, AllLengths = column.AllLengths, MaxNumTerms = column.MaxNumTerms, NgramLength = column.NgramLength, SkipLength = column.SkipLength, Weighting = column.Weighting }; } return(new NgramTransform(h, ngramArgs, view)); }