Пример #1
0
        /// <summary>
        /// The function walk through the data to compute the highest label.
        /// </summary>
        protected Tuple <int, int> MinMaxLabelOverDataSet(MultiToBinaryTransform tr, string label, out int nb)
        {
            int index = SchemaHelper.GetColumnIndex(tr.Schema, label);
            var ty    = tr.Schema[index].Type;

            switch (ty.RawKind())
            {
            case DataKind.Single:
                // float is 0 based
                var tf = MinMaxLabel <float>(tr, index);
                nb = (int)tf.Item2 + 1;
                return(new Tuple <int, int>((int)tf.Item1, (int)tf.Item2));

            case DataKind.SByte:
                // key is 1 based
                var tb = MinMaxLabel <byte>(tr, index);
                nb = tb.Item2;
                return(new Tuple <int, int>(1, (int)tb.Item2));

            case DataKind.UInt16:
                // key is 1 based
                var ts = MinMaxLabel <ushort>(tr, index);
                nb = ts.Item2;
                return(new Tuple <int, int>(1, (int)ts.Item2));

            case DataKind.UInt32:
                // key is 1 based
                var tu = MinMaxLabel <uint>(tr, index);
                nb = (int)tu.Item2;
                return(new Tuple <int, int>(1, (int)tu.Item2));

            default:
                throw Contracts.ExceptNotImpl("Type '{0}' not implemented", ty.RawKind());
            }
        }
