コード例 #1
0
 private static KdTree BuildKDTree <TLabel>(IDataView data, DataViewRowCursor cursor,
                                            int featureIndex, int labelIndex, int idIndex, int weightIndex,
                                            out Dictionary <long, Tuple <TLabel, float> > labelsWeights, NearestNeighborsArguments args)
     where TLabel : IComparable <TLabel>
 {
     using (cursor)
     {
         var featureGetter = cursor.GetGetter <VBuffer <float> >(SchemaHelper._dc(featureIndex, cursor));
         var labelGetter   = labelIndex >= 0 && labelIndex < int.MaxValue
             ? cursor.GetGetter <TLabel>(SchemaHelper._dc(labelIndex, cursor))
             : null;
         var weightGetter = weightIndex >= 0 && weightIndex < int.MaxValue
             ? cursor.GetGetter <float>(SchemaHelper._dc(weightIndex, cursor))
             : null;
         var idGetter = idIndex >= 0 && idIndex < int.MaxValue
             ? cursor.GetGetter <long>(SchemaHelper._dc(idIndex, cursor))
             : null;
         var kdtree = new KdTree(distance: args.distance, seed: args.seed);
         labelsWeights = new Dictionary <long, Tuple <TLabel, float> >();
         VBuffer <float> features = new VBuffer <float>();
         TLabel          label    = default(TLabel);
         float           weight   = 1;
         long            lid      = default(long);
         while (cursor.MoveNext())
         {
             featureGetter(ref features);
             if (labelGetter != null)
             {
                 labelGetter(ref label);
             }
             if (weightGetter != null)
             {
                 weightGetter(ref weight);
             }
             if (idGetter != null)
             {
                 idGetter(ref lid);
             }
             else
             {
                 lid = labelsWeights.Count;
             }
             labelsWeights[lid] = new Tuple <TLabel, float>(label, weight);
             var point = new PointIdFloat(lid, features, true);
             kdtree.Add(point);
         }
         return(kdtree);
     }
 }
コード例 #2
0
        public KeyValuePair <float, long>[] NearestNNeighbors(VBuffer <float> target, int k)
        {
            var point = new PointIdFloat(-1, target, false);

            KeyValuePair <float, long>[] neighbors;
            if (_kdtrees.Length == 1)
            {
                // kdtrees returns the opposite of the distance.
                neighbors = _kdtrees[0].NearestNNeighborsAndDistance(point, k).Select(c => new KeyValuePair <float, long>(-c.Key, c.Value.id)).ToArray();
            }
            else
            {
                KeyValuePair <float, long>[][] stack = new KeyValuePair <float, long> [_kdtrees.Length][];
                var ops = new Action[_kdtrees.Length];
                for (int i = 0; i < ops.Length; ++i)
                {
                    int chunkId = i;
                    ops[i] = () =>
                    {
                        // kdtrees returns the opposite of the distance.
                        stack[chunkId] = _kdtrees[chunkId].NearestNNeighborsAndDistance(point, k).Select(c => new KeyValuePair <float, long>(-c.Key, c.Value.id)).ToArray();
                    };
                }
                Parallel.Invoke(new ParallelOptions()
                {
                    MaxDegreeOfParallelism = ops.Length
                }, ops);
                var merged = new List <KeyValuePair <float, long> >();
                for (int i = 0; i < ops.Length; ++i)
                {
                    merged.AddRange(stack[i]);
                }
                neighbors = merged.OrderBy(c => c.Key).Take(k).ToArray();
            }
            return(neighbors);
        }