예제 #1
0
        private MultiToBinaryPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, ctx, RegistrationName)
        {
            byte bkind = ctx.Reader.ReadByte();

            env.Check(bkind >= 0 && bkind <= 100, "kind");
            var kind = (DataKind)bkind;

            switch (kind)
            {
            case DataKind.Single:
                _impl = new ImplRawBinary <float>(ctx, env);
                break;

            case DataKind.SByte:
                _impl = new ImplRawBinary <byte>(ctx, env);
                break;

            case DataKind.UInt16:
                _impl = new ImplRawBinary <ushort>(ctx, env);
                break;

            case DataKind.UInt32:
                _impl = new ImplRawBinary <uint>(ctx, env);
                break;

            default:
                throw env.ExceptNotSupp("Not supported label type.");
            }
        }
        internal DateTimeTransformer(IHostEnvironment host, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country) :
            base(host.Register(nameof(DateTimeTransformer)))
        {
            host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");

            _column = new LongTypedColumn(inputColumnName, columnPrefix);
            _column.CreateTransformerFromEstimator(country);
        }
        internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
        {
            var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;

            while (calibrated != null)
            {
                predictor  = calibrated.WeeklyTypedSubModel;
                calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
            }

            IDataView summary = null;

            stats = null;
            var dvGetter  = predictor as ICanGetSummaryAsIDataView;
            var rowGetter = predictor as ICanGetSummaryAsIRow;

            if (dvGetter != null)
            {
                summary = dvGetter.GetSummaryDataView(schema);
            }
            if (rowGetter != null)
            {
                var row = rowGetter.GetSummaryIRowOrNull(schema);
                env.Check(dvGetter == null || row == null,
                          "Predictor outputs two summary data views, don't know which one to choose");
                if (row != null)
                {
                    summary = RowCursorUtils.RowAsDataView(env, row);
                }
                var statsRow = rowGetter.GetStatsIRowOrNull(schema);
                if (statsRow != null)
                {
                    stats = RowCursorUtils.RowAsDataView(env, statsRow);
                }
            }
            if (dvGetter == null && rowGetter == null)
            {
                var bldr         = new ArrayDataViewBuilder(env);
                var summaryModel = predictor as ICanSaveSummary;

                // Save a data view containing one row and one column with the model summary.
                if (summaryModel != null)
                {
                    var sb = new StringBuilder();
                    using (StringWriter sw = new StringWriter(sb))
                        summaryModel.SaveSummary(sw, schema);
                    bldr.AddColumn("Summary", sb.ToString());
                }
                else
                {
                    bldr.AddColumn("PredictorName", predictor.GetType().ToString());
                }
                summary = bldr.GetDataView();
            }
            env.AssertValue(summary);
            return(summary);
        }
        /// <summary>
        /// Backwards compatibility helper function that loads a Choose Column Transform.
        /// </summary>
        private static SelectColumnsTransform LoadChooseColumnsTransform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            // *** Binary format ***
            // int: sizeof(Float)
            // bindings
            int    cbFloat = ctx.Reader.ReadInt32();
            string renameNotSupportedMsg = "Rename for ChooseColumns is not backwards compatible with the SelectColumnsTranform";
            string differentHideColumnNotSupportedMsg = "Setting a hide option different from default is not compatible with SelectColumnsTransform";
            // *** Binary format ***
            // byte: default HiddenColumnOption value
            // int: number of raw column infos
            // for each raw column info
            //   int: id of output column name
            //   int: id of input column name
            //   byte: HiddenColumnOption
            var hiddenOption = (HiddenColumnOption)ctx.Reader.ReadByte();

            Contracts.Assert(Enum.IsDefined(typeof(HiddenColumnOption), hiddenOption));
            env.Check(HiddenColumnOption.Rename != hiddenOption, renameNotSupportedMsg);
            var keepHidden = GetHiddenOption(env, hiddenOption);

            int count = ctx.Reader.ReadInt32();

            Contracts.CheckDecode(count >= 0);
            var keepHiddenCols = new HiddenColumnOption[count];

            var names = new HashSet <string>();

            for (int colIdx = 0; colIdx < count; ++colIdx)
            {
                string dst = ctx.LoadNonEmptyString();
                Contracts.CheckDecode(names.Add(dst));
                string src = ctx.LoadNonEmptyString();

                var colHiddenOption = (HiddenColumnOption)ctx.Reader.ReadByte();
                Contracts.Assert(Enum.IsDefined(typeof(HiddenColumnOption), colHiddenOption));
                env.Check(colHiddenOption != HiddenColumnOption.Rename, renameNotSupportedMsg);
                var colKeepHidden = GetHiddenOption(env, colHiddenOption);
                env.Check(colKeepHidden == keepHidden, differentHideColumnNotSupportedMsg);
            }

            return(new SelectColumnsTransform(env, names.ToArray(), null, keepHidden));
        }
