internal BatchPredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
                                       SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(modelStream);
            Contracts.AssertValueOrNull(inputSchemaDefinition);
            Contracts.AssertValueOrNull(outputSchemaDefinition);

            // Initialize pipe.
            _srcDataView = DataViewConstructionUtils.CreateFromEnumerable(env, new TSrc[] { }, inputSchemaDefinition);

            // Load transforms.
            var pipe = env.LoadTransforms(modelStream, _srcDataView);

            // Load predictor (if present) and apply default scorer.
            // REVIEW: distinguish the case of predictor / no predictor?
            var predictor = env.LoadPredictorOrNull(modelStream);

            if (predictor != null)
            {
                var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, modelStream);
                pipe = roles != null
                    ? env.CreateDefaultScorer(RoleMappedData.CreateOpt(pipe, roles), predictor)
                    : env.CreateDefaultScorer(env.CreateExamples(pipe, "Features"), predictor);
            }

            _pipeEngine = new PipeEngine <TDst>(env, pipe, ignoreMissingColumns, outputSchemaDefinition);
        }
        public void PrepareData(IHostEnvironment env, IDataView input, out RoleMappedData roleMappedData, out IPredictor predictor)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));

            input          = _transformModel.Apply(env, input);
            roleMappedData = RoleMappedData.CreateOpt(input, _roleMappings);
            predictor      = _predictor;
        }
        public void Save(IHostEnvironment env, Stream stream)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(stream, nameof(stream));
            using (var ch = env.Start("Saving predictor model"))
            {
                // REVIEW: address the asymmetry in the way we're loading and saving the model.
                // Effectively, we have methods to load the transform model from a model.zip, but don't have
                // methods to compose the model.zip out of transform model, predictor and role mappings
                // (we use the TrainUtils.SaveModel that does all three).

                // Create the chain of transforms for saving.
                IDataView data = new EmptyDataView(env, _transformModel.InputSchema);
                data = _transformModel.Apply(env, data);
                var roleMappedData = RoleMappedData.CreateOpt(data, _roleMappings);

                TrainUtils.SaveModel(env, ch, stream, _predictor, roleMappedData);
                ch.Done();
            }
        }
        public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetricsInput input)
        {
            var eval = GetEvaluator(env, input.Kind);

            var perInst = EvaluateUtils.ConcatenatePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select(
                                                                            idv => RoleMappedData.CreateOpt(idv, new[]
            {
                RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn),
                RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Weight, input.WeightColumn.Value),
                RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, input.GroupColumn.Value)
            })).ToArray(),
                                                                        out var variableSizeVectorColumnNames);

            var warnings = input.Warnings != null ? new List <IDataView>(input.Warnings) : new List <IDataView>();

            if (variableSizeVectorColumnNames.Length > 0)
            {
                var dvBldr = new ArrayDataViewBuilder(env);
                var warn   = $"Detected columns of variable length: {string.Join(", ", variableSizeVectorColumnNames)}." +
                             $" Consider setting collateMetrics- for meaningful per-Folds results.";
                dvBldr.AddColumn(MetricKinds.ColumnNames.WarningText, TextType.Instance, new DvText(warn));
                warnings.Add(dvBldr.GetDataView());
            }

            env.Assert(Utils.Size(perInst) == 1);

            var overall = eval.GetOverallResults(input.OverallMetrics);

            overall = EvaluateUtils.CombineFoldMetricsDataViews(env, overall, input.OverallMetrics.Length);

            IDataView conf = null;

            if (Utils.Size(input.ConfusionMatrix) > 0)
            {
                EvaluateUtils.ReconcileSlotNames <double>(env, input.ConfusionMatrix, MetricKinds.ColumnNames.Count, NumberType.R8);

                for (int i = 0; i < input.ConfusionMatrix.Length; i++)
                {
                    var idv = input.ConfusionMatrix[i];
                    // Find the old Count column and drop it.
                    for (int col = 0; col < idv.Schema.ColumnCount; col++)
                    {
                        if (idv.Schema.IsHidden(col) &&
                            idv.Schema.GetColumnName(col).Equals(MetricKinds.ColumnNames.Count))
                        {
                            input.ConfusionMatrix[i] = new ChooseColumnsByIndexTransform(env,
                                                                                         new ChooseColumnsByIndexTransform.Arguments()
                            {
                                Drop = true, Index = new[] { col }
                            }, idv);
                            break;
                        }
                    }
                }

                conf = EvaluateUtils.ConcatenateOverallMetrics(env, input.ConfusionMatrix);
            }

            var warningsIdv = warnings.Count > 0 ? AppendRowsDataView.Create(env, warnings[0].Schema, warnings.ToArray()) : null;

            return(new CombinedOutput()
            {
                PerInstanceMetrics = perInst[0],
                OverallMetrics = overall,
                ConfusionMatrix = conf,
                Warnings = warningsIdv
            });
        }
        public void Run()
        {
            string template;

            using (var stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(CodeTemplatePath))
                using (var reader = new StreamReader(stream))
                    template = reader.ReadToEnd();

            var codeProvider = new CSharpCodeProvider();

            using (var fs = File.OpenRead(_args.InputModelFile))
            {
                var transformPipe = ModelFileUtils.LoadPipeline(_host, fs, new MultiFileSource(null), true);
                var pred          = _host.LoadPredictorOrNull(fs);

                IDataView root;
                for (root = transformPipe; root is IDataTransform; root = ((IDataTransform)root).Source)
                {
                    ;
                }

                // root is now the loader.
                _host.Assert(root is IDataLoader);

                // Loader columns.
                var loaderSb = new StringBuilder();
                for (int i = 0; i < root.Schema.ColumnCount; i++)
                {
                    if (root.Schema.IsHidden(i))
                    {
                        continue;
                    }
                    if (loaderSb.Length > 0)
                    {
                        loaderSb.AppendLine();
                    }

                    ColumnType colType = root.Schema.GetColumnType(i);
                    CodeGenerationUtils.AppendFieldDeclaration(codeProvider, loaderSb, i, root.Schema.GetColumnName(i), colType, true, _args.SparseVectorDeclaration);
                }

                // Scored example columns.
                IDataView scorer;
                if (pred == null)
                {
                    scorer = transformPipe;
                }
                else
                {
                    var roles = ModelFileUtils.LoadRoleMappingsOrNull(_host, fs);
                    scorer = roles != null
                        ? _host.CreateDefaultScorer(RoleMappedData.CreateOpt(transformPipe, roles), pred)
                        : _host.CreateDefaultScorer(_host.CreateExamples(transformPipe, "Features"), pred);
                }

                var nonScoreSb = new StringBuilder();
                var scoreSb    = new StringBuilder();
                for (int i = 0; i < scorer.Schema.ColumnCount; i++)
                {
                    if (scorer.Schema.IsHidden(i))
                    {
                        continue;
                    }
                    bool isScoreColumn = scorer.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.ScoreColumnSetId, i) != null;

                    var sb = isScoreColumn ? scoreSb : nonScoreSb;

                    if (sb.Length > 0)
                    {
                        sb.AppendLine();
                    }

                    ColumnType colType = scorer.Schema.GetColumnType(i);
                    CodeGenerationUtils.AppendFieldDeclaration(codeProvider, sb, i, scorer.Schema.GetColumnName(i), colType, false, _args.SparseVectorDeclaration);
                }

                // Turn model path into a C# identifier and insert it.
                var modelPath = !string.IsNullOrWhiteSpace(_args.ModelNameOverride) ? _args.ModelNameOverride : _args.InputModelFile;
                modelPath = CodeGenerationUtils.GetCSharpString(codeProvider, modelPath);
                modelPath = string.Format("modelPath = {0};", modelPath);

                // Replace values inside the template.
                var replacementMap =
                    new Dictionary <string, string>
                {
                    { "EXAMPLE_CLASS_DECL", loaderSb.ToString() },
                    { "SCORED_EXAMPLE_CLASS_DECL", nonScoreSb.ToString() },
                    { "SCORE_CLASS_DECL", scoreSb.ToString() },
                    { "MODEL_PATH", modelPath }
                };

                var classSource = CodeGenerationUtils.MultiReplace(template, replacementMap);
                File.WriteAllText(_args.CSharpOutput, classSource);
            }
        }