コード例 #1
0
        public ValueMapper <TIn, TOut> GetMapper <TIn, TOut>(NearestNeighborsTrees trees, int k,
                                                             NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight, PredictionKind kind)
        {
            _host.Check(typeof(TIn) == typeof(VBuffer <float>));
            _host.CheckValue(_labelWeights, "_labelWeights");
            _host.Check(algo == NearestNeighborsAlgorithm.kdtree, "algo");

            if (weight == NearestNeighborsWeights.uniform)
            {
                switch (kind)
                {
                case PredictionKind.BinaryClassification:
                    return(GetMapperBinaryPrediction <TIn, TOut>(trees, k, algo, weight));

                case PredictionKind.MulticlassClassification:
                    return(GetMapperMulticlassPrediction <TIn, TOut>(trees, k, algo, weight));

                default:
                    throw _host.ExceptNotImpl("Not implemented yet for kind={0}", kind);
                }
            }
            else
            {
                throw _host.ExceptNotImpl("Not implemented yet for wieght={0}", weight);
            }
        }
コード例 #2
0
        void ComputeNearestNeighbors()
        {
            lock (_lock)
            {
                if (_trees != null)
                {
                    return;
                }

                using (var ch = _host.Start("Build k-d tree"))
                {
                    ch.Info(MessageSensitivity.None, "ComputeNearestNeighbors: build a k-d tree.");
                    int featureIndex, labelIndex, idIndex, weightIndex;
                    featureIndex = GetColumnIndex(ch, _args.column);
                    if (featureIndex == -1)
                    {
                        throw ch.Except($"Unable to find column '{_args.column}' in {SchemaHelper.ToString(Schema)}.");
                    }
                    labelIndex  = GetColumnIndex(ch, _args.labelColumn);
                    weightIndex = GetColumnIndex(ch, _args.weightColumn);
                    idIndex     = GetColumnIndex(ch, _args.colId);

                    Dictionary <long, Tuple <long, float> > merged;
                    _trees = NearestNeighborsBuilder.NearestNeighborsBuild <long>(ch, _input, featureIndex, labelIndex,
                                                                                  idIndex, weightIndex, out merged, _args);
                    ch.Info(MessageSensitivity.UserData, "Done. Tree size: {0} points.", _trees.Count());
                }
            }
        }
コード例 #3
0
        void ComputeNearestNeighbors()
        {
            lock (_lock)
            {
                if (_trees != null)
                {
                    return;
                }

                using (var ch = _host.Start("Build k-d tree"))
                {
                    ch.Info("ComputeNearestNeighbors: build a k-d tree.");
                    int featureIndex, labelIndex, idIndex, weightIndex;
                    if (!_input.Schema.TryGetColumnIndex(_args.column, out featureIndex))
                    {
                        throw ch.Except("Unable to find column '{0}'.", _args.column);
                    }
                    labelIndex  = GetColumnIndex(ch, _args.labelColumn);
                    weightIndex = GetColumnIndex(ch, _args.weightColumn);
                    idIndex     = GetColumnIndex(ch, _args.colId);

                    Dictionary <long, Tuple <long, float> > merged;
                    _trees = NearestNeighborsBuilder.NearestNeighborsBuild <long>(ch, _input, featureIndex, labelIndex,
                                                                                  idIndex, weightIndex, out merged, _args);
                    ch.Info("Done. Tree size: {0} points.", _trees.Count());
                }
            }
        }
コード例 #4
0
 internal NearestNeighborsBinaryClassifierPredictor(IHost host, NearestNeighborsTrees trees, INearestNeighborsValueMapper predictor,
                                                    int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights)
 {
     _host             = host;
     _k                = k;
     _algo             = algo;
     _weights          = weights;
     _nearestPredictor = predictor;
     _nearestTrees     = trees;
 }
コード例 #5
0
 public NearestNeighborsCursor(RowCursor cursor, NearestNeighborsTransform parent, Func <int, bool> predicate, int colFeatures)
 {
     _inputCursor    = cursor;
     _parent         = parent;
     _trees          = parent._trees;
     _k              = parent._args.k;
     _getterFeatures = _inputCursor.GetGetter <VBuffer <float> >(colFeatures);
     _tempFeatures   = new VBuffer <float>();
     _distance       = new VBuffer <float>(_k, new float[_k]);
     _idn            = new VBuffer <long>(_k, new long[_k]);
 }
コード例 #6
0
 public NearestNeighborsCursor(DataViewRowCursor cursor, NearestNeighborsTransform parent,
                               IEnumerable <DataViewSchema.Column> columnsNeeded,
                               DataViewSchema.Column colFeatures)
 {
     _inputCursor    = cursor;
     _parent         = parent;
     _trees          = parent._trees;
     _k              = parent._args.k;
     _getterFeatures = _inputCursor.GetGetter <VBuffer <float> >(colFeatures);
     _tempFeatures   = new VBuffer <float>();
     _distance       = new VBuffer <float>(_k, new float[_k]);
     _idn            = new VBuffer <long>(_k, new long[_k]);
 }
