コード例 #1
0
        public static void TrainkNNMultiClassification(int k, NearestNeighborsWeights weight, int threads, float ratio = 0.2f,
                                                       string distance = "L2")
        {
            var methodName       = string.Format("{0}-k{1}-W{2}-T{3}-D{4}", System.Reflection.MethodBase.GetCurrentMethod().Name, k, weight, threads, distance);
            var dataFilePath     = FileHelper.GetTestFile("iris.txt");
            var outModelFilePath = FileHelper.GetOutputFile("outModelFilePath.zip", methodName);
            var outData          = FileHelper.GetOutputFile("outData1.txt", methodName);
            var outData2         = FileHelper.GetOutputFile("outData2.txt", methodName);

            var env = k == 1 ? EnvHelper.NewTestEnvironment(conc: 1) : EnvHelper.NewTestEnvironment();

            using (env)
            {
                var loader = env.CreateLoader("Text{col=Label:R4:0 col=Slength:R4:1 col=Swidth:R4:2 col=Plength:R4:3 col=Pwidth:R4:4 header=+}",
                                              new MultiFileSource(dataFilePath));

                var    concat = env.CreateTransform("Concat{col=Features:Slength,Swidth}", loader);
                var    roles  = env.CreateExamples(concat, "Features", "Label");
                string modelDef;
                modelDef = string.Format("knnmc{{k={0} weighting={1} nt={2} distance={3}}}", k,
                                         weight == NearestNeighborsWeights.distance ? "distance" : "uniform", threads, distance);
                var trainer = env.CreateTrainer(modelDef);
                using (var ch = env.Start("test"))
                {
                    var pred = trainer.Train(env, ch, roles);
                    TestTrainerHelper.FinalizeSerializationTest(env, outModelFilePath, pred, roles, outData, outData2,
                                                                PredictionKind.MultiClassClassification, true, ratio: ratio);
                }
            }
        }
コード例 #2
0
 internal NearestNeighborsBinaryClassifierPredictor(IHost host, NearestNeighborsTrees trees, INearestNeighborsValueMapper predictor,
                                                    int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights)
 {
     _host             = host;
     _k                = k;
     _algo             = algo;
     _weights          = weights;
     _nearestPredictor = predictor;
     _nearestTrees     = trees;
 }
コード例 #3
0
 public virtual void Read(ModelLoadContext ctx, IHost host)
 {
     k          = ctx.Reader.ReadInt32();
     algo       = (NearestNeighborsAlgorithm)ctx.Reader.ReadInt32();
     weighting  = (NearestNeighborsWeights)ctx.Reader.ReadInt32();
     distance   = (NearestNeighborsDistance)ctx.Reader.ReadInt32();
     numThreads = ctx.Reader.ReadInt32();
     if (numThreads == -1)
     {
         numThreads = null;
     }
     seed = ctx.Reader.ReadInt32();
     if (seed == -1)
     {
         seed = null;
     }
     colId = ctx.Reader.ReadString();
     if (string.IsNullOrEmpty(colId))
     {
         colId = null;
     }
 }
コード例 #4
0
        protected void ReadCore(IHost host, ModelLoadContext ctx)
        {
            _k            = ctx.Reader.ReadInt32();
            _algo         = (NearestNeighborsAlgorithm)ctx.Reader.ReadInt32();
            _weights      = (NearestNeighborsWeights)ctx.Reader.ReadInt32();
            _nearestTrees = new NearestNeighborsTrees(_host, ctx);
            _host.CheckValue(_nearestTrees, "_nearestTrees");
            var kind_ = ctx.Reader.ReadInt32();
            var kind  = (DataKind)kind_;

            switch (kind)
            {
            case DataKind.BL:
                _nearestPredictor = new NearestNeighborsValueMapper <bool>(host, ctx);
                break;

            case DataKind.U1:
                _nearestPredictor = new NearestNeighborsValueMapper <byte>(host, ctx);
                break;

            case DataKind.U2:
                _nearestPredictor = new NearestNeighborsValueMapper <ushort>(host, ctx);
                break;

            case DataKind.U4:
                _nearestPredictor = new NearestNeighborsValueMapper <uint>(host, ctx);
                break;

            case DataKind.R4:
                _nearestPredictor = new NearestNeighborsValueMapper <float>(host, ctx);
                break;

            default:
                throw _host.ExceptNotSupp("Not suported kind={0}", kind);
            }
            _host.CheckValue(_nearestPredictor, "_nearestPredictor");
        }
コード例 #5
0
        internal static NearestNeighborsBinaryClassifierPredictor Create <TLabel>(IHost host,
                                                                                  KdTree[] kdtrees, Dictionary <long, Tuple <TLabel, float> > labelWeights,
                                                                                  int k, NearestNeighborsAlgorithm algo, NearestNeighborsWeights weights)
            where TLabel : IComparable <TLabel>
        {
            Contracts.CheckValue(host, "host");
            host.CheckValue(kdtrees, "kdtrees");
            host.Check(!kdtrees.Where(c => c == null).Any(), "kdtrees");
            NearestNeighborsBinaryClassifierPredictor res;

            using (var ch = host.Start("Creating kNN predictor"))
            {
                var trees = new NearestNeighborsTrees(host, kdtrees);
                var pred  = new NearestNeighborsValueMapper <TLabel>(host, labelWeights);
                res = new NearestNeighborsBinaryClassifierPredictor(host, trees, pred, k, algo, weights);
            }
            return(res);
        }
