/// <summary> /// We insert a MultiToBinaryTransform in the pipeline. /// </summary> /// <param name="data">input data</param> /// <param name="dstName">name of the new label</param> /// <param name="ch">channel</param> /// <param name="labName">label name as a single column</param> /// <param name="count">indication of the number of expected classes</param> /// <param name="train">pipeline for training or not</param> /// <param name="args">arguments send to the trainer</param> /// <returns>created view</returns> protected MultiToBinaryTransform MapLabelsAndInsertTransform(IChannel ch, RoleMappedData data, out string dstName, out string labName, int count, bool train, Arguments args) { var lab = data.Schema.Label.Value; Host.Assert(!data.Schema.Schema[lab.Index].IsHidden); Host.Assert(lab.Type.GetKeyCount() > 0 || lab.Type == NumberDataViewType.Single); IDataView source = data.Data; if (train) { source = FilterNA(source, lab.Name, args.dropNALabel); if (lab.Type.IsKey()) { var uargs = new ULabelToR4LabelTransform.Arguments { columns = new Column1x1[] { new Column1x1() { Source = lab.Name, Name = lab.Name } } }; source = new ULabelToR4LabelTransform(Host, uargs, source); } } // Get the destination label column name. ch.Info("Multiplying rows label '{0}'", lab.Name); switch (args.algo) { case MultiToBinaryTransform.MultiplicationAlgorithm.Default: case MultiToBinaryTransform.MultiplicationAlgorithm.Reweight: dstName = source.Schema.GetTempColumnName() + "BL"; break; case MultiToBinaryTransform.MultiplicationAlgorithm.Ranking: dstName = source.Schema.GetTempColumnName() + "U4"; break; default: throw Host.ExceptNotSupp("Not supported algorithm {0}", args.algo); } var args2 = new MultiToBinaryTransform.Arguments { label = lab.Name, newColumn = dstName, algo = args.algo, weight = args.weight, maxMulti = args.maxMulti, seed = args.seed, numThreads = args.numThreads, }; labName = lab.Name; var tr = new MultiToBinaryTransform(Host, args2, source); return(tr); }
public static void TestMultiToBinaryTransformVector(MultiToBinaryTransform.MultiplicationAlgorithm algo, int max) { /*using (*/ var host = EnvHelper.NewTestEnvironment(); { var inputs = new InputOutputU[] { new InputOutputU() { X = new float[] { 0.1f, 1.1f }, Y = 0 }, new InputOutputU() { X = new float[] { 0.2f, 1.2f }, Y = 1 }, new InputOutputU() { X = new float[] { 0.3f, 1.3f }, Y = 2 } }; var data = DataViewConstructionUtils.CreateFromEnumerable(host, inputs); var args = new MultiToBinaryTransform.Arguments { label = "Y", algo = algo, maxMulti = max }; var multiplied = new MultiToBinaryTransform(host, args, data); using (var cursor = multiplied.GetRowCursor(multiplied.Schema)) { var labelGetter = cursor.GetGetter <uint>(SchemaHelper._dc(1, cursor)); var labelVectorGetter = cursor.GetGetter <VBuffer <bool> >(SchemaHelper._dc(1, cursor)); var labelVectorFloatGetter = cursor.GetGetter <VBuffer <float> >(SchemaHelper._dc(1, cursor)); var binGetter = cursor.GetGetter <bool>(SchemaHelper._dc(2, cursor)); Contracts.CheckValue(binGetter, "Type mismatch."); var cont = new List <Tuple <uint, bool> >(); bool bin = false; uint got = 0; var gotv = new VBuffer <bool>(); var gotvf = new VBuffer <float>(); while (cursor.MoveNext()) { labelGetter(ref got); labelVectorGetter(ref gotv); labelVectorFloatGetter(ref gotvf); binGetter(ref bin); cont.Add(new Tuple <uint, bool>(got, bin)); if (gotv.Length != 3) { throw new Exception("Bad dimension (Length)"); } if (gotv.Count != 1) { throw new Exception("Bad dimension (Count)"); } if (!gotv.Values[0]) { throw new Exception("Bad value (Count)"); } if (gotv.Indices[0] != got) { throw new Exception("Bad index (Count)"); } var ar = gotv.DenseValues().ToArray(); if (ar.Length != 3) { throw new Exception("Bad dimension (dense)"); } if (gotvf.Length != 3) { throw new Exception("Bad dimension (Length)f"); } if (gotvf.Count != 1) { throw new Exception("Bad dimension (Count)f"); } if (gotvf.Values[0] != 1) { throw new Exception("Bad value (Count)f"); } if (gotvf.Indices[0] != got) { throw new Exception("Bad index (Count)f"); } var ar2 = gotv.DenseValues().ToArray(); if (ar2.Length != 3) { throw new Exception("Bad dimension (dense)f"); } } if (max >= 3) { if (cont.Count != 9) { throw new Exception("It should be 9."); } if (algo == MultiToBinaryTransform.MultiplicationAlgorithm.Default) { for (int i = 0; i < 3; ++i) { var co = cont.Where(c => c.Item1 == (uint)i && c.Item2); if (co.Count() != 1) { throw new Exception(string.Format("Unexpected number of true labels for class {0} - algo={1} - max={2}", i, algo, max)); } } } } else { if (cont.Count != 3 * max) { throw new Exception(string.Format("It should be {0}.", 3 * max)); } } } } }
public static void TestMultiToBinaryTransform(MultiToBinaryTransform.MultiplicationAlgorithm algo, int max) { /*using (*/ var host = EnvHelper.NewTestEnvironment(); { var inputs = new InputOutputU[] { new InputOutputU() { X = new float[] { 0.1f, 1.1f }, Y = 0 }, new InputOutputU() { X = new float[] { 0.2f, 1.2f }, Y = 1 }, new InputOutputU() { X = new float[] { 0.3f, 1.3f }, Y = 2 } }; var data = DataViewConstructionUtils.CreateFromEnumerable(host, inputs); var args = new MultiToBinaryTransform.Arguments { label = "Y", algo = algo, maxMulti = max }; var multiplied = new MultiToBinaryTransform(host, args, data); using (var cursor = multiplied.GetRowCursor(multiplied.Schema)) { var labelGetter = cursor.GetGetter <uint>(SchemaHelper._dc(1, cursor)); var binGetter = cursor.GetGetter <bool>(SchemaHelper._dc(2, cursor)); var cont = new List <Tuple <uint, bool> >(); bool bin = false; while (cursor.MoveNext()) { uint got = 0; labelGetter(ref got); binGetter(ref bin); cont.Add(new Tuple <uint, bool>(got, bin)); } if (max >= 3) { if (cont.Count != 9) { throw new Exception("It should be 9."); } if (algo == MultiToBinaryTransform.MultiplicationAlgorithm.Default) { for (int i = 0; i < 3; ++i) { var co = cont.Where(c => c.Item1 == (uint)i && c.Item2); if (co.Count() != 1) { throw new Exception(string.Format("Unexpected number of true labels for class {0} - algo={1} - max={2}", i, algo, max)); } } } } else { if (cont.Count != 3 * max) { throw new Exception(string.Format("It should be {0}.", 3 * max)); } } } } }