コード例 #7
0
        public NearestNeighborsTransform(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, "env");
            _host = env.Register(LoaderSignature);
            _host.CheckValue(args, "args");
            args.PostProcess();
            _host.CheckValue(args.column, "column");

            _input          = input;
            _trees          = null;
            _args           = args;
            _lock           = new object();
            _extendedSchema = ComputeExtendedSchema();
        }
コード例 #8
0
        private NearestNeighborsTransform(IHost host, ModelLoadContext ctx, IDataView input)
        {
            Contracts.CheckValue(host, "host");
            Contracts.CheckValue(input, "input");
            _lock  = new object();
            _host  = host;
            _input = input;
            _host.CheckValue(input, "input");
            _host.CheckValue(ctx, "ctx");
            _args = new Arguments();
            _args.Read(ctx, _host);
            bool run = ctx.Reader.ReadByte() == 1;

            _trees          = run ? new NearestNeighborsTrees(host, ctx) : null;
            _extendedSchema = ComputeExtendedSchema();
        }
コード例 #9
0
        internal static NearestNeighborsBinaryClassifierPredictor Create <TLabel>(IHost host,
                                                                                  KdTree[] kdtrees, Dictionary <long, Tuple <TLabel, float> > labelWeights,
                                                                                  int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights)
            where TLabel : IComparable <TLabel>
        {
            Contracts.CheckValue(host, "host");
            host.CheckValue(kdtrees, "kdtrees");
            host.Check(!kdtrees.Where(c => c == null).Any(), "kdtrees");
            NearestNeighborsBinaryClassifierPredictor res;

            using (var ch = host.Start("Creating kNN predictor"))
            {
                var trees = new NearestNeighborsTrees(host, kdtrees);
                var pred  = new NearestNeighborsValueMapper <TLabel>(host, labelWeights);
                res = new NearestNeighborsBinaryClassifierPredictor(host, trees, pred, k, algo, weights);
            }
            return(res);
        }
コード例 #10
0
        ValueMapper <TIn, TOut> GetMapperBinaryPrediction <TIn, TOut>(NearestNeighborsTrees trees, int k,
                                                                      NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight)
        {
            var    conv          = new TypedConverters <TLabel>();
            TLabel positiveClass = default(TLabel);

            if (typeof(TLabel) == typeof(bool))
            {
                var convMap = conv.GetMapperFrom <bool>();
                var b       = true;
                convMap(in b, ref positiveClass);
            }
            else if (typeof(TLabel) == typeof(float))
            {
                var convMap = conv.GetMapperFrom <float>();
                var b       = 1f;
                convMap(in b, ref positiveClass);
            }
            else if (typeof(TLabel) == typeof(uint))
            {
                var  convMap = conv.GetMapperFrom <uint>();
                uint b       = 1;
                convMap(in b, ref positiveClass);
            }
            else
            {
                _host.ExceptNotImpl("Not implemented for type {0}", typeof(TLabel));
            }

            Dictionary <TLabel, float> hist = null;

            if (weight == NearestNeighborsWeights.uniform)
            {
                ValueMapper <VBuffer <float>, float> mapper = (in VBuffer <float> input, ref float output) =>
                {
                    GetMapperUniformBinaryPrediction(trees, k, in input, ref output, positiveClass, ref hist);
                };
                return(mapper as ValueMapper <TIn, TOut>);
            }
            else
            {
                throw _host.ExceptNotImpl("Not implemented for {0}", weight);
            }
        }
コード例 #11
0
        protected void ReadCore(IHost host, ModelLoadContext ctx)
        {
            _k            = ctx.Reader.ReadInt32();
            _algo         = (NearestNeighborsAlgorithm)ctx.Reader.ReadInt32();
            _weights      = (NearestNeighborsWeights)ctx.Reader.ReadInt32();
            _nearestTrees = new NearestNeighborsTrees(_host, ctx);
            _host.CheckValue(_nearestTrees, "_nearestTrees");
            var kind_ = ctx.Reader.ReadInt32();
            var kind  = (DataKind)kind_;

            switch (kind)
            {
            case DataKind.BL:
                _nearestPredictor = new NearestNeighborsValueMapper <bool>(host, ctx);
                break;

            case DataKind.U1:
                _nearestPredictor = new NearestNeighborsValueMapper <byte>(host, ctx);
                break;

            case DataKind.U2:
                _nearestPredictor = new NearestNeighborsValueMapper <ushort>(host, ctx);
                break;

            case DataKind.U4:
                _nearestPredictor = new NearestNeighborsValueMapper <uint>(host, ctx);
                break;

            case DataKind.R4:
                _nearestPredictor = new NearestNeighborsValueMapper <float>(host, ctx);
                break;

            default:
                throw _host.ExceptNotSupp("Not suported kind={0}", kind);
            }
            _host.CheckValue(_nearestPredictor, "_nearestPredictor");
        }
コード例 #12
0
 void GetMapperUniformBinaryPrediction(NearestNeighborsTrees trees, int k, in VBuffer <float> input,