private void GetLabels(Transposer trans, ColumnType labelType, int labelCol)
            {
                int min;
                int lim;
                var labels = default(VBuffer <int>);

                // Note: NAs have their own separate bin.
                if (labelType == NumberType.I4)
                {
                    var tmp = default(VBuffer <int>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinInts(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R4)
                {
                    var tmp = default(VBuffer <Single>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinSingles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R8)
                {
                    var tmp = default(VBuffer <Double>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinDoubles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType is BoolType)
                {
                    var tmp = default(VBuffer <bool>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinBools(in tmp, ref labels);
                    _numLabels = 3;
                    min        = -1;
                    lim        = 2;
                }
                else
                {
                    ulong labelKeyCount = labelType.GetKeyCount();
                    Contracts.Assert(labelKeyCount < Utils.ArrayMaxSize);
                    KeyLabelGetter <int> del = GetKeyLabels <int>;
                    var methodInfo           = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
                    var parameters           = new object[] { trans, labelCol, labelType };
                    _labels    = (VBuffer <int>)methodInfo.Invoke(this, parameters);
                    _numLabels = labelType.GetKeyCountAsInt32(_host) + 1;

                    // No need to densify or shift in this case.
                    return;
                }

                // Densify and shift labels.
                VBufferUtils.Densify(ref labels);
                Contracts.Assert(labels.IsDense);
                var labelsEditor = VBufferEditor.CreateFromBuffer(ref labels);

                for (int i = 0; i < labels.Length; i++)
                {
                    labelsEditor.Values[i] -= min;
                    Contracts.Assert(labelsEditor.Values[i] < _numLabels);
                }
                _labels = labelsEditor.Commit();
            }
            public float[][] GetScores(IDataView input, string labelColumnName, string[] columns, int numBins, int[] colSizes)
            {
                _numBins = numBins;
                var schema = input.Schema;
                var size   = columns.Length;

                if (!schema.TryGetColumnIndex(labelColumnName, out int labelCol))
                {
                    throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.LabelColumn),
                                              "Label column '{0}' not found", labelColumnName);
                }

                var labelType = schema[labelCol].Type;

                if (!IsValidColumnType(labelType))
                {
                    throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.LabelColumn),
                                              "Label column '{0}' does not have compatible type", labelColumnName);
                }

                var colSrcs = new int[size + 1];

                colSrcs[size] = labelCol;
                for (int i = 0; i < size; i++)
                {
                    var colName = columns[i];
                    if (!schema.TryGetColumnIndex(colName, out int colSrc))
                    {
                        throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.Columns),
                                                  "Source column '{0}' not found", colName);
                    }

                    var colType = schema[colSrc].Type;
                    if (colType is VectorType vectorType && !vectorType.IsKnownSize)
                    {
                        throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.Columns),
                                                  "Variable length column '{0}' is not allowed", colName);
                    }

                    if (!IsValidColumnType(colType.GetItemType()))
                    {
                        throw _host.ExceptUserArg(nameof(MutualInformationFeatureSelectingEstimator.Arguments.Columns),
                                                  "Column '{0}' of type '{1}' does not have compatible type.", colName, colType);
                    }

                    colSrcs[i]  = colSrc;
                    colSizes[i] = colType.GetValueCount();
                }

                var scores = new float[size][];

                using (var ch = _host.Start("Computing mutual information scores"))
                    using (var pch = _host.StartProgressChannel("Computing mutual information scores"))
                    {
                        using (var trans = Transposer.Create(_host, input, false, colSrcs))
                        {
                            int i      = 0;
                            var header = new ProgressHeader(new[] { "columns" });
                            var b      = trans.Schema.TryGetColumnIndex(labelColumnName, out labelCol);
                            Contracts.Assert(b);

                            GetLabels(trans, labelType, labelCol);
                            _contingencyTable = new int[_numLabels][];
                            _labelSums        = new int[_numLabels];
                            pch.SetHeader(header, e => e.SetProgress(0, i, size));
                            for (i = 0; i < size; i++)
                            {
                                b = trans.Schema.TryGetColumnIndex(columns[i], out int col);
                                Contracts.Assert(b);
                                ch.Trace("Computing scores for column '{0}'", columns[i]);
                                scores[i] = ComputeMutualInformation(trans, col);
#if DEBUG
                                ch.Trace("Scores for column '{0}': {1}", columns[i], string.Join(", ", scores[i]));
#endif
                                pch.Checkpoint(i + 1);
                            }
                        }
                    }

                return(scores);
            }
