private bool IsValid(IValueMapperDist mapper, ref VectorType inputType) { if (mapper == null) { return(false); } VectorType vectorType = mapper.InputType as VectorType; if (vectorType == null || !vectorType.IsKnownSize || vectorType.ItemType != NumberDataViewType.Single) { return(false); } if (inputType == null) { inputType = vectorType; } else if (inputType.Size != vectorType.Size) { return(false); } if (mapper.OutputType != NumberDataViewType.Single) { return(false); } if (mapper.DistType != NumberDataViewType.Single) { return(false); } return(true); }
internal static OvaPredictor Create(IHost host, bool useProb, TScalarPredictor[] predictors) { ImplBase impl; using (var ch = host.Start("Creating OVA predictor")) { IValueMapperDist ivmd = null; if (useProb && ((ivmd = predictors[0] as IValueMapperDist) == null || ivmd.OutputType != NumberType.Float || ivmd.DistType != NumberType.Float)) { ch.Warning($"{nameof(Ova.Arguments.UseProbabilities)} specified with {nameof(Ova.Arguments.PredictorType)} that can't produce probabilities."); ivmd = null; } if (ivmd != null) { var dists = new IValueMapperDist[predictors.Length]; for (int i = 0; i < predictors.Length; ++i) { dists[i] = (IValueMapperDist)predictors[i]; } impl = new ImplDist(dists); } else { impl = new ImplRaw(predictors); } } return(new OvaPredictor(host, impl)); }
private OvaPredictor(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** // bool: useDist // int: predictor count bool useDist = ctx.Reader.ReadBoolByte(); int len = ctx.Reader.ReadInt32(); Host.CheckDecode(len > 0); if (useDist) { var predictors = new IValueMapperDist[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplDist(predictors); } else { var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplRaw(predictors); } OutputType = new VectorType(NumberType.Float, _impl.Predictors.Length); }
private bool IsValid(IValueMapperDist mapper, ref ColumnType inputType) { if (mapper == null) { return(false); } if (!mapper.InputType.IsKnownSizeVector || mapper.InputType.ItemType != NumberType.Float) { return(false); } if (inputType == null) { inputType = mapper.InputType; } else if (inputType.VectorSize != mapper.InputType.VectorSize) { return(false); } if (mapper.OutputType != NumberType.Float) { return(false); } if (mapper.DistType != NumberType.Float) { return(false); } return(true); }
private bool IsValid(IValueMapperDist mapper) { return(mapper != null && mapper.InputType.IsVector && mapper.InputType.ItemType == NumberType.Float && mapper.OutputType == NumberType.Float && mapper.DistType == NumberType.Float); }
private ColumnType InitializeMappers(out IValueMapperDist[] mappers) { Host.AssertNonEmpty(Models); mappers = new IValueMapperDist[Models.Length]; ColumnType inputType = null; for (int i = 0; i < Models.Length; i++) { var vmd = Models[i].Predictor as IValueMapperDist; if (!IsValid(vmd)) { throw Host.Except("Predictor does not implement expected interface"); } if (vmd.InputType.VectorSize > 0) { if (inputType == null) { inputType = vmd.InputType; } else if (vmd.InputType.VectorSize != inputType.VectorSize) { throw Host.Except("Predictor input type mismatch"); } } mappers[i] = vmd; } return(inputType ?? new VectorType(NumberType.Float)); }
private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** // byte: OutputFormula as byte // int: predictor count OutputFormula outputFormula = (OutputFormula)ctx.Reader.ReadByte(); int len = ctx.Reader.ReadInt32(); Host.CheckDecode(len > 0); if (outputFormula == OutputFormula.Raw) { var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplRaw(predictors); } else if (outputFormula == OutputFormula.ProbabilityNormalization) { var predictors = new IValueMapperDist[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplDist(predictors); } else if (outputFormula == OutputFormula.Softmax) { var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplSoftmax(predictors); } DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); }
internal static OptimizedOVAPredictor Create(IHost host, bool useProb, TScalarPredictor[] predictors) { ImplBase impl; using (var ch = host.Start("Creating OVA predictor")) { IValueMapperDist ivmd = null; if (useProb && ((ivmd = predictors[0] as IValueMapperDist) == null || ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) { ch.Warning("useProbabilities specified with basePredictor that can't produce probabilities."); ivmd = null; } if (ivmd != null) { var dists = new IValueMapperDist[predictors.Length]; for (int i = 0; i < predictors.Length; ++i) { dists[i] = (IValueMapperDist)predictors[i]; } impl = new ImplDist(dists); } else { impl = new ImplRaw(predictors); } } return(new OptimizedOVAPredictor(host, impl)); }
private bool IsValid(IValueMapperDist mapper, out VectorType inputType) { if (mapper != null && mapper.InputType is VectorType inVectorType && inVectorType.ItemType == NumberType.Float && mapper.OutputType == NumberType.Float && mapper.DistType == NumberType.Float) { inputType = inVectorType; return(true); }
private bool IsValid(IValueMapperDist mapper, ref DataViewType inputType) { if (!base.IsValid(mapper, ref inputType)) { return(false); } if (mapper.DistType != NumberDataViewType.Single) { return(false); } return(true); }
private bool IsValid(IValueMapperDist mapper, ref ColumnType inputType) { if (!base.IsValid(mapper, ref inputType)) { return(false); } if (mapper.DistType != NumberType.Float) { return(false); } return(true); }
private VectorDataViewType InitializeMappers(out IValueMapperDist[] mappers) { mappers = new IValueMapperDist[_predictors.Length]; VectorDataViewType inputType = null; for (int i = 0; i < _predictors.Length; i++) { var vmd = _predictors[i] as IValueMapperDist; Host.Check(IsValid(vmd, ref inputType), "Predictor doesn't implement the expected interface"); mappers[i] = vmd; } return inputType; }
private void CheckValid(out IValueMapperDist distMapper) { Contracts.Check(ScoreType == NumberType.Float, "Expected predictor result type to be Float"); distMapper = Predictor as IValueMapperDist; if (distMapper == null) { throw Contracts.Except("Predictor does not provide probabilities"); } // REVIEW: In theory the restriction on input type could be relaxed at the expense // of more complicated code in CalibratedRowMapper.GetGetters. Not worth it at this point // and no good way to test it. Contracts.Check(distMapper.InputType.IsVector && distMapper.InputType.ItemType == NumberType.Float, "Invalid input type for the IValueMapperDist"); Contracts.Check(distMapper.DistType == NumberType.Float, "Invalid probability type for the IValueMapperDist"); }
internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) { ImplBase impl; using (var ch = host.Start("Creating OVA predictor")) { if (outputFormula == OutputFormula.Softmax) { impl = new ImplSoftmax(predictors); return(new OneVersusAllModelParameters(host, impl)); } // Caller of this function asks for probability output. We check if input predictor can produce probability. // If that predictor can't produce probability, ivmd will be null. IValueMapperDist ivmd = null; if (outputFormula == OutputFormula.ProbabilityNormalization && ((ivmd = predictors[0] as IValueMapperDist) == null || ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) { ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); ivmd = null; } // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. if (ivmd != null) { var dists = new IValueMapperDist[predictors.Length]; for (int i = 0; i < predictors.Length; ++i) { dists[i] = (IValueMapperDist)predictors[i]; } impl = new ImplDist(dists); } else { impl = new ImplRaw(predictors); } } return(new OneVersusAllModelParameters(host, impl)); }
private bool IsValid(IValueMapperDist mapper, ref ColumnType inputType) { return(base.IsValid(mapper, ref inputType) && mapper.DistType == NumberType.Float); }