private static KdTree BuildKDTree <TLabel>(IDataView data, DataViewRowCursor cursor, int featureIndex, int labelIndex, int idIndex, int weightIndex, out Dictionary <long, Tuple <TLabel, float> > labelsWeights, NearestNeighborsArguments args) where TLabel : IComparable <TLabel> { using (cursor) { var featureGetter = cursor.GetGetter <VBuffer <float> >(SchemaHelper._dc(featureIndex, cursor)); var labelGetter = labelIndex >= 0 && labelIndex < int.MaxValue ? cursor.GetGetter <TLabel>(SchemaHelper._dc(labelIndex, cursor)) : null; var weightGetter = weightIndex >= 0 && weightIndex < int.MaxValue ? cursor.GetGetter <float>(SchemaHelper._dc(weightIndex, cursor)) : null; var idGetter = idIndex >= 0 && idIndex < int.MaxValue ? cursor.GetGetter <long>(SchemaHelper._dc(idIndex, cursor)) : null; var kdtree = new KdTree(distance: args.distance, seed: args.seed); labelsWeights = new Dictionary <long, Tuple <TLabel, float> >(); VBuffer <float> features = new VBuffer <float>(); TLabel label = default(TLabel); float weight = 1; long lid = default(long); while (cursor.MoveNext()) { featureGetter(ref features); if (labelGetter != null) { labelGetter(ref label); } if (weightGetter != null) { weightGetter(ref weight); } if (idGetter != null) { idGetter(ref lid); } else { lid = labelsWeights.Count; } labelsWeights[lid] = new Tuple <TLabel, float>(label, weight); var point = new PointIdFloat(lid, features, true); kdtree.Add(point); } return(kdtree); } }
public KeyValuePair <float, long>[] NearestNNeighbors(VBuffer <float> target, int k) { var point = new PointIdFloat(-1, target, false); KeyValuePair <float, long>[] neighbors; if (_kdtrees.Length == 1) { // kdtrees returns the opposite of the distance. neighbors = _kdtrees[0].NearestNNeighborsAndDistance(point, k).Select(c => new KeyValuePair <float, long>(-c.Key, c.Value.id)).ToArray(); } else { KeyValuePair <float, long>[][] stack = new KeyValuePair <float, long> [_kdtrees.Length][]; var ops = new Action[_kdtrees.Length]; for (int i = 0; i < ops.Length; ++i) { int chunkId = i; ops[i] = () => { // kdtrees returns the opposite of the distance. stack[chunkId] = _kdtrees[chunkId].NearestNNeighborsAndDistance(point, k).Select(c => new KeyValuePair <float, long>(-c.Key, c.Value.id)).ToArray(); }; } Parallel.Invoke(new ParallelOptions() { MaxDegreeOfParallelism = ops.Length }, ops); var merged = new List <KeyValuePair <float, long> >(); for (int i = 0; i < ops.Length; ++i) { merged.AddRange(stack[i]); } neighbors = merged.OrderBy(c => c.Key).Take(k).ToArray(); } return(neighbors); }