예제 #5
0
            /// <summary>
            /// Create an instance, given the arguments object and arguments to the signature delegate.
            /// The args should be non-null iff ArgType is non-null. The length of the extra array should
            /// match the number of paramters for the signature delgate. When that number is zero, extra
            /// may be null.
            /// </summary>
            public object CreateInstance(IHostEnvironment env, object args, object[] extra)
            {
                Contracts.CheckValue(env, nameof(env));
                env.Check((ArgType != null) == (args != null));
                env.Check(Utils.Size(extra) == ExtraArgCount);

                List <object> prefix = new List <object>();

                if (RequireEnvironment)
                {
                    prefix.Add(env);
                }
                if (ArgType != null)
                {
                    prefix.Add(args);
                }
                var values = Utils.Concat(prefix.ToArray(), extra);

                return(CreateInstanceCore(values));
            }
예제 #6
0
        internal CategoricalImputerTransformer(IHostEnvironment host, IDataView input, CategoricalImputerEstimator.Column[] columns) :
            base(host.Register(nameof(CategoricalImputerEstimator)))
        {
            host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
            var schema = input.Schema;

            _columns = columns.Select(x => TypedColumn.CreateTypedColumn(x.Name, x.Source, schema[x.Source].Type.RawType.ToString())).ToArray();
            foreach (var column in _columns)
            {
                column.CreateTransformerFromEstimator(input);
            }
        }
예제 #7
0
        public TypeName(IHostEnvironment env, float p, int foo)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckParam(0 <= p && p <= 1, nameof(p), "Should be in range [0,1]");
            env.CheckParam(0 <= p && p <= 1, "p");                   // Should fail.
            env.CheckParam(0 <= p && p <= 1, nameof(p) + nameof(p)); // Should fail.
            env.CheckValue(paramName: nameof(p), val: "p");          // Should succeed despite confusing order.
            env.CheckValue(paramName: "p", val: nameof(p));          // Should fail despite confusing order.
            env.CheckValue("p", nameof(p));
            env.CheckUserArg(foo > 5, "foo", "Nice");
            env.CheckUserArg(foo > 5, nameof(foo), "Nice");
            env.Except();                                           // Not throwing or doing anything with the exception, so should fail.
            Contracts.ExceptParam(nameof(env), "What a silly env"); // Should also fail.
            if (false)
            {
                throw env.Except(); // Should not fail.
            }
            if (false)
            {
                throw env.ExceptParam(nameof(env), "What a silly env"); // Should not fail.
            }
            if (false)
            {
                throw env.ExceptParam("env", "What a silly env"); // Should fail due to name error.
            }
            var e = env.Except();

            env.Check(true, $"Hello {foo} is cool");
            env.Check(true, "Hello it is cool");
            string coolMessage = "Hello it is cool";

            env.Check(true, coolMessage);
            env.Check(true, string.Format("Hello {0} is cool", foo));
            env.Check(true, Messages.CoolMessage);
            env.CheckDecode(true, "Not suspicious, no ModelLoadContext");
            Contracts.Check(true, "Fine: " + nameof(env));
            Contracts.Check(true, "Less fine: " + env.GetType().Name);
            Contracts.CheckUserArg(0 <= p && p <= 1,
                                   "p", "On a new line");
        }
