示例#1
0
 protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
 {
     return(new Aggregator(Host, LossFunction, schema.Weight != null, stratName));
 }
示例#2
0
 protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
 {
     return(new SingleValueRowMapper(schema, this, new Schema(ScoreType, _quantiles)));
 }
示例#3
0
 protected override Aggregator GetAggregatorCore(RoleMappedSchema schema, string stratName)
 {
     return(new Aggregator(Host, _aucCount, _numTopResults, _k, _p, _streaming, schema.Name == null ? -1 : schema.Name.Index, stratName));
 }
示例#4
0
 protected abstract ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema);
示例#5
0
        protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema)
        {
            var outputSchema = new ScoreMapperSchema(ScoreType, _scoreColumnKind);

            return(new SingleValueRowMapper(schema, this, outputSchema));
        }
 public MultiClassClassifierScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
     : base(args, env, data, WrapIfNeeded(env, mapper, trainSchema), trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.MultiClassClassification,
            MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType)
 {
 }
示例#7
0
 public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
        public static ISchemaBoundMapper WrapCore <T>(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
        {
            Contracts.AssertValue(env);
            env.AssertValue(mapper);
            env.AssertValue(trainSchema);
            env.Assert(mapper is ISchemaBoundRowMapper);

            // Key values from the training schema label, will map to slot names of the score output.
            var type = trainSchema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, trainSchema.Label.Index);

            env.AssertValue(type);
            env.Assert(type.IsVector);

            // Wrap the fetching of the metadata as a simple getter.
            ValueGetter <VBuffer <T> > getter =
                (ref VBuffer <T> value) =>
            {
                trainSchema.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues,
                                               trainSchema.Label.Index, ref value);
            };

            return(LabelNameBindableMapper.CreateBound <T>(env, (ISchemaBoundRowMapper)mapper, type.AsVector, getter, MetadataUtils.Kinds.SlotNames, CanWrap));
        }
        /// <summary>
        /// This function performs a number of checks on the inputs and, if appropriate and possible, will produce
        /// a mapper with slots names on the output score column properly mapped. If this is not possible for any
        /// reason, it will just return the input bound mapper.
        /// </summary>
        private static ISchemaBoundMapper WrapIfNeeded(IHostEnvironment env, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(mapper, nameof(mapper));
            env.CheckValueOrNull(trainSchema);

            // The idea is that we will take the key values from the train schema label, and present
            // them as slot name metadata. But there are a number of conditions for this to actually
            // happen, so we test those here. If these are not

            if (trainSchema == null || trainSchema.Label == null)
            {
                return(mapper); // We don't even have a label identified in a training schema.
            }
            var keyType = trainSchema.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, trainSchema.Label.Index);

            if (keyType == null || !CanWrap(mapper, keyType))
            {
                return(mapper);
            }

            // Great!! All checks pass.
            return(Utils.MarshalInvoke(WrapCore <int>, keyType.ItemType.RawType, env, mapper, trainSchema));
        }
示例#10
0
        protected override IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > GetInputColumnRolesCore(RoleMappedSchema schema)
        {
            foreach (var col in base.GetInputColumnRolesCore(schema))
            {
                if (!col.Key.Equals(RoleMappedSchema.ColumnRole.Label))
                {
                    yield return(col);
                }
                else if (schema.Schema.TryGetColumnIndex(col.Value, out int labelIndex))
                {
                    yield return(col);
                }
            }

            if (_calculateDbi)
            {
                string feat = EvaluateUtils.GetColName(_featureCol, schema.Feature, DefaultColumnNames.Features);
                if (!schema.Schema.TryGetColumnIndex(feat, out int featCol))
                {
                    throw Host.ExceptUserArg(nameof(Arguments.FeatureColumn), "Features column '{0}' not found", feat);
                }
                yield return(RoleMappedSchema.ColumnRole.Feature.Bind(feat));
            }
        }
 public ClusteringScorer(IHostEnvironment env, Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
     : base(args, env, data, mapper, trainSchema, RegistrationName, MetadataUtils.Const.ScoreColumnKind.Clustering,
            MetadataUtils.Const.ScoreValueKind.Score, OutputTypeMatches, GetPredColType)
 {
 }
示例#12
0
 public FoldResult(Dictionary <string, IDataView> metrics, Schema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema)
 {
     Metrics            = metrics;
     ScoreSchema        = scoreSchema;
     PerInstanceResults = perInstance;
     TrainSchema        = trainSchema;
 }