Example #1
        /// <summary>
        /// Loads data view (loader and transforms) from <paramref name="rep"/> if <paramref name="loadTransforms"/> is set to true,
        /// otherwise loads loader only.
        /// </summary>
        public static IDataLoader LoadLoader(IHostEnvironment env, RepositoryReader rep, IMultiStreamSource files, bool loadTransforms)
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(rep, nameof(rep));
            env.CheckValue(files, nameof(files));

            IDataLoader loader;

            // If loadTransforms is false, load the loader only, not the transforms.
            Repository.Entry ent = null;
            string           dir = "";

            if (!loadTransforms)
                ent = rep.OpenEntryOrNull(dir = Path.Combine(DirDataLoaderModel, "Loader"), ModelLoadContext.ModelStreamName);

            if (ent == null) // either loadTransforms is true, or it's not a composite loader
                ent = rep.OpenEntry(dir = DirDataLoaderModel, ModelLoadContext.ModelStreamName);

            env.CheckDecode(ent != null, "Loader is not found.");
            using (ent)
                env.Assert(ent.Stream.Position == 0);
                ModelLoadContext.LoadModel <IDataLoader, SignatureLoadDataLoader>(env, out loader, rep, ent, dir, files);
Example #2
        // When the label column is not a key, we check that the number of classes is the same for all the predictors, by checking the
        // OutputType property of the IValueMapper.
        // If any of the predictors do not implement IValueMapper we throw an exception. Returns the class count.
        private static int CheckNonKeyLabelColumnCore(IHostEnvironment env, IPredictor pred, IPredictorModel[] models, bool isBinary, ColumnType labelType)

            if (isBinary)

            // The label is numeric, we just have to check that the number of classes is the same.
            if (!(pred is IValueMapper vm))
                throw env.Except("Cannot determine the number of classes the predictor outputs");
            var classCount = vm.OutputType.VectorSize;

            for (int i = 1; i < models.Length; i++)
                var model = models[i];
                var edv   = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, edv, out RoleMappedData rmd, out pred);
                vm = pred as IValueMapper;
                if (vm.OutputType.VectorSize != classCount)
                    throw env.Except("Label of model {0} has different number of classes than model 0", i);
            public void KeepSelectedLearners(IEnumerable <string> learnersToKeep)
                var allLearners = RecipeInference.AllowedLearners(_env, TrainerKind);

                _availableLearners = allLearners.Where(l => learnersToKeep.Contains(l.LearnerName)).ToArray();
Example #4
        public EntryPointNode(IHostEnvironment env, ModuleCatalog moduleCatalog, RunContext context,
                              string id, string entryPointName, JObject inputs, JObject outputs, bool checkpoint = false,
                              string stageId = "", float cost = float.NaN)
            _host = env.Register(id);

            _context = context;
            _catalog = moduleCatalog;

            Id = id;
            if (!moduleCatalog.TryFindEntryPoint(entryPointName, out _entryPoint))
                throw _host.Except($"Entry point '{entryPointName}' not found");

            // Validate inputs.
            _inputMap        = new Dictionary <ParameterBinding, VariableBinding>();
            _inputBindingMap = new Dictionary <string, List <ParameterBinding> >();
            _inputBuilder    = new InputBuilder(_host, _entryPoint.InputType, moduleCatalog);
            // REVIEW: This logic should move out of Node eventually and be delegated to
            // a class that can nest to handle Components with variables.
            if (inputs != null)
                foreach (var pair in inputs)
            var missing = _inputBuilder.GetMissingValues().Except(_inputBindingMap.Keys).ToArray();

            if (missing.Length > 0)
                throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");

            // Validate outputs.
            _outputHelper = new OutputHelper(_host, _entryPoint.OutputType);
            _outputMap    = new Dictionary <string, string>();
            if (outputs != null)
                foreach (var pair in outputs)

            Checkpoint = checkpoint;
            StageId    = stageId;
            Cost       = cost;