예제 #8
0
        /// <summary>
        /// Auto-detect column types of the file.
        /// </summary>
        public static InferenceResult InferTextFileColumnTypes(IHostEnvironment env, IMultiStreamSource fileSource, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(fileSource, nameof(fileSource));
            env.CheckValue(args, nameof(args));
            env.CheckNonEmpty(args.Separator, nameof(args.Separator));
            env.Check(args.MaxRowsToRead > 0);

            using (var ch = env.Register("InferTextFileColumnTypes").Start("TypeInference"))
            {
                return(InferTextFileColumnTypesCore(env, fileSource, args, ch));
            }
        }
예제 #9
0
        private static bool TryCreateInstance <TRes>(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra)
            where TRes : class
        {
            Contracts.CheckValue(env, nameof(env));
            env.Check(signatureType.BaseType == typeof(MulticastDelegate));
            env.CheckValueOrNull(name);

            string            nameLower = (name ?? "").ToLowerInvariant().Trim();
            LoadableClassInfo info      = FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType));

            if (info == null)
            {
                result = null;
                return(false);
            }

            if (!typeof(TRes).IsAssignableFrom(info.Type))
            {
                throw env.Except("Loadable class '{0}' does not derive from '{1}'", name, typeof(TRes).FullName);
            }

            int carg = Utils.Size(extra);

            if (info.ExtraArgCount != carg)
            {
                throw env.Except(
                          "Wrong number of extra parameters for loadable class '{0}', need '{1}', given '{2}'",
                          name, info.ExtraArgCount, carg);
            }

            if (info.ArgType == null)
            {
                if (!string.IsNullOrEmpty(options))
                {
                    throw env.Except("Loadable class '{0}' doesn't support settings", name);
                }
                result = (TRes)info.CreateInstance(env, null, extra);
                return(true);
            }

            object args = info.CreateArguments();

            if (args == null)
            {
                throw Contracts.Except("Can't instantiate arguments object '{0}' for '{1}'", info.ArgType.Name, name);
            }

            ParseArguments(env, args, options, name);
            result = (TRes)info.CreateInstance(env, args, extra);
            return(true);
        }
        /// <summary>
        /// Walks back the Source chain of the <see cref="IDataTransform"/> up to the <paramref name="oldSource"/>
        /// (or <see cref="IDataLoader"/> if <paramref name="oldSource"/> is <c>null</c>),
        /// and reapplies all transforms in the chain, to produce the same chain but bound to the different data.
        /// It is valid to have no transforms: in this case the result will be equal to <paramref name="newSource"/>
        /// If <paramref name="oldSource"/> is specified and not found in the pipe, an exception is thrown.
        /// </summary>
        /// <param name="env">The environment to use.</param>
        /// <param name="chain">The end of the chain.</param>
        /// <param name="newSource">The new data to attach the chain to.</param>
        /// <param name="oldSource">The 'old source' of the pipe, that doesn't need to be reapplied. If null, all transforms are reapplied.</param>
        /// <returns>The resulting data view.</returns>
        public static IDataView ApplyAllTransformsToData(IHostEnvironment env, IDataView chain, IDataView newSource, IDataView oldSource = null)
        {
            // REVIEW: have a variation that would selectively apply transforms?
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(chain, nameof(chain));
            env.CheckValue(newSource, nameof(newSource));
            env.CheckValueOrNull(oldSource);

            // Backtrack the chain until we reach a chain start or a non-transform.
            // REVIEW: we 'unwrap' the composite data loader here and step through its pipeline.
            // It's probably more robust to make CompositeDataLoader not even be an IDataView, this
            // would force the user to do the right thing and unwrap on his end.
            var cdl = chain as CompositeDataLoader;

            if (cdl != null)
            {
                chain = cdl.View;
            }

            var            transforms = new List <IDataTransform>();
            IDataTransform xf;

            while ((xf = chain as IDataTransform) != null)
            {
                if (chain == oldSource)
                {
                    break;
                }
                transforms.Add(xf);
                chain = xf.Source;

                cdl = chain as CompositeDataLoader;
                if (cdl != null)
                {
                    chain = cdl.View;
                }
            }
            transforms.Reverse();

            env.Check(oldSource == null || chain == oldSource, "Source data not found in the chain");

            IDataView newChain = newSource;

            foreach (var transform in transforms)
            {
                newChain = ApplyTransformToData(env, transform, newChain);
            }

            return(newChain);
        }
        public static ISchemaBindableMapper Create(IHostEnvironment env, Arguments args, IPredictor predictor)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(args, nameof(args));
            env.CheckValue(predictor, nameof(predictor));

            var pred = predictor as IQuantileRegressionPredictor;

            env.Check(pred != null, "Predictor doesn't support quantile regression");

            var quantiles = ParseQuantiles(args.Quantiles);

            return(pred.CreateMapper(quantiles));
        }