Пример #2
0
        /// <summary>
        /// We insert a MultiToBinaryTransform in the pipeline.
        /// </summary>
        /// <param name="data">input data</param>
        /// <param name="dstName">name of the new label</param>
        /// <param name="ch">channel</param>
        /// <param name="labName">label name as a single column</param>
        /// <param name="count">indication of the number of expected classes</param>
        /// <param name="train">pipeline for training or not</param>
        /// <param name="args">arguments send to the trainer</param>
        /// <returns>created view</returns>
        protected MultiToBinaryTransform MapLabelsAndInsertTransform(IChannel ch, RoleMappedData data,
                                                                     out string dstName, out string labName, int count, bool train, Arguments args)
        {
            var lab = data.Schema.Label.Value;

            Host.Assert(!data.Schema.Schema[lab.Index].IsHidden);
            Host.Assert(lab.Type.GetKeyCount() > 0 || lab.Type == NumberDataViewType.Single);

            IDataView source = data.Data;

            if (train)
            {
                source = FilterNA(source, lab.Name, args.dropNALabel);
                if (lab.Type.IsKey())
                {
                    var uargs = new ULabelToR4LabelTransform.Arguments
                    {
                        columns = new Column1x1[] { new Column1x1()
                                                    {
                                                        Source = lab.Name, Name = lab.Name
                                                    } }
                    };
                    source = new ULabelToR4LabelTransform(Host, uargs, source);
                }
            }

            // Get the destination label column name.
            ch.Info("Multiplying rows label '{0}'", lab.Name);
            switch (args.algo)
            {
            case MultiToBinaryTransform.MultiplicationAlgorithm.Default:
            case MultiToBinaryTransform.MultiplicationAlgorithm.Reweight:
                dstName = source.Schema.GetTempColumnName() + "BL";
                break;

            case MultiToBinaryTransform.MultiplicationAlgorithm.Ranking:
                dstName = source.Schema.GetTempColumnName() + "U4";
                break;

            default:
                throw Host.ExceptNotSupp("Not supported algorithm {0}", args.algo);
            }
            var args2 = new MultiToBinaryTransform.Arguments
            {
                label      = lab.Name,
                newColumn  = dstName,
                algo       = args.algo,
                weight     = args.weight,
                maxMulti   = args.maxMulti,
                seed       = args.seed,
                numThreads = args.numThreads,
            };

            labName = lab.Name;

            var tr = new MultiToBinaryTransform(Host, args2, source);

            return(tr);
        }
        protected TVectorPredictor CreateFinalPredictor(IChannel ch, RoleMappedData data,
                                                        MultiToBinaryTransform trans, int count, Arguments args,
                                                        TScalarPredictor[] predictors, IPredictor reclassPredictor)
        {
            // We create the final predictor. We remove every unneeded transform.
            string dstName, labName;
            var    trans_ = trans;

            trans = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, false, args);
            trans.Steal(trans_);

            int indexLab;

            if (!trans.Schema.TryGetColumnIndex(labName, out indexLab))
            {
                throw ch.Except("Unable to find column '{0}' in \n{1}", labName, SchemaHelper.ToString(trans.Schema));
            }

            var labType        = trans.Schema.GetColumnType(indexLab);
            var initialLabKind = data.Schema.Label.Type.RawKind();

            TVectorPredictor predictor;

            switch (initialLabKind)
            {
            case DataKind.R4:
                var p4 = MultiToBinaryPredictor.Create(Host, trans.GetClasses <float>(), predictors, reclassPredictor, args.singleColumn, false);
                predictor = p4 as TVectorPredictor;
                break;

            case DataKind.U1:
                var pu1 = MultiToBinaryPredictor.Create(Host, trans.GetClasses <byte>(), predictors, reclassPredictor, args.singleColumn, true);
                predictor = pu1 as TVectorPredictor;
                break;

            case DataKind.U2:
                var pu2 = MultiToBinaryPredictor.Create(Host, trans.GetClasses <ushort>(), predictors, reclassPredictor, args.singleColumn, true);
                predictor = pu2 as TVectorPredictor;
                break;

            case DataKind.U4:
                var pu4 = MultiToBinaryPredictor.Create(Host, trans.GetClasses <uint>(), predictors, reclassPredictor, args.singleColumn, true);
                predictor = pu4 as TVectorPredictor;
                break;

            default:
                throw ch.ExceptNotSupp("Unsupported type for a multi class label.");
            }

            Host.Assert(predictor != null);
            return(predictor);
        }
        protected TVectorPredictor CreateFinalPredictor(IChannel ch, RoleMappedData data,
                                                        MultiToBinaryTransform trans, int count, Arguments args,
                                                        TScalarPredictor[] predictors, IPredictor reclassPredictor)
        {
            // We create the final predictor. We remove every unneeded transform.
            string dstName, labName;
            int    indexLab;
            var    trans_ = trans;

            trans = MapLabelsAndInsertTransform(ch, data, out dstName, out labName, count, false, _args);
            trans.Steal(trans_);
            indexLab = SchemaHelper.GetColumnIndex(trans.Schema, labName);
            var labType        = trans.Schema[indexLab].Type;
            var initialLabKind = data.Schema.Label.Value.Type.RawKind();

            TVectorPredictor predictor;

            switch (initialLabKind)
            {
            case DataKind.Single:
                var p4 = MultiToRankerPredictor.Create(Host, trans.GetClasses <float>(), predictors, _reclassPredictor, _args.singleColumn, false);
                predictor = p4 as TVectorPredictor;
                break;

            case DataKind.SByte:
                var pu1 = MultiToRankerPredictor.Create(Host, trans.GetClasses <byte>(), predictors, _reclassPredictor, _args.singleColumn, true);
                predictor = pu1 as TVectorPredictor;
                break;

            case DataKind.UInt16:
                var pu2 = MultiToRankerPredictor.Create(Host, trans.GetClasses <ushort>(), predictors, _reclassPredictor, _args.singleColumn, true);
                predictor = pu2 as TVectorPredictor;
                break;

            case DataKind.UInt32:
                var pu4 = MultiToRankerPredictor.Create(Host, trans.GetClasses <uint>(), predictors, _reclassPredictor, _args.singleColumn, true);
                predictor = pu4 as TVectorPredictor;
                break;

            default:
                throw ch.ExceptNotSupp("Unsupported type for a multi class label.");
            }

            Host.Assert(predictor != null);
            return(predictor);
        }
        /// <summary>
        /// The function walk through the data to compute the highest label.
        /// </summary>
        protected Tuple <int, int> MinMaxLabelOverDataSet(MultiToBinaryTransform tr, string label, out int nb)
        {
            int index;

            if (!tr.Schema.TryGetColumnIndex(label, out index))
            {
                throw Contracts.Except("Unable to find column '{0}' in '{1}'", label, SchemaHelper.ToString(tr.Schema));
            }
            var ty = tr.Schema.GetColumnType(index);

            switch (ty.RawKind())
            {
            case DataKind.R4:
                // float is 0 based
                var tf = MinMaxLabel <float>(tr, index);
                nb = (int)tf.Item2 + 1;
                return(new Tuple <int, int>((int)tf.Item1, (int)tf.Item2));

            case DataKind.U1:
                // key is 1 based
                var tb = MinMaxLabel <byte>(tr, index);
                nb = tb.Item2;
                return(new Tuple <int, int>(1, (int)tb.Item2));

            case DataKind.U2:
                // key is 1 based
                var ts = MinMaxLabel <ushort>(tr, index);
                nb = ts.Item2;
                return(new Tuple <int, int>(1, (int)ts.Item2));

            case DataKind.U4:
                // key is 1 based
                var tu = MinMaxLabel <uint>(tr, index);
                nb = (int)tu.Item2;
                return(new Tuple <int, int>(1, (int)tu.Item2));

            default:
                throw Contracts.ExceptNotImpl("Type '{0}' not implemented", ty.RawKind());
            }
        }
