public ValueMapper <TIn, TOut> GetMapper <TIn, TOut>(NearestNeighborsTrees trees, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight, PredictionKind kind) { _host.Check(typeof(TIn) == typeof(VBuffer <float>)); _host.CheckValue(_labelWeights, "_labelWeights"); _host.Check(algo == NearestNeighborsAlgorithm.kdtree, "algo"); if (weight == NearestNeighborsWeights.uniform) { switch (kind) { case PredictionKind.BinaryClassification: return(GetMapperBinaryPrediction <TIn, TOut>(trees, k, algo, weight)); case PredictionKind.MulticlassClassification: return(GetMapperMulticlassPrediction <TIn, TOut>(trees, k, algo, weight)); default: throw _host.ExceptNotImpl("Not implemented yet for kind={0}", kind); } } else { throw _host.ExceptNotImpl("Not implemented yet for wieght={0}", weight); } }
void ComputeNearestNeighbors() { lock (_lock) { if (_trees != null) { return; } using (var ch = _host.Start("Build k-d tree")) { ch.Info(MessageSensitivity.None, "ComputeNearestNeighbors: build a k-d tree."); int featureIndex, labelIndex, idIndex, weightIndex; featureIndex = GetColumnIndex(ch, _args.column); if (featureIndex == -1) { throw ch.Except($"Unable to find column '{_args.column}' in {SchemaHelper.ToString(Schema)}."); } labelIndex = GetColumnIndex(ch, _args.labelColumn); weightIndex = GetColumnIndex(ch, _args.weightColumn); idIndex = GetColumnIndex(ch, _args.colId); Dictionary <long, Tuple <long, float> > merged; _trees = NearestNeighborsBuilder.NearestNeighborsBuild <long>(ch, _input, featureIndex, labelIndex, idIndex, weightIndex, out merged, _args); ch.Info(MessageSensitivity.UserData, "Done. Tree size: {0} points.", _trees.Count()); } } }
void ComputeNearestNeighbors() { lock (_lock) { if (_trees != null) { return; } using (var ch = _host.Start("Build k-d tree")) { ch.Info("ComputeNearestNeighbors: build a k-d tree."); int featureIndex, labelIndex, idIndex, weightIndex; if (!_input.Schema.TryGetColumnIndex(_args.column, out featureIndex)) { throw ch.Except("Unable to find column '{0}'.", _args.column); } labelIndex = GetColumnIndex(ch, _args.labelColumn); weightIndex = GetColumnIndex(ch, _args.weightColumn); idIndex = GetColumnIndex(ch, _args.colId); Dictionary <long, Tuple <long, float> > merged; _trees = NearestNeighborsBuilder.NearestNeighborsBuild <long>(ch, _input, featureIndex, labelIndex, idIndex, weightIndex, out merged, _args); ch.Info("Done. Tree size: {0} points.", _trees.Count()); } } }
internal NearestNeighborsBinaryClassifierPredictor(IHost host, NearestNeighborsTrees trees, INearestNeighborsValueMapper predictor, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights) { _host = host; _k = k; _algo = algo; _weights = weights; _nearestPredictor = predictor; _nearestTrees = trees; }
public NearestNeighborsCursor(RowCursor cursor, NearestNeighborsTransform parent, Func <int, bool> predicate, int colFeatures) { _inputCursor = cursor; _parent = parent; _trees = parent._trees; _k = parent._args.k; _getterFeatures = _inputCursor.GetGetter <VBuffer <float> >(colFeatures); _tempFeatures = new VBuffer <float>(); _distance = new VBuffer <float>(_k, new float[_k]); _idn = new VBuffer <long>(_k, new long[_k]); }
public NearestNeighborsCursor(DataViewRowCursor cursor, NearestNeighborsTransform parent, IEnumerable <DataViewSchema.Column> columnsNeeded, DataViewSchema.Column colFeatures) { _inputCursor = cursor; _parent = parent; _trees = parent._trees; _k = parent._args.k; _getterFeatures = _inputCursor.GetGetter <VBuffer <float> >(colFeatures); _tempFeatures = new VBuffer <float>(); _distance = new VBuffer <float>(_k, new float[_k]); _idn = new VBuffer <long>(_k, new long[_k]); }
public NearestNeighborsTransform(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, "env"); _host = env.Register(LoaderSignature); _host.CheckValue(args, "args"); args.PostProcess(); _host.CheckValue(args.column, "column"); _input = input; _trees = null; _args = args; _lock = new object(); _extendedSchema = ComputeExtendedSchema(); }
private NearestNeighborsTransform(IHost host, ModelLoadContext ctx, IDataView input) { Contracts.CheckValue(host, "host"); Contracts.CheckValue(input, "input"); _lock = new object(); _host = host; _input = input; _host.CheckValue(input, "input"); _host.CheckValue(ctx, "ctx"); _args = new Arguments(); _args.Read(ctx, _host); bool run = ctx.Reader.ReadByte() == 1; _trees = run ? new NearestNeighborsTrees(host, ctx) : null; _extendedSchema = ComputeExtendedSchema(); }
internal static NearestNeighborsBinaryClassifierPredictor Create <TLabel>(IHost host, KdTree[] kdtrees, Dictionary <long, Tuple <TLabel, float> > labelWeights, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights) where TLabel : IComparable <TLabel> { Contracts.CheckValue(host, "host"); host.CheckValue(kdtrees, "kdtrees"); host.Check(!kdtrees.Where(c => c == null).Any(), "kdtrees"); NearestNeighborsBinaryClassifierPredictor res; using (var ch = host.Start("Creating kNN predictor")) { var trees = new NearestNeighborsTrees(host, kdtrees); var pred = new NearestNeighborsValueMapper <TLabel>(host, labelWeights); res = new NearestNeighborsBinaryClassifierPredictor(host, trees, pred, k, algo, weights); } return(res); }
ValueMapper <TIn, TOut> GetMapperBinaryPrediction <TIn, TOut>(NearestNeighborsTrees trees, int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight) { var conv = new TypedConverters <TLabel>(); TLabel positiveClass = default(TLabel); if (typeof(TLabel) == typeof(bool)) { var convMap = conv.GetMapperFrom <bool>(); var b = true; convMap(in b, ref positiveClass); } else if (typeof(TLabel) == typeof(float)) { var convMap = conv.GetMapperFrom <float>(); var b = 1f; convMap(in b, ref positiveClass); } else if (typeof(TLabel) == typeof(uint)) { var convMap = conv.GetMapperFrom <uint>(); uint b = 1; convMap(in b, ref positiveClass); } else { _host.ExceptNotImpl("Not implemented for type {0}", typeof(TLabel)); } Dictionary <TLabel, float> hist = null; if (weight == NearestNeighborsWeights.uniform) { ValueMapper <VBuffer <float>, float> mapper = (in VBuffer <float> input, ref float output) => { GetMapperUniformBinaryPrediction(trees, k, in input, ref output, positiveClass, ref hist); }; return(mapper as ValueMapper <TIn, TOut>); } else { throw _host.ExceptNotImpl("Not implemented for {0}", weight); } }
protected void ReadCore(IHost host, ModelLoadContext ctx) { _k = ctx.Reader.ReadInt32(); _algo = (NearestNeighborsAlgorithm)ctx.Reader.ReadInt32(); _weights = (NearestNeighborsWeights)ctx.Reader.ReadInt32(); _nearestTrees = new NearestNeighborsTrees(_host, ctx); _host.CheckValue(_nearestTrees, "_nearestTrees"); var kind_ = ctx.Reader.ReadInt32(); var kind = (DataKind)kind_; switch (kind) { case DataKind.BL: _nearestPredictor = new NearestNeighborsValueMapper <bool>(host, ctx); break; case DataKind.U1: _nearestPredictor = new NearestNeighborsValueMapper <byte>(host, ctx); break; case DataKind.U2: _nearestPredictor = new NearestNeighborsValueMapper <ushort>(host, ctx); break; case DataKind.U4: _nearestPredictor = new NearestNeighborsValueMapper <uint>(host, ctx); break; case DataKind.R4: _nearestPredictor = new NearestNeighborsValueMapper <float>(host, ctx); break; default: throw _host.ExceptNotSupp("Not suported kind={0}", kind); } _host.CheckValue(_nearestPredictor, "_nearestPredictor"); }
void GetMapperUniformBinaryPrediction(NearestNeighborsTrees trees, int k, in VBuffer <float> input,