예제 #12
0
        internal DateTimeTransformer(IHostEnvironment host, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country, DataViewSchema schema) :
            base(host.Register(nameof(DateTimeTransformer)))
        {
            host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");

            _schema = schema;
            if (_schema[inputColumnName].Type.RawType != typeof(long) &&
                _schema[inputColumnName].Type.RawType != typeof(DateTime))
            {
                throw new Exception($"Unsupported type {_schema[inputColumnName].Type.RawType} for input column ${inputColumnName}. Only long and System.DateTime are supported");
            }

            _column = new LongTypedColumn(inputColumnName, columnPrefix);
            _column.CreateTransformerFromEstimator(country);
        }
        // Factory method for SignatureLoadModel.
        internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) :
            base(host.Register(nameof(DateTimeTransformer)))
        {
            Host.CheckValue(ctx, nameof(ctx));
            host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
            ctx.CheckAtModel(GetVersionInfo());
            // *** Binary format ***
            // name of input column
            // column prefix
            // length of C++ state array
            // C++ byte state array

            _column = new LongTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString());

            var dataLength = ctx.Reader.ReadInt32();
            var data       = ctx.Reader.ReadByteArray(dataLength);

            _column.CreateTransformerFromSavedData(data);
        }
            public BindingsImpl ApplyToSchema(DataViewSchema input, ISchemaBindableMapper bindable, IHostEnvironment env)
            {
                Contracts.AssertValue(env);
                env.AssertValue(input);
                env.AssertValue(bindable);

                string scoreCol = RowMapper.OutputSchema[ScoreColumnIndex].Name;
                var schema = new RoleMappedSchema(input, RowMapper.GetInputColumnRoles());

                // Checks compatibility of the predictor input types.
                var mapper = bindable.Bind(env, schema);
                var rowMapper = mapper as ISchemaBoundRowMapper;
                env.CheckParam(rowMapper != null, nameof(bindable), "Mapper must implement ISchemaBoundRowMapper");
                int mapperScoreColumn;
                bool tmp = rowMapper.OutputSchema.TryGetColumnIndex(scoreCol, out mapperScoreColumn);
                env.Check(tmp, "Mapper doesn't have expected score column");

                return new BindingsImpl(input, rowMapper, Suffix, ScoreColumnKind, true, mapperScoreColumn, PredColType);
            }
예제 #15
0
        /// <summary>
        /// Save the contents to a stream, as a "model file".
        /// </summary>
        public void SaveTo(IHostEnvironment env, Stream outputStream)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(outputStream, nameof(outputStream));

            env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
            using (var ch = env.Start("Saving pipeline"))
            {
                using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
                {
                    ch.Trace("Saving data reader");
                    ModelSaveContext.SaveModel(rep, Reader, "Reader");

                    ch.Trace("Saving transformer chain");
                    ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
                    rep.Commit();
                }
            }
        }
        /// <summary>
        /// Creates an instance of the transform from a context.
        /// </summary>
        public static ITransformTemplate Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(LoaderSignature);

            host.CheckValue(ctx, nameof(ctx));

            host.CheckValue(input, nameof(input));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: Number of bytes the load method was serialized to
            // byte[n]: The serialized load method info
            // <arbitrary>: Arbitrary bytes saved by the save action

            var loadMethodBytes = ctx.Reader.ReadByteArray();

            host.CheckDecode(Utils.Size(loadMethodBytes) > 0);
            // Attempt to reconstruct the method.
            Exception error;
            var       loadFunc = DeserializeStaticDelegateOrNull(host, loadMethodBytes, out error);

            if (loadFunc == null)
            {
                host.AssertValue(error);
                throw error;
            }

            var bytes = ctx.Reader.ReadByteArray() ?? new byte[0];

            using (var ms = new MemoryStream(bytes))
                using (var reader = new BinaryReader(ms))
                {
                    var result = loadFunc(reader, env, input);
                    env.Check(result != null, "Load method returned null");
                    return(result);
                }
        }