Пример #6
0
 protected Tuple <TLabel, TLabel> MinMaxLabel <TLabel>(MultiToBinaryTransform tr, int index)
     where TLabel : IComparable <TLabel>
 {
     using (var cursor = tr.GetRowCursor(tr.Schema.Where(c => c.Index == index).ToArray()))
     {
         var    getter = cursor.GetGetter <TLabel>(SchemaHelper._dc(index, cursor));
         TLabel cl = default(TLabel), max = default(TLabel), min = default(TLabel);
         bool   first = true;
         while (cursor.MoveNext())
         {
             getter(ref cl);
             if (first || cl.CompareTo(max) == 1)
             {
                 max = cl;
             }
             if (first || cl.CompareTo(max) == -1)
             {
                 min = cl;
             }
             first = false;
         }
         return(new Tuple <TLabel, TLabel>(min, max));
     }
 }
Пример #7
0
        public static void TestMultiToBinaryTransformVector(MultiToBinaryTransform.MultiplicationAlgorithm algo, int max)
        {
            /*using (*/ var host = EnvHelper.NewTestEnvironment();
            {
                var inputs = new InputOutputU[] {
                    new InputOutputU()
                    {
                        X = new float[] { 0.1f, 1.1f }, Y = 0
                    },
                    new InputOutputU()
                    {
                        X = new float[] { 0.2f, 1.2f }, Y = 1
                    },
                    new InputOutputU()
                    {
                        X = new float[] { 0.3f, 1.3f }, Y = 2
                    }
                };

                var data = DataViewConstructionUtils.CreateFromEnumerable(host, inputs);

                var args = new MultiToBinaryTransform.Arguments {
                    label = "Y", algo = algo, maxMulti = max
                };
                var multiplied = new MultiToBinaryTransform(host, args, data);

                using (var cursor = multiplied.GetRowCursor(multiplied.Schema))
                {
                    var labelGetter            = cursor.GetGetter <uint>(SchemaHelper._dc(1, cursor));
                    var labelVectorGetter      = cursor.GetGetter <VBuffer <bool> >(SchemaHelper._dc(1, cursor));
                    var labelVectorFloatGetter = cursor.GetGetter <VBuffer <float> >(SchemaHelper._dc(1, cursor));
                    var binGetter = cursor.GetGetter <bool>(SchemaHelper._dc(2, cursor));
                    Contracts.CheckValue(binGetter, "Type mismatch.");
                    var  cont  = new List <Tuple <uint, bool> >();
                    bool bin   = false;
                    uint got   = 0;
                    var  gotv  = new VBuffer <bool>();
                    var  gotvf = new VBuffer <float>();
                    while (cursor.MoveNext())
                    {
                        labelGetter(ref got);
                        labelVectorGetter(ref gotv);
                        labelVectorFloatGetter(ref gotvf);
                        binGetter(ref bin);
                        cont.Add(new Tuple <uint, bool>(got, bin));
                        if (gotv.Length != 3)
                        {
                            throw new Exception("Bad dimension (Length)");
                        }
                        if (gotv.Count != 1)
                        {
                            throw new Exception("Bad dimension (Count)");
                        }
                        if (!gotv.Values[0])
                        {
                            throw new Exception("Bad value (Count)");
                        }
                        if (gotv.Indices[0] != got)
                        {
                            throw new Exception("Bad index (Count)");
                        }
                        var ar = gotv.DenseValues().ToArray();
                        if (ar.Length != 3)
                        {
                            throw new Exception("Bad dimension (dense)");
                        }

                        if (gotvf.Length != 3)
                        {
                            throw new Exception("Bad dimension (Length)f");
                        }
                        if (gotvf.Count != 1)
                        {
                            throw new Exception("Bad dimension (Count)f");
                        }
                        if (gotvf.Values[0] != 1)
                        {
                            throw new Exception("Bad value (Count)f");
                        }
                        if (gotvf.Indices[0] != got)
                        {
                            throw new Exception("Bad index (Count)f");
                        }
                        var ar2 = gotv.DenseValues().ToArray();
                        if (ar2.Length != 3)
                        {
                            throw new Exception("Bad dimension (dense)f");
                        }
                    }

                    if (max >= 3)
                    {
                        if (cont.Count != 9)
                        {
                            throw new Exception("It should be 9.");
                        }
                        if (algo == MultiToBinaryTransform.MultiplicationAlgorithm.Default)
                        {
                            for (int i = 0; i < 3; ++i)
                            {
                                var co = cont.Where(c => c.Item1 == (uint)i && c.Item2);
                                if (co.Count() != 1)
                                {
                                    throw new Exception(string.Format("Unexpected number of true labels for class {0} - algo={1} - max={2}", i, algo, max));
                                }
                            }
                        }
                    }
                    else
                    {
                        if (cont.Count != 3 * max)
                        {
                            throw new Exception(string.Format("It should be {0}.", 3 * max));
                        }
                    }
                }
            }
        }