Example #3
0
        public void TransposerTest()
        {
            const int            rowCount = 1000;
            Random               rgen     = new Random(0);
            ArrayDataViewBuilder builder  = new ArrayDataViewBuilder(Env);

            // A is to check the splitting of a sparse-ish column.
            var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (DvInt4)rgen.Next(), 50, 5, 10, 15);

            dataA[rowCount / 2] = new VBuffer <DvInt4>(50, 0, null, null); // Coverage for the null vbuffer case.
            builder.AddColumn("A", NumberType.I4, dataA);
            // B is to check the splitting of a dense-ish column.
            builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49));
            // C is to just have some column we do nothing with.
            builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (DvInt2)1, 30, 3, 10, 24));
            // D is to check some column we don't have to split because it's sufficiently small.
            builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1));
            // E is to check a sparse scalar column.
            builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue)));
            // F is to check a dense-ish scalar column.
            builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (DvInt4)rgen.Next()));

            IDataView view = builder.GetDataView();

            // Do not force save. This will have a mix of passthrough and saved columns. Note that duplicate
            // specification of "D" to test that specifying a column twice has no ill effects.
            string[] names = { "B", "A", "E", "D", "F", "D" };
            using (Transposer trans = Transposer.Create(Env, view, false, names))
            {
                // Before checking the contents, check the names.
                for (int i = 0; i < names.Length; ++i)
                {
                    int index;
                    Assert.True(trans.Schema.TryGetColumnIndex(names[i], out index), $"Transpose schema couldn't find column '{names[i]}'");
                    int  trueIndex;
                    bool result = view.Schema.TryGetColumnIndex(names[i], out trueIndex);
                    Contracts.Assert(result);
                    Assert.True(trueIndex == index, $"Transpose schema had column '{names[i]}' at unexpected index");
                }
                // Check the contents
                Assert.Null(trans.TransposeSchema.GetSlotType(2)); // C check to see that it's not transposable.
                TransposeCheckHelper <DvInt4>(view, 0, trans);     // A check.
                TransposeCheckHelper <Double>(view, 1, trans);     // B check.
                TransposeCheckHelper <Double>(view, 3, trans);     // D check.
                TransposeCheckHelper <uint>(view, 4, trans);       // E check.
                TransposeCheckHelper <DvInt4>(view, 5, trans);     // F check.
            }

            // Force save. Recheck columns that would have previously been passthrough columns.
            // The primary benefit of this check is that we check the binary saving / loading
            // functionality of scalars which are otherwise always must necessarily be
            // passthrough. Also exercise the select by index functionality while we're at it.
            using (Transposer trans = Transposer.Create(Env, view, true, 3, 5, 4))
            {
                // Check to see that A, B, and C were not transposed somehow.
                Assert.Null(trans.TransposeSchema.GetSlotType(0));
                Assert.Null(trans.TransposeSchema.GetSlotType(1));
                Assert.Null(trans.TransposeSchema.GetSlotType(2));
                TransposeCheckHelper <Double>(view, 3, trans); // D check.
                TransposeCheckHelper <uint>(view, 4, trans);   // E check.
                TransposeCheckHelper <DvInt4>(view, 5, trans); // F check.
            }
        }
Example #4
0
            private void GetLabels(Transposer trans, ColumnType labelType, int labelCol)
            {
                int min;
                int lim;
                var labels = default(VBuffer <int>);

                // Note: NAs have their own separate bin.
                if (labelType == NumberType.I4)
                {
                    var tmp = default(VBuffer <DvInt4>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinInts(ref tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R4)
                {
                    var tmp = default(VBuffer <Single>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinSingles(ref tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberType.R8)
                {
                    var tmp = default(VBuffer <Double>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinDoubles(ref tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType.IsBool)
                {
                    var tmp = default(VBuffer <DvBool>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinBools(ref tmp, ref labels);
                    _numLabels = 3;
                    min        = -1;
                    lim        = 2;
                }
                else
                {
                    Contracts.Assert(0 < labelType.KeyCount && labelType.KeyCount < Utils.ArrayMaxSize);
                    KeyLabelGetter <int> del = GetKeyLabels <int>;
                    var methodInfo           = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
                    var parameters           = new object[] { trans, labelCol, labelType };
                    _labels    = (int[])methodInfo.Invoke(this, parameters);
                    _numLabels = labelType.KeyCount + 1;

                    // No need to densify or shift in this case.
                    return;
                }

                // Densify and shift labels.
                VBufferUtils.Densify(ref labels);
                Contracts.Assert(labels.IsDense);
                _labels = labels.Values;
                if (labels.Length < _labels.Length)
                {
                    Array.Resize(ref _labels, labels.Length);
                }
                for (int i = 0; i < _labels.Length; i++)
                {
                    _labels[i] -= min;
                    Contracts.Assert(_labels[i] < _numLabels);
                }
            }
Example #5
0
 public static Note operator --(Note note)
 {
     return(Transposer.Transpose(note, 1, Direction.Bellow));
 }
Example #6
0
 public static Note operator ++(Note note)
 {
     return(Transposer.Transpose(note, 1, Direction.Above));
 }