コード例 #1
0
        public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context,
                              string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
                              string stageId = "", float cost = float.NaN)
        {
            Contracts.AssertValue(env);
            env.AssertNonEmpty(id);
            _host = env.Register(id);
            _host.AssertValue(context);
            _host.AssertNonEmpty(entryPointName);
            _host.AssertValue(moduleCatalog);
            _host.AssertValueOrNull(inputs);
            _host.AssertValueOrNull(outputs);

            _context = context;
            _catalog = moduleCatalog;

            Id = id;
            if (!moduleCatalog.TryFindEntryPoint(entryPointName, out _entryPoint))
            {
                throw _host.Except($"Entry point '{entryPointName}' not found");
            }

            // Validate inputs.
            _inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
            _inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            _inputBuilder    = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog);
            // REVIEW: This logic should move out of Node eventually and be delegated to
            // a class that can nest to handle Components with variables.
            if (inputs != null)
            {
                foreach (var pair in inputs)
                {
                    CheckAndSetInputValue(pair);
                }
            }
            var missing = _inputBuilder.GetMissingValues().Except(_inputBindingMap.Keys).ToArray();

            if (missing.Length > 0)
            {
                throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
            }

            // Validate outputs.
            _outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
            _outputMap    = new Dictionary <string, string>();
            if (outputs != null)
            {
                foreach (var pair in outputs)
                {
                    CheckAndMarkOutputValue(pair);
                }
            }

            Checkpoint = checkpoint;
            StageId    = stageId;
            Cost       = cost;
        }
コード例 #2
0
            protected void SendTelemetryComponent(IPipe <TelemetryMessage> pipe, SubComponent sub)
            {
                Host.AssertValue(pipe);
                Host.AssertValueOrNull(sub);

                if (sub.IsGood())
                {
                    pipe.Send(TelemetryMessage.CreateTrainer(sub.Kind, sub.SubComponentSettings));
                }
            }
コード例 #3
0
        public static void SaveAll(IHost host, ModelSaveContext ctx, int infoLim, VBuffer <ReadOnlyMemory <char> >[] keyValues)
        {
            Contracts.AssertValue(host);
            host.AssertValue(ctx);
            host.AssertValueOrNull(keyValues);

            if (keyValues == null)
            {
                return;
            }

            using (var ch = host.Start("SaveTextValues"))
            {
                // Save the key names as separate submodels.
                const string dirFormat = "Vocabulary_{0:000}";
                CodecFactory factory   = new CodecFactory(host);

                for (int iinfo = 0; iinfo < infoLim; iinfo++)
                {
                    if (keyValues[iinfo].Length == 0)
                    {
                        continue;
                    }
                    ctx.SaveSubModel(string.Format(dirFormat, iinfo),
                                     c => Save(ch, c, factory, ref keyValues[iinfo]));
                }
                ch.Done();
            }
        }
コード例 #4
0
        /// <summary>
        /// Get the calibration summary in INI format
        /// </summary>
        private static string AddCalibrationToIni(IHost host, string ini, ICalibrator calibrator)
        {
            host.AssertValue(ini);
            host.AssertValueOrNull(calibrator);

            if (calibrator == null)
            {
                return(ini);
            }

            if (calibrator is PlattCalibrator)
            {
                string calibratorEvaluatorIni = IniFileUtils.GetCalibratorEvaluatorIni(ini, calibrator as PlattCalibrator);
                return(IniFileUtils.AddEvaluator(ini, calibratorEvaluatorIni));
            }
            else
            {
                StringBuilder newSection = new StringBuilder();
                newSection.AppendLine();
                newSection.AppendLine();
                newSection.AppendLine("[TLCCalibration]");
                newSection.AppendLine("Type=" + calibrator.GetType().Name);
                return(ini + newSection);
            }
        }
コード例 #5
0
        private AppendRowsDataView(IHostEnvironment env, Schema schema, IDataView[] sources)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);

            _host.AssertValueOrNull(schema);
            _host.AssertValue(sources);
            _host.Assert(sources.Length >= 2);

            _sources = sources;
            _schema  = schema ?? _sources[0].Schema;

            CheckSchemaConsistency();

            _canShuffle = true;
            _counts     = new int[_sources.Length];
            for (int i = 0; i < _sources.Length; i++)
            {
                IDataView dv = _sources[i];
                if (!dv.CanShuffle)
                {
                    _canShuffle = false;
                    _counts     = null;
                    break;
                }
                long?count = dv.GetRowCount();
                if (count == null || count < 0 || count > int.MaxValue)
                {
                    _canShuffle = false;
                    _counts     = null;
                    break;
                }
                _counts[i] = (int)count;
            }
        }
コード例 #6
0
            protected AggregatorBase(IHostEnvironment env, string stratName)
            {
                Contracts.AssertValue(env);
                Host = env.Register("Aggregator");
                Host.AssertValueOrNull(stratName);

                PassNum   = -1;
                StratName = stratName;
            }
コード例 #7
0
        private static ILegacyDataLoader CreateCore(IHost host, ILegacyDataLoader srcLoader,
                                                    KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader, "srcLoader");
            host.AssertValueOrNull(transformArgs);

            if (Utils.Size(transformArgs) == 0)
            {
                return(srcLoader);
            }
コード例 #8
0
        public DataViewRowCursor GetRowCursor(IEnumerable <DataViewSchema.Column> columnsNeeded, Random rand = null)
        {
            Host.AssertValueOrNull(rand);
            var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);

            Func <int, bool> inputPred = TypedSrc.GetDependencies(predicate);

            var inputCols = Input.Schema.Where(x => inputPred(x.Index));
            var input     = Input.GetRowCursor(inputCols, rand);

            return(GetRowCursorCore(input, Utils.BuildArray(Input.Schema.Count, inputCols)));
        }
コード例 #9
0
        /// <summary>
        /// Create a filter transform
        /// </summary>
        /// <param name="env">The host environment</param>
        /// <param name="source">The dataview upon which we construct the transform</param>
        /// <param name="filterFunc">The function by which we transform source to destination columns and decide whether
        /// to keep the row.</param>
        /// <param name="initStateAction">The function that is called once per cursor to initialize state. Can be null.</param>
        /// <param name="inputSchemaDefinition">The schema definition overrides for <typeparamref name="TSrc"/></param>
        /// <param name="outputSchemaDefinition">The schema definition overrides for <typeparamref name="TDst"/></param>
        public StatefulFilterTransform(IHostEnvironment env, IDataView source, Func <TSrc, TDst, TState, bool> filterFunc,
                                       Action <TState> initStateAction,
                                       SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
        {
            _host = env.Register(RegistrationName);
            _host.AssertValue(source, "source");
            _host.AssertValue(filterFunc, "filterFunc");
            _host.AssertValueOrNull(initStateAction);
            _host.AssertValueOrNull(inputSchemaDefinition);
            _host.AssertValueOrNull(outputSchemaDefinition);

            _source                = source;
            _filterFunc            = filterFunc;
            _initStateAction       = initStateAction;
            _inputSchemaDefinition = inputSchemaDefinition;
            _typedSource           = TypedCursorable <TSrc> .Create(_host, Source, false, inputSchemaDefinition);

            var outSchema = InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition);

            _addedSchema = outSchema;
            _bindings    = new ColumnBindings(Source.Schema, DataViewConstructionUtils.GetSchemaColumns(outSchema));
        }
コード例 #10
0
            public DataViewRowCursor GetRowCursor(IEnumerable <DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                _host.AssertValueOrNull(rand);

                // Build out the active state for the input
                var inputCols      = ((IRowToRowMapper)this).GetDependencies(columnsNeeded);
                var inputRowCursor = Source.GetRowCursor(inputCols, rand);

                // Build the active state for the output
                var active = Utils.BuildArray(_mapper.OutputSchema.Count, columnsNeeded);

                return(new Cursor(_host, _mapper, inputRowCursor, active));
            }
コード例 #11
0
            public IRowCursor GetRowCursor(Func <int, bool> needCol, IRandom rand = null)
            {
                _host.AssertValue(needCol, nameof(needCol));
                _host.AssertValueOrNull(rand);

                // Build out the active state for the input
                var inputPred      = GetDependencies(needCol);
                var inputRowCursor = Source.GetRowCursor(inputPred, rand);

                // Build the active state for the output
                var active = Utils.BuildArray(_mapper.Schema.ColumnCount, needCol);

                return(new RowCursor(_host, _mapper, inputRowCursor, active));
            }
コード例 #12
0
            protected void SendTelemetryComponent(IPipe <TelemetryMessage> pipe, IComponentFactory factory)
            {
                Host.AssertValue(pipe);
                Host.AssertValueOrNull(factory);

                if (factory is ICommandLineComponentFactory commandLineFactory)
                {
                    pipe.Send(TelemetryMessage.CreateTrainer(commandLineFactory.Name, commandLineFactory.GetSettingsString()));
                }
                else
                {
                    pipe.Send(TelemetryMessage.CreateTrainer("Unknown", "Non-ICommandLineComponentFactory object"));
                }
            }
コード例 #13
0
            private LabelNameBindableMapper(IHostEnvironment env, ISchemaBindableMapper bindable, VectorType type, Delegate getter,
                                            string metadataKind, Func <ISchemaBoundMapper, ColumnType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                _host = env.Register(LoaderSignature);
                _host.AssertValue(bindable);
                _host.AssertValue(type);
                _host.AssertValue(getter);
                _host.AssertNonEmpty(metadataKind);
                _host.AssertValueOrNull(canWrap);

                _bindable     = bindable;
                _type         = type;
                _getter       = getter;
                _metadataKind = metadataKind;
                _canWrap      = canWrap;
            }
コード例 #14
0
            public DataViewRowCursor GetRowCursor(IEnumerable <DataViewSchema.Column> columnsNeeded, Random rand = null)
            {
                var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);

                _host.AssertValueOrNull(rand);

                // Build out the active state for the input
                var inputPred = GetDependencies(predicate);
                var inputCols = Source.Schema.Where(x => inputPred(x.Index));

                var inputRowCursor = Source.GetRowCursor(inputCols, rand);

                // Build the active state for the output
                var active = Utils.BuildArray(_mapper.OutputSchema.Count, columnsNeeded);

                return(new Cursor(_host, _mapper, inputRowCursor, active));
            }
コード例 #15
0
        private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader,
                                              KeyValuePair <string, IComponentFactory <IDataView, IDataTransform> >[] transformArgs)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader, "srcLoader");
            host.AssertValueOrNull(transformArgs);

            if (Utils.Size(transformArgs) == 0)
            {
                return(srcLoader);
            }

            string GetTagData(IComponentFactory <IDataView, IDataTransform> factory)
            {
                // When coming from the command line, preserve the string arguments.
                // For other factories, we aren't able to get the string.
                return((factory as ICommandLineComponentFactory)?.ToString());
            }

            var tagData = transformArgs
                          .Select(x => new KeyValuePair <string, string>(x.Key, GetTagData(x.Value)))
                          .ToArray();

            // Warn if tags coincide with ones already present in the loader.
            var composite = srcLoader as CompositeDataLoader;

            if (composite != null)
            {
                using (var ch = host.Start("TagValidation"))
                {
                    foreach (var pair in tagData)
                    {
                        if (!string.IsNullOrEmpty(pair.Key) && composite._transforms.Any(x => x.Tag == pair.Key))
                        {
                            ch.Warning("The transform with tag '{0}' already exists in the chain", pair.Key);
                        }
                    }

                    ch.Done();
                }
            }

            return(ApplyTransformsCore(host, srcLoader, tagData,
                                       (env, index, data) => transformArgs[index].Value.CreateComponent(env, data)));
        }
コード例 #16
0
        private void TrainCore(IChannel ch, IProgressChannel pch, Dataset dtrain, CategoricalMetaData catMetaData, Dataset dvalid = null)
        {
            Host.AssertValue(ch);
            Host.AssertValue(pch);
            Host.AssertValue(dtrain);
            Host.AssertValueOrNull(dvalid);
            // For multi class, the number of labels is required.
            ch.Assert(PredictionKind != PredictionKind.MultiClassClassification || Options.ContainsKey("num_class"),
                      "LightGBM requires the number of classes to be specified in the parameters.");

            // Only enable one trainer to run at one time.
            lock (LightGbmShared.LockForMultiThreadingInside)
            {
                ch.Info("LightGBM objective={0}", Options["objective"]);
                using (Booster bst = WrappedLightGbmTraining.Train(ch, pch, Options, dtrain,
                                                                   dvalid: dvalid, numIteration: Args.NumBoostRound,
                                                                   verboseEval: Args.VerboseEval, earlyStoppingRound: Args.EarlyStoppingRound))
                {
                    TrainedEnsemble = bst.GetModel(catMetaData.CategoricalBoudaries);
                }
            }
        }
コード例 #17
0
        private static IDataLoader CreateCore(IHost host, IDataLoader srcLoader,
                                              KeyValuePair <string, SubComponent <IDataTransform, SignatureDataTransform> >[] transformArgs)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader, "srcLoader");
            host.AssertValueOrNull(transformArgs);

            if (Utils.Size(transformArgs) == 0)
            {
                return(srcLoader);
            }

            var tagData = transformArgs
                          .Select(x => new KeyValuePair <string, string>(x.Key, x.Value.ToString()))
                          .ToArray();

            // Warn if tags coincide with ones already present in the loader.
            var composite = srcLoader as CompositeDataLoader;

            if (composite != null)
            {
                using (var ch = host.Start("TagValidation"))
                {
                    foreach (var pair in tagData)
                    {
                        if (!string.IsNullOrEmpty(pair.Key) && composite._transforms.Any(x => x.Tag == pair.Key))
                        {
                            ch.Warning("The transform with tag '{0}' already exists in the chain", pair.Key);
                        }
                    }

                    ch.Done();
                }
            }

            return(ApplyTransformsCore(host, srcLoader, tagData,
                                       (prov, index, data) => transformArgs[index].Value.CreateInstance(prov, data)));
        }
コード例 #18
0
        private ColumnCodec[] GetActiveColumns(Schema schema, int[] colIndices)
        {
            _host.AssertValue(schema);
            _host.AssertValueOrNull(colIndices);

            ColumnCodec[] activeSourceColumns = new ColumnCodec[Utils.Size(colIndices)];
            if (Utils.Size(colIndices) == 0)
            {
                return(activeSourceColumns);
            }

            for (int c = 0; c < colIndices.Length; ++c)
            {
                ColumnType  type = schema[colIndices[c]].Type;
                IValueCodec codec;
                if (!_factory.TryGetCodec(type, out codec))
                {
                    throw _host.Except("Could not get codec for requested column {0} of type {1}", schema[c].Name, type);
                }
                _host.Assert(type.Equals(codec.Type));
                activeSourceColumns[c] = new ColumnCodec(colIndices[c], codec);
            }
            return(activeSourceColumns);
        }