Пример #8
0
        public static void TestMultiToBinaryTransform(MultiToBinaryTransform.MultiplicationAlgorithm algo, int max)
        {
            /*using (*/ var host = EnvHelper.NewTestEnvironment();
            {
                var inputs = new InputOutputU[] {
                    new InputOutputU()
                    {
                        X = new float[] { 0.1f, 1.1f }, Y = 0
                    },
                    new InputOutputU()
                    {
                        X = new float[] { 0.2f, 1.2f }, Y = 1
                    },
                    new InputOutputU()
                    {
                        X = new float[] { 0.3f, 1.3f }, Y = 2
                    }
                };

                var data = DataViewConstructionUtils.CreateFromEnumerable(host, inputs);

                var args = new MultiToBinaryTransform.Arguments {
                    label = "Y", algo = algo, maxMulti = max
                };
                var multiplied = new MultiToBinaryTransform(host, args, data);

                using (var cursor = multiplied.GetRowCursor(multiplied.Schema))
                {
                    var  labelGetter = cursor.GetGetter <uint>(SchemaHelper._dc(1, cursor));
                    var  binGetter   = cursor.GetGetter <bool>(SchemaHelper._dc(2, cursor));
                    var  cont        = new List <Tuple <uint, bool> >();
                    bool bin         = false;
                    while (cursor.MoveNext())
                    {
                        uint got = 0;
                        labelGetter(ref got);
                        binGetter(ref bin);
                        cont.Add(new Tuple <uint, bool>(got, bin));
                    }

                    if (max >= 3)
                    {
                        if (cont.Count != 9)
                        {
                            throw new Exception("It should be 9.");
                        }
                        if (algo == MultiToBinaryTransform.MultiplicationAlgorithm.Default)
                        {
                            for (int i = 0; i < 3; ++i)
                            {
                                var co = cont.Where(c => c.Item1 == (uint)i && c.Item2);
                                if (co.Count() != 1)
                                {
                                    throw new Exception(string.Format("Unexpected number of true labels for class {0} - algo={1} - max={2}", i, algo, max));
                                }
                            }
                        }
                    }
                    else
                    {
                        if (cont.Count != 3 * max)
                        {
                            throw new Exception(string.Format("It should be {0}.", 3 * max));
                        }
                    }
                }
            }
        }