Example #5
            public static ISchemaBoundMapper CreateBound <T>(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorType type, Delegate getter,
                                                             string metadataKind, Func <ISchemaBoundMapper, ColumnType, bool> canWrap)
                env.Assert(getter is ValueGetter <VBuffer <T> >);

                return(new Bound <T>(env, mapper, type, (ValueGetter <VBuffer <T> >)getter, metadataKind, canWrap));
        private void AppendFormattedText(StringBuilder builder, string text, string indent, int screenWidth)
            _env.Assert(screenWidth > 0);

            int textIdx = 0;

            while (textIdx < text.Length)
                int screenLeft  = screenWidth - indent.Length;
                int summaryLeft = text.Length - textIdx;
                if (summaryLeft <= screenLeft)
                    builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine();

                int spaceIdx = text.LastIndexOf(' ', screenLeft + textIdx, screenLeft);
                if (spaceIdx < 0)
                    // Print to the first space.
                    int startIdx = screenLeft + textIdx + 1;
                    spaceIdx = text.IndexOf(' ', startIdx, text.Length - startIdx);
                    if (spaceIdx < 0)
                        // Print to the end.
                        builder.Append(indent).Append(text, textIdx, summaryLeft).AppendLine();

                int appendCount = spaceIdx - textIdx;
                builder.Append(indent).Append(text, textIdx, appendCount).AppendLine();
                textIdx += appendCount + 1;
        public XmlGenerator(IHostEnvironment env, Arguments args, string regenerate)
            Contracts.CheckValue(env, nameof(env));
            env.AssertValue(args, nameof(args));
            env.AssertNonEmpty(regenerate, nameof(regenerate));

            _xmlFilename = args.XmlFilename;
            if (!string.IsNullOrWhiteSpace(_xmlFilename))
                Utils.CheckOptionalUserDirectory(_xmlFilename, nameof(args.XmlFilename));
                _xmlFilename = null;
            _host = env.Register("XML Generator");
        // Checks that all the label columns of the model have the same key type as their label column - including the same
        // cardinality and the same key values, and returns the cardinality of the label column key.
        private static int CheckKeyLabelColumnCore <T>(IHostEnvironment env, PredictorModel[] models, KeyType labelType, Schema schema, int labelIndex, VectorType keyValuesType)
            where T : IEquatable <T>
            env.Assert(keyValuesType.ItemType.RawType == typeof(T));
            var labelNames = default(VBuffer <T>);

            schema[labelIndex].GetKeyValues(ref labelNames);
            var classCount = labelNames.Length;

            var curLabelNames = default(VBuffer <T>);

            for (int i = 1; i < models.Length; i++)
                var model = models[i];
                var edv   = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred);
                var labelInfo = rmd.Schema.Label.HasValue;
                if (!rmd.Schema.Label.HasValue)
                    throw env.Except("Training schema for model {0} does not have a label column", i);
                var labelCol = rmd.Schema.Label.Value;

                var curLabelType = labelCol.Type as KeyType;
                if (!labelType.Equals(curLabelType))
                    throw env.Except("Label column of model {0} has different type than model 0", i);

                var mdType = labelCol.Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type;
                if (!mdType.Equals(keyValuesType))
                    throw env.Except("Label column of model {0} has different key value type than model 0", i);
                labelCol.GetKeyValues(ref curLabelNames);
                if (!AreEqual(in labelNames, in curLabelNames))
                    throw env.Except("Label of model {0} has different values than model 0", i);
Example #9
        private static void GetPipeline(IHostEnvironment env, InputBase input, out IDataView startingData, out RoleMappedData transformedData)

            ISchema inputSchema = null;

            startingData    = null;
            transformedData = null;
            byte[][] transformedDataSerialized    = null;
            string[] transformedDataZipEntryNames = null;
            for (int i = 0; i < input.Models.Length; i++)
                var model = input.Models[i];

                var inputData = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, inputData, out RoleMappedData transformedDataCur, out IPredictor pred);

                if (inputSchema == null)
                    env.Assert(i == 0);
                    inputSchema     = model.TransformModel.InputSchema;
                    startingData    = inputData;
                    transformedData = transformedDataCur;
                else if (input.ValidatePipelines)
                    using (var ch = env.Start("Validating pipeline"))
                        if (transformedDataSerialized == null)
                            ch.Assert(transformedDataZipEntryNames == null);
                            SerializeRoleMappedData(env, ch, transformedData, out transformedDataSerialized,
                                                    out transformedDataZipEntryNames);
                        CheckSamePipeline(env, ch, transformedDataCur, transformedDataSerialized, transformedDataZipEntryNames);