public ValueMapper <TIn, TOut> GetMapper <TIn, TOut>(NearestNeighborsTrees trees, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight, PredictionKind kind) { _host.Check(typeof(TIn) == typeof(VBuffer <float>)); _host.CheckValue(_labelWeights, "_labelWeights"); _host.Check(algo == NearestNeighborsAlgorithm.kdtree, "algo"); if (weight == NearestNeighborsWeights.uniform) { switch (kind) { case PredictionKind.BinaryClassification: return(GetMapperBinaryPrediction <TIn, TOut>(trees, k, algo, weight)); case PredictionKind.MulticlassClassification: return(GetMapperMulticlassPrediction <TIn, TOut>(trees, k, algo, weight)); default: throw _host.ExceptNotImpl("Not implemented yet for kind={0}", kind); } } else { throw _host.ExceptNotImpl("Not implemented yet for wieght={0}", weight); } }
internal NearestNeighborsBinaryClassifierPredictor(IHost host, NearestNeighborsTrees trees, INearestNeighborsValueMapper predictor, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights) { _host = host; _k = k; _algo = algo; _weights = weights; _nearestPredictor = predictor; _nearestTrees = trees; }
ValueMapper <TIn, TOut> GetMapperBinaryPrediction <TIn, TOut>(NearestNeighborsTrees trees, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight) { var conv = new TypedConverters <TLabel>(); TLabel positiveClass = default(TLabel); if (typeof(TLabel) == typeof(bool)) { var convMap = conv.GetMapperFrom <bool>(); var b = true; convMap(in b, ref positiveClass); } else if (typeof(TLabel) == typeof(float)) { var convMap = conv.GetMapperFrom <float>(); var b = 1f; convMap(in b, ref positiveClass); } else if (typeof(TLabel) == typeof(uint)) { var convMap = conv.GetMapperFrom <uint>(); uint b = 1; convMap(in b, ref positiveClass); } else { _host.ExceptNotImpl("Not implemented for type {0}", typeof(TLabel)); } Dictionary <TLabel, float> hist = null; if (weight == NearestNeighborsWeights.uniform) { ValueMapper <VBuffer <float>, float> mapper = (in VBuffer <float> input, ref float output) => { GetMapperUniformBinaryPrediction(trees, k, in input, ref output, positiveClass, ref hist); }; return(mapper as ValueMapper <TIn, TOut>); } else { throw _host.ExceptNotImpl("Not implemented for {0}", weight); } }
public virtual void Read(ModelLoadContext ctx, IHost host) { k = ctx.Reader.ReadInt32(); algo = (NearestNeighborsAlgorithm)ctx.Reader.ReadInt32(); weighting = (NearestNeighborsWeights)ctx.Reader.ReadInt32(); distance = (NearestNeighborsDistance)ctx.Reader.ReadInt32(); numThreads = ctx.Reader.ReadInt32(); if (numThreads == -1) { numThreads = null; } seed = ctx.Reader.ReadInt32(); if (seed == -1) { seed = null; } colId = ctx.Reader.ReadString(); if (string.IsNullOrEmpty(colId)) { colId = null; } }
protected void ReadCore(IHost host, ModelLoadContext ctx) { _k = ctx.Reader.ReadInt32(); _algo = (NearestNeighborsAlgorithm)ctx.Reader.ReadInt32(); _weights = (NearestNeighborsWeights)ctx.Reader.ReadInt32(); _nearestTrees = new NearestNeighborsTrees(_host, ctx); _host.CheckValue(_nearestTrees, "_nearestTrees"); var kind_ = ctx.Reader.ReadInt32(); var kind = (DataKind)kind_; switch (kind) { case DataKind.BL: _nearestPredictor = new NearestNeighborsValueMapper <bool>(host, ctx); break; case DataKind.U1: _nearestPredictor = new NearestNeighborsValueMapper <byte>(host, ctx); break; case DataKind.U2: _nearestPredictor = new NearestNeighborsValueMapper <ushort>(host, ctx); break; case DataKind.U4: _nearestPredictor = new NearestNeighborsValueMapper <uint>(host, ctx); break; case DataKind.R4: _nearestPredictor = new NearestNeighborsValueMapper <float>(host, ctx); break; default: throw _host.ExceptNotSupp("Not suported kind={0}", kind); } _host.CheckValue(_nearestPredictor, "_nearestPredictor"); }
internal static NearestNeighborsBinaryClassifierPredictor Create <TLabel>(IHost host, KdTree[] kdtrees, Dictionary <long, Tuple <TLabel, float> > labelWeights, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights) where TLabel : IComparable <TLabel> { Contracts.CheckValue(host, "host"); host.CheckValue(kdtrees, "kdtrees"); host.Check(!kdtrees.Where(c => c == null).Any(), "kdtrees"); NearestNeighborsBinaryClassifierPredictor res; using (var ch = host.Start("Creating kNN predictor")) { var trees = new NearestNeighborsTrees(host, kdtrees); var pred = new NearestNeighborsValueMapper <TLabel>(host, labelWeights); res = new NearestNeighborsBinaryClassifierPredictor(host, trees, pred, k, algo, weights); } return(res); }