コード例 #6
0
        public static void TrainkNNTransformId(int k, NearestNeighborsWeights weight, int threads, string distance = "L2")
        {
            var methodName       = string.Format("{0}-k{1}-W{2}-T{3}-D{4}", System.Reflection.MethodBase.GetCurrentMethod().Name, k, weight, threads, distance);
            var dataFilePath     = FileHelper.GetTestFile("iris_binary_id.txt");
            var outModelFilePath = FileHelper.GetOutputFile("outModelFilePath.zip", methodName);
            var outData          = FileHelper.GetOutputFile("outData1.txt", methodName);
            var outData2         = FileHelper.GetOutputFile("outData2.txt", methodName);

            var env = k == 1 ? EnvHelper.NewTestEnvironment(conc: 1) : EnvHelper.NewTestEnvironment();

            using (env)
            {
                var loader = env.CreateLoader("Text{col=Label:R4:0 col=Slength:R4:1 col=Swidth:R4:2 col=Plength:R4:3 col=Pwidth:R4:4 col=Uid:I8:5 header=+}",
                                              new MultiFileSource(dataFilePath));

                var concat = env.CreateTransform("Concat{col=Features:Slength,Swidth}", loader);
                if (distance == "cosine")
                {
                    concat = env.CreateTransform("Scaler{col=Features}", concat);
                }
                concat = env.CreateTransform("knntr{k=5 id=Uid}", concat);
                long nb = DataViewUtils.ComputeRowCount(concat);
                if (nb == 0)
                {
                    throw new System.Exception("Empty pipeline.");
                }

                using (var cursor = concat.GetRowCursor(i => true))
                {
                    var getdist = cursor.GetGetter <VBuffer <float> >(7);
                    var getid   = cursor.GetGetter <VBuffer <long> >(8);
                    var ddist   = new VBuffer <float>();
                    var did     = new VBuffer <long>();
                    while (cursor.MoveNext())
                    {
                        getdist(ref ddist);
                        getid(ref did);
                        if (!ddist.IsDense || !did.IsDense)
                        {
                            throw new System.Exception("not dense");
                        }
                        if (ddist.Count != did.Count)
                        {
                            throw new System.Exception("not the same dimension");
                        }
                        for (int i = 1; i < ddist.Count; ++i)
                        {
                            if (ddist.Values[i - 1] > ddist.Values[i])
                            {
                                throw new System.Exception("not sorted");
                            }
                            if (did.Values[i] % 2 != 1)
                            {
                                throw new System.Exception("wrong id");
                            }
                        }
                    }
                }

                TestTransformHelper.SerializationTestTransform(env, outModelFilePath, concat, loader, outData, outData2, false);
            }
        }
コード例 #7
0
        ValueMapper <TIn, TOut> GetMapperBinaryPrediction <TIn, TOut>(NearestNeighborsTrees trees, int k,
                                                                      NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight)
        {
            var    conv          = new TypedConverters <TLabel>();
            TLabel positiveClass = default(TLabel);

            if (typeof(TLabel) == typeof(bool))
            {
                var convMap = conv.GetMapperFrom <bool>();
                var b       = true;
                convMap(in b, ref positiveClass);
            }
            else if (typeof(TLabel) == typeof(float))
            {
                var convMap = conv.GetMapperFrom <float>();
                var b       = 1f;
                convMap(in b, ref positiveClass);
            }
            else if (typeof(TLabel) == typeof(uint))
            {
                var  convMap = conv.GetMapperFrom <uint>();
                uint b       = 1;
                convMap(in b, ref positiveClass);
            }
            else
            {
                _host.ExceptNotImpl("Not implemented for type {0}", typeof(TLabel));
            }

            Dictionary <TLabel, float> hist = null;

            if (weight == NearestNeighborsWeights.uniform)
            {
                ValueMapper <VBuffer <float>, float> mapper = (in VBuffer <float> input, ref float output) =>
                {
                    GetMapperUniformBinaryPrediction(trees, k, in input, ref output, positiveClass, ref hist);
                };
                return(mapper as ValueMapper <TIn, TOut>);
            }
            else
            {
                throw _host.ExceptNotImpl("Not implemented for {0}", weight);
            }
        }
コード例 #8
0
        public ValueMapper <TIn, TOut> GetMapper <TIn, TOut>(NearestNeighborsTrees trees, int k,
                                                             NearestNeighborsAlgorithm algo, NearestNeighborsWeights weight, PredictionKind kind)
        {
            _host.Check(typeof(TIn) == typeof(VBuffer <float>));
            _host.CheckValue(_labelWeights, "_labelWeights");
            _host.Check(algo == NearestNeighborsAlgorithm.kdtree, "algo");

            if (weight == NearestNeighborsWeights.uniform)
            {
                switch (kind)
                {
                case PredictionKind.BinaryClassification:
                    return(GetMapperBinaryPrediction <TIn, TOut>(trees, k, algo, weight));

                case PredictionKind.MulticlassClassification:
                    return(GetMapperMulticlassPrediction <TIn, TOut>(trees, k, algo, weight));

                default:
                    throw _host.ExceptNotImpl("Not implemented yet for kind={0}", kind);
                }
            }
            else
            {
                throw _host.ExceptNotImpl("Not implemented yet for wieght={0}", weight);
            }
        }