예제 #17
0
        public XGBoostScalarRowMapperBase(RoleMappedSchema schema, XGBoostPredictorBase <TOutput> parent, IHostEnvironment env, ISchema outputSchema)
        {
            Contracts.AssertValue(env, "env");
            env.AssertValue(schema, "schema");
            env.AssertValue(parent, "parent");
            env.AssertValue(schema.Feature, "schema");

            // REVIEW xadupre: only one feature columns is allowed.
            // This should be revisited in the future.
            // XGBoost has plans for others types.
            // Look at https://github.com/dmlc/xgboost/issues/874.
            env.Check(schema.Feature != null, "Unexpected number of feature columns, 1 expected.");

            _parent = parent;
            var columns = new[] { schema.Feature };
            var fc      = new[] { new KeyValuePair <RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, columns[0].Name) };

            _inputSchema  = new RoleMappedSchema(schema.Schema, fc);
            _outputSchema = outputSchema;

            _inputCols = new List <int>();
            foreach (var kvp in columns)
            {
                int index;
                if (schema.Schema.TryGetColumnIndex(kvp.Name, out index))
                {
                    _inputCols.Add(index);
                }
                else
                {
                    Contracts.Assert(false);
                }
            }

            _booster = _parent.GetBooster();
        }
예제 #18
0
        // REVIEW: It would be nice to support propagation of select metadata.
        public static IDataView Create <TSrc, TDst>(IHostEnvironment env, string name, IDataView input,
                                                    string src, string dst, ColumnType typeSrc, ColumnType typeDst, ValueMapper <TSrc, TDst> mapper,
                                                    ValueGetter <VBuffer <ReadOnlyMemory <char> > > keyValueGetter = null, ValueGetter <VBuffer <ReadOnlyMemory <char> > > slotNamesGetter = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckNonEmpty(name, nameof(name));
            env.CheckValue(input, nameof(input));
            env.CheckNonEmpty(src, nameof(src));
            env.CheckNonEmpty(dst, nameof(dst));
            env.CheckValue(typeSrc, nameof(typeSrc));
            env.CheckValue(typeDst, nameof(typeDst));
            env.CheckValue(mapper, nameof(mapper));
            env.Check(keyValueGetter == null || typeDst.GetItemType() is KeyType);
            env.Check(slotNamesGetter == null || typeDst.IsKnownSizeVector());

            if (typeSrc.RawType != typeof(TSrc))
            {
                throw env.ExceptParam(nameof(mapper),
                                      "The source column type '{0}' doesn't match the input type of the mapper", typeSrc);
            }
            if (typeDst.RawType != typeof(TDst))
            {
                throw env.ExceptParam(nameof(mapper),
                                      "The destination column type '{0}' doesn't match the output type of the mapper", typeDst);
            }

            bool tmp = input.Schema.TryGetColumnIndex(src, out int colSrc);

            if (!tmp)
            {
                throw env.ExceptParam(nameof(src), "The input data doesn't have a column named '{0}'", src);
            }
            var typeOrig = input.Schema[colSrc].Type;

            // REVIEW: Ideally this should support vector-type conversion. It currently doesn't.
            bool     ident;
            Delegate conv;

            if (typeOrig.SameSizeAndItemType(typeSrc))
            {
                ident = true;
                conv  = null;
            }
            else if (!Conversions.Instance.TryGetStandardConversion(typeOrig, typeSrc, out conv, out ident))
            {
                throw env.ExceptParam(nameof(mapper),
                                      "The type of column '{0}', '{1}', cannot be converted to the input type of the mapper '{2}'",
                                      src, typeOrig, typeSrc);
            }

            var       col = new Column(src, dst);
            IDataView impl;

            if (ident)
            {
                impl = new Impl <TSrc, TDst, TDst>(env, name, input, col, typeDst, mapper, keyValueGetter: keyValueGetter, slotNamesGetter: slotNamesGetter);
            }
            else
            {
                Func <IHostEnvironment, string, IDataView, Column, ColumnType, ValueMapper <int, int>,
                      ValueMapper <int, int>, ValueGetter <VBuffer <ReadOnlyMemory <char> > >, ValueGetter <VBuffer <ReadOnlyMemory <char> > >,
                      Impl <int, int, int> > del = CreateImpl <int, int, int>;
                var meth = del.GetMethodInfo().GetGenericMethodDefinition()
                           .MakeGenericMethod(typeOrig.RawType, typeof(TSrc), typeof(TDst));
                impl = (IDataView)meth.Invoke(null, new object[] { env, name, input, col, typeDst, conv, mapper, keyValueGetter, slotNamesGetter });
            }

            return(new OpaqueDataView(impl));
        }
        private static RoleMappedData GetDataRoles(IHostEnvironment env, Arguments input)
        {
            var roles = new List <KeyValuePair <RoleMappedSchema.ColumnRole, string> >();

            if (input.LabelColumns != null)
            {
                env.Check(input.LabelColumns.Length == 1, "LabelColumns expected one column name to be specified.");
                roles.Add(RoleMappedSchema.ColumnRole.Label.Bind(input.LabelColumns[0]));
            }

            if (input.GroupColumns != null)
            {
                env.Check(input.GroupColumns.Length == 1, "GroupColumns expected one column name to be specified.");
                roles.Add(RoleMappedSchema.ColumnRole.Group.Bind(input.GroupColumns[0]));
            }

            if (input.WeightColumns != null)
            {
                env.Check(input.WeightColumns.Length == 1, "WeightColumns expected one column name to be specified.");
                roles.Add(RoleMappedSchema.ColumnRole.Weight.Bind(input.WeightColumns[0]));
            }

            if (input.NameColumns != null)
            {
                env.Check(input.NameColumns.Length == 1, "NameColumns expected one column name to be specified.");
                roles.Add(RoleMappedSchema.ColumnRole.Name.Bind(input.NameColumns[0]));
            }

            if (input.NumericFeatureColumns != null)
            {
                var numericFeature = new RoleMappedSchema.ColumnRole(ColumnPurpose.NumericFeature.ToString());
                foreach (var colName in input.NumericFeatureColumns)
                {
                    var item = numericFeature.Bind(colName);
                    roles.Add(item);
                }
            }

            if (input.CategoricalFeatureColumns != null)
            {
                var categoricalFeature = new RoleMappedSchema.ColumnRole(ColumnPurpose.CategoricalFeature.ToString());
                foreach (var colName in input.CategoricalFeatureColumns)
                {
                    var item = categoricalFeature.Bind(colName);
                    roles.Add(item);
                }
            }

            if (input.TextFeatureColumns != null)
            {
                var textFeature = new RoleMappedSchema.ColumnRole(ColumnPurpose.TextFeature.ToString());
                foreach (var colName in input.TextFeatureColumns)
                {
                    var item = textFeature.Bind(colName);
                    roles.Add(item);
                }
            }

            if (input.ImagePathColumns != null)
            {
                var imagePath = new RoleMappedSchema.ColumnRole(ColumnPurpose.ImagePath.ToString());
                foreach (var colName in input.ImagePathColumns)
                {
                    var item = imagePath.Bind(colName);
                    roles.Add(item);
                }
            }

            return(new RoleMappedData(input.TrainingData, roles));
        }
예제 #20
0
        public static CommonOutputs.MacroOutput <Output> PipelineSweep(
            IHostEnvironment env,
            Arguments input,
            EntryPointNode node)
        {
            env.Check(input.StateArguments != null || input.State is AutoInference.AutoMlMlState,
                      "Must have a valid AutoML State, or pass arguments to create one.");
            env.Check(input.BatchSize > 0, "Batch size must be > 0.");

            // If no current state, create object and set data.
            if (input.State == null)
            {
                input.State = input.StateArguments?.CreateComponent(env);

                if (input.State is AutoInference.AutoMlMlState inState)
                {
                    inState.SetTrainTestData(input.TrainingData, input.TestingData);
                }
                else
                {
                    throw env.Except($"Incompatible type. Expecting type {typeof(AutoInference.AutoMlMlState)}, received type {input.State?.GetType()}.");
                }

                var result = node.AddNewVariable("State", input.State);
                node.Context.AddInputVariable(result.Item2, typeof(IMlState));
            }
            var autoMlState = (AutoInference.AutoMlMlState)input.State;

            // The indicators are just so the macro knows those pipelines need to
            // be run before performing next expansion. If we add them as inputs
            // to the next iteration, the next iteration cannot run until they have
            // their values set. Thus, indicators are needed.
            var pipelineIndicators = new List <Var <IDataView> >();

            var expNodes = new List <EntryPointNode>();

            // Keep versions of the training and testing var names
            var training = new Var <IDataView> {
                VarName = node.GetInputVariable("TrainingData").VariableName
            };
            var testing = new Var <IDataView> {
                VarName = node.GetInputVariable("TestingData").VariableName
            };
            var amlsVarObj =
                new Var <IMlState>()
            {
                VarName = node.GetInputVariable(nameof(input.State)).VariableName
            };

            // Make sure search space is defined. If not, infer,
            // with default number of transform levels.
            if (!autoMlState.IsSearchSpaceDefined())
            {
                autoMlState.InferSearchSpace(numTransformLevels: 1);
            }

            // Extract performance summaries and assign to previous candidate pipelines.
            foreach (var pipeline in autoMlState.BatchCandidates)
            {
                if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId),
                                                out var v))
                {
                    pipeline.PerformanceSummary =
                        AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name);
                    autoMlState.AddEvaluated(pipeline);
                }
            }

            node.OutputMap.TryGetValue("Results", out string outDvName);
            var outDvVar = new Var <IDataView>()
            {
                VarName = outDvName
            };

            node.OutputMap.TryGetValue("State", out string outStateName);
            var outStateVar = new Var <IMlState>()
            {
                VarName = outStateName
            };

            // Get next set of candidates.
            var candidatePipelines = autoMlState.GetNextCandidates(input.BatchSize);

            // Check if termination condition was met, i.e. no more candidates were returned.
            // If so, end expansion and add a node to extract the sweep result.
            if (candidatePipelines == null || candidatePipelines.Length == 0)
            {
                // Add a node to extract the sweep result.
                return(new CommonOutputs.MacroOutput <Output>()
                {
                    Nodes = expNodes
                });
            }

            // Prep all returned candidates
            foreach (var p in candidatePipelines)
            {
                // Add train test experiments to current graph for candidate pipeline
                var subgraph        = new Experiment(env);
                var trainTestOutput = p.AddAsTrainTest(training, testing, autoMlState.TrainerKind, subgraph);

                // Change variable name to reference pipeline ID in output map, context and entrypoint output.
                var uniqueName = ExperimentUtils.GenerateOverallMetricVarName(p.UniqueId);
                var sgNode     = EntryPointNode.ValidateNodes(env, node.Context,
                                                              new JArray(subgraph.GetNodes().Last()), node.Catalog).Last();
                sgNode.RenameOutputVariable(trainTestOutput.OverallMetrics.VarName, uniqueName, cascadeChanges: true);
                trainTestOutput.OverallMetrics.VarName = uniqueName;
                expNodes.Add(sgNode);

                // Store indicators, to pass to next iteration of macro.
                pipelineIndicators.Add(trainTestOutput.OverallMetrics);
            }

            return(new CommonOutputs.MacroOutput <Output>()
            {
                Nodes = expNodes
            });
        }
예제 #21
0
 public string GetColumnName(int col)
 {
     _env.Check(0 <= col && col < ColumnCount);
     return(_shape[col].Name);
 }