예제 #1
0
        /// <summary>
        /// Loads data view (loader and transforms) from <paramref name="rep"/> if <paramref name="loadTransforms"/> is set to true,
        /// otherwise loads loader only.
        /// </summary>
        public static IDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool loadTransforms)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            env.CheckValue(files, nameof(files));

            IDataLoader loader;

            // If loadTransforms is false, load the loader only, not the transforms.
            Repository.Entry ent = null;
            string           dir = "";

            if (!loadTransforms)
            {
                ent = rep.OpenEntryOrNull(dir = Path.Combine(DirDataLoaderModel, "Loader"), ModelLoadContext.ModelStreamName);
            }

            if (ent == null) // either loadTransforms is true, or it's not a composite loader
            {
                ent = rep.OpenEntry(dir = DirDataLoaderModel, ModelLoadContext.ModelStreamName);
            }

            env.CheckDecode(ent != null, "Loader is not found.");
            env.AssertNonEmpty(dir);
            using (ent)
            {
                env.Assert(ent.Stream.Position == 0);
                ModelLoadContext.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, dir, files);
            }
            return(loader);
        }
예제 #2
0
        // When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the
        // OutputType property of the IValueMapper.
        // If any of the predictors do not implement IValueMapper we throw an exception. Returns the class count.
        private static int CheckNonKeyLabelColumnCore(IHostEnvironment env, IPredictor pred, IPredictorModel[] models, bool isBinary, ColumnType labelType)
        {
            env.Assert(!labelType.IsKey);
            env.AssertNonEmpty(models);

            if (isBinary)
            {
                return(2);
            }

            // The label is numeric, we just have to check that the number of classes is the same.
            if (!(pred is IValueMapper vm))
            {
                throw env.Except("Cannot determine the number of classes the predictor outputs");
            }
            var classCount = vm.OutputType.VectorSize;

            for (int i = 1; i < models.Length; i++)
            {
                var model = models[i];
                var edv   = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, edv, out RoleMappedData rmd, out pred);
                vm = pred as IValueMapper;
                if (vm.OutputType.VectorSize != classCount)
                {
                    throw env.Except("Label of model {0} has different number of classes than model 0", i);
                }
            }
            return(classCount);
        }
            public void KeepSelectedLearners(IEnumerable <string> learnersToKeep)
            {
                var allLearners = RecipeInference.AllowedLearners(_env, TrainerKind);

                _env.AssertNonEmpty(allLearners);
                _availableLearners = allLearners.Where(l => learnersToKeep.Contains(l.LearnerName)).ToArray();
                AutoMlEngine.UpdateLearners(_availableLearners);
            }
예제 #4
0
        public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context,
                              string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
                              string stageId = "", float cost = float.NaN)
        {
            Contracts.AssertValue(env);
            env.AssertNonEmpty(id);
            _host = env.Register(id);
            _host.AssertValue(context);
            _host.AssertNonEmpty(entryPointName);
            _host.AssertValue(moduleCatalog);
            _host.AssertValueOrNull(inputs);
            _host.AssertValueOrNull(outputs);

            _context = context;
            _catalog = moduleCatalog;

            Id = id;
            if (!moduleCatalog.TryFindEntryPoint(entryPointName, out _entryPoint))
            {
                throw _host.Except($"Entry point '{entryPointName}' not found");
            }

            // Validate inputs.
            _inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
            _inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            _inputBuilder    = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog);
            // REVIEW: This logic should move out of Node eventually and be delegated to
            // a class that can nest to handle Components with variables.
            if (inputs != null)
            {
                foreach (var pair in inputs)
                {
                    CheckAndSetInputValue(pair);
                }
            }
            var missing = _inputBuilder.GetMissingValues().Except(_inputBindingMap.Keys).ToArray();

            if (missing.Length > 0)
            {
                throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
            }

            // Validate outputs.
            _outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
            _outputMap    = new Dictionary <string, string>();
            if (outputs != null)
            {
                foreach (var pair in outputs)
                {
                    CheckAndMarkOutputValue(pair);
                }
            }

            Checkpoint = checkpoint;
            StageId    = stageId;
            Cost       = cost;
        }
예제 #5
0
            public static ISchemaBoundMapper CreateBound <T>(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter,
                                                             string metadataKind, Func <ISchemaBoundMapper, ColumnType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                env.AssertValue(mapper);
                env.AssertValue(type);
                env.AssertValue(getter);
                env.Assert(getter is ValueGetter <VBuffer <T> >);
                env.AssertNonEmpty(metadataKind);
                env.AssertValueOrNull(canWrap);

                return(new Bound <T>(env, mapper, type, (ValueGetter <VBuffer <T> >)getter, metadataKind, canWrap));
            }
        private void AppendFormattedText(StringBuilder builder, string text, string indent, int screenWidth)
        {
            _env.AssertValue(builder);
            _env.AssertNonEmpty(text);
            _env.AssertNonEmpty(indent);
            _env.Assert(screenWidth > 0);

            int textIdx = 0;

            while (textIdx < text.Length)
            {
                int screenLeft  = screenWidth - indent.Length;
                int summaryLeft = text.Length - textIdx;
                if (summaryLeft <= screenLeft)
                {
                    builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine();
                    break;
                }

                int spaceIdx = text.LastIndexOf(' ', screenLeft + textIdx, screenLeft);
                if (spaceIdx < 0)
                {
                    // Print to the first space.
                    int startIdx = screenLeft + textIdx + 1;
                    spaceIdx = text.IndexOf(' ', startIdx, text.Length - startIdx);
                    if (spaceIdx < 0)
                    {
                        // Print to the end.
                        builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine();
                        break;
                    }
                }

                int appendCount = spaceIdx - textIdx;
                builder.Append(indent).Append(text, textIdx, appendCount).AppendLine();
                textIdx += appendCount + 1;
            }
        }
        public XmlGenerator(IHostEnvironment env, Arguments args, string regenerate)
        {
            Contracts.CheckValue(env, nameof(env));
            env.AssertValue(args, nameof(args));
            env.AssertNonEmpty(regenerate, nameof(regenerate));

            _xmlFilename = args.XmlFilename;
            if (!string.IsNullOrWhiteSpace(_xmlFilename))
            {
                Utils.CheckOptionalUserDirectory(_xmlFilename, nameof(args.XmlFilename));
            }
            else
            {
                _xmlFilename = null;
            }
            _host = env.Register("XML Generator");
        }
        // Checks that all the label columns of the model have the same key type as their label column - including the same
        // cardinality and the same key values, and returns the cardinality of the label column key.
        private static int CheckKeyLabelColumnCore <T>(IHostEnvironment env, PredictorModel[] models, KeyType labelType, Schema schema, int labelIndex, VectorType keyValuesType)
            where T : IEquatable <T>
        {
            env.Assert(keyValuesType.ItemType.RawType == typeof(T));
            env.AssertNonEmpty(models);
            var labelNames = default(VBuffer <T>);

            schema[labelIndex].GetKeyValues(ref labelNames);
            var classCount = labelNames.Length;

            var curLabelNames = default(VBuffer <T>);

            for (int i = 1; i < models.Length; i++)
            {
                var model = models[i];
                var edv   = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred);
                var labelInfo = rmd.Schema.Label.HasValue;
                if (!rmd.Schema.Label.HasValue)
                {
                    throw env.Except("Training schema for model {0} does not have a label column", i);
                }
                var labelCol = rmd.Schema.Label.Value;

                var curLabelType = labelCol.Type as KeyType;
                if (!labelType.Equals(curLabelType))
                {
                    throw env.Except("Label column of model {0} has different type than model 0", i);
                }

                var mdType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type;
                if (!mdType.Equals(keyValuesType))
                {
                    throw env.Except("Label column of model {0} has different key value type than model 0", i);
                }
                labelCol.GetKeyValues(ref curLabelNames);
                if (!AreEqual(in labelNames, in curLabelNames))
                {
                    throw env.Except("Label of model {0} has different values than model 0", i);
                }
            }
            return(classCount);
        }
예제 #9
0
        private static void GetPipeline(IHostEnvironment env, InputBase input, out IDataView startingData, out RoleMappedData transformedData)
        {
            Contracts.AssertValue(env);
            env.AssertValue(input);
            env.AssertNonEmpty(input.Models);

            ISchema inputSchema = null;

            startingData    = null;
            transformedData = null;
            byte[][] transformedDataSerialized    = null;
            string[] transformedDataZipEntryNames = null;
            for (int i = 0; i < input.Models.Length; i++)
            {
                var model = input.Models[i];

                var inputData = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, inputData, out RoleMappedData transformedDataCur, out IPredictor pred);

                if (inputSchema == null)
                {
                    env.Assert(i == 0);
                    inputSchema     = model.TransformModel.InputSchema;
                    startingData    = inputData;
                    transformedData = transformedDataCur;
                }
                else if (input.ValidatePipelines)
                {
                    using (var ch = env.Start("Validating pipeline"))
                    {
                        if (transformedDataSerialized == null)
                        {
                            ch.Assert(transformedDataZipEntryNames == null);
                            SerializeRoleMappedData(env, ch, transformedData, out transformedDataSerialized,
                                                    out transformedDataZipEntryNames);
                        }
                        CheckSamePipeline(env, ch, transformedDataCur, transformedDataSerialized, transformedDataZipEntryNames);
                        ch.Done();
                    }
                }
            }
        }