Exemple #1
0
        private static InferenceResult InferTextFileColumnTypesCore(IHostEnvironment env, IMultiStreamSource fileSource, Arguments args, IChannel ch)
        {
            Contracts.AssertValue(ch);
            ch.AssertValue(env);
            ch.AssertValue(fileSource);
            ch.AssertValue(args);

            if (args.ColumnCount == 0)
            {
                ch.Error("Too many empty columns for automatic inference.");
                return(InferenceResult.Fail());
            }

            if (args.ColumnCount >= SmartColumnsLim)
            {
                ch.Error("Too many columns for automatic inference.");
                return(InferenceResult.Fail());
            }

            // Read the file as the specified number of text columns.
            var textLoaderArgs = new TextLoader.Arguments
            {
                Column       = new[] { TextLoader.Column.Parse(string.Format("C:TX:0-{0}", args.ColumnCount - 1)) },
                Separator    = args.Separator,
                AllowSparse  = args.AllowSparse,
                AllowQuoting = args.AllowQuote,
            };
            var textLoader = new TextLoader(env, textLoaderArgs, fileSource);
            var idv        = textLoader.Take(args.MaxRowsToRead);

            // Read all the data into memory.
            // List items are rows of the dataset.
            var data = new List <DvText[]>();

            using (var cursor = idv.GetRowCursor(col => true))
            {
                int  columnIndex;
                bool found = cursor.Schema.TryGetColumnIndex("C", out columnIndex);
                Contracts.Assert(found);
                var colType = cursor.Schema.GetColumnType(columnIndex);
                Contracts.Assert(colType.ItemType.IsText);
                ValueGetter <VBuffer <DvText> > vecGetter = null;
                ValueGetter <DvText>            oneGetter = null;
                bool isVector = colType.IsVector;
                if (isVector)
                {
                    vecGetter = cursor.GetGetter <VBuffer <DvText> >(columnIndex);
                }
                else
                {
                    Contracts.Assert(args.ColumnCount == 1);
                    oneGetter = cursor.GetGetter <DvText>(columnIndex);
                }

                VBuffer <DvText> line    = default(VBuffer <DvText>);
                DvText           tsValue = default(DvText);
                while (cursor.MoveNext())
                {
                    if (isVector)
                    {
                        vecGetter(ref line);
                        Contracts.Assert(line.Length == args.ColumnCount);
                        var values = new DvText[args.ColumnCount];
                        line.CopyTo(values);
                        data.Add(values);
                    }
                    else
                    {
                        oneGetter(ref tsValue);
                        var values = new[] { tsValue };
                        data.Add(values);
                    }
                }
            }

            if (data.Count < 2)
            {
                ch.Error("Too few rows ({0}) for automatic inference.", data.Count);
                return(InferenceResult.Fail());
            }

            var cols = new IntermediateColumn[args.ColumnCount];

            for (int i = 0; i < args.ColumnCount; i++)
            {
                cols[i] = new IntermediateColumn(data.Select(x => x[i]).ToArray(), i);
            }

            foreach (var expert in GetExperts())
            {
                expert.Apply(cols);
            }

            Contracts.Check(cols.All(x => x.SuggestedType != null), "Column type inference must be conclusive");

            // Aggregating header signals.
            int suspect   = 0;
            var usedNames = new HashSet <string>();

            for (int i = 0; i < args.ColumnCount; i++)
            {
                if (cols[i].HasHeader == true)
                {
                    if (usedNames.Add(cols[i].RawData[0].ToString()))
                    {
                        suspect++;
                    }
                    else
                    {
                        // duplicate value in the first column is a strong signal that this is not a header
                        suspect -= args.ColumnCount;
                    }
                }
                else if (cols[i].HasHeader == false)
                {
                    suspect--;
                }
            }

            // REVIEW: Why not use this for column names as well?
            TextLoader.Arguments fileArgs;
            bool hasHeader;

            if (TextLoader.FileContainsValidSchema(env, fileSource, out fileArgs))
            {
                hasHeader = fileArgs.HasHeader;
            }
            else
            {
                hasHeader = suspect > 0;
            }

            // suggest names
            var names = new List <string>();

            usedNames.Clear();
            foreach (var col in cols)
            {
                string name0;
                string name;
                name0 = name = SuggestName(col, hasHeader);
                int i = 0;
                while (!usedNames.Add(name))
                {
                    name = string.Format("{0}_{1:00}", name0, i++);
                }
                names.Add(name);
            }
            var outCols =
                cols.Select((x, i) => new Column(x.ColumnId, names[i], x.SuggestedType)).ToArray();

            var numerics = outCols.Count(x => x.ItemType.IsNumber);

            ch.Info("Detected {0} numeric and {1} text columns.", numerics, outCols.Length - numerics);
            if (hasHeader)
            {
                ch.Info("Generated column names from the file header.");
            }

            return(InferenceResult.Success(outCols, hasHeader, cols.Select(col => col.RawData).ToArray()));
        }
Exemple #2
0
        public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable <IRunResult> previousRuns = null)
        {
            int numSweeps = Math.Min(maxSweeps, _dim + 1 - _simplexVertices.Count);

            if (previousRuns == null)
            {
                return(_initSweeper.ProposeSweeps(numSweeps, previousRuns));
            }

            foreach (var run in previousRuns)
            {
                Contracts.Check(run != null);
            }

            foreach (var run in previousRuns)
            {
                if (_simplexVertices.Count == _dim + 1)
                {
                    break;
                }

                if (!_simplexVertices.ContainsKey(run))
                {
                    _simplexVertices.Add(run, ParameterSetAsFloatArray(run.ParameterSet));
                }

                if (_simplexVertices.Count == _dim + 1)
                {
                    ComputeExtremes();
                }
            }

            if (_simplexVertices.Count < _dim + 1)
            {
                numSweeps = Math.Min(maxSweeps, _dim + 1 - _simplexVertices.Count);
                return(_initSweeper.ProposeSweeps(numSweeps, previousRuns));
            }

            switch (_stage)
            {
            case OptimizationStage.NeedReflectionPoint:
                _pendingSweeps.Clear();
                var nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaReflection);
                if (OutOfBounds(nextPoint) && _args.ProjectInbounds)
                {
                    // if the reflection point is out of bounds, get the inner contraction point.
                    nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction);
                    _stage    = OptimizationStage.WaitingForInnerContractionResult;
                }
                else
                {
                    _stage = OptimizationStage.WaitingForReflectionResult;
                }
                _pendingSweeps.Add(new KeyValuePair <ParameterSet, Float[]>(FloatArrayAsParameterSet(nextPoint), nextPoint));
                if (previousRuns.Any(runResult => runResult.ParameterSet.Equals(_pendingSweeps[0].Key)))
                {
                    _stage = OptimizationStage.WaitingForReductionResult;
                    _pendingSweeps.Clear();
                    if (!TryGetReductionPoints(maxSweeps, previousRuns))
                    {
                        _stage = OptimizationStage.Done;
                        return(null);
                    }
                    return(_pendingSweeps.Select(kvp => kvp.Key).ToArray());
                }
                return(new ParameterSet[] { _pendingSweeps[0].Key });

            case OptimizationStage.WaitingForReflectionResult:
                Contracts.Assert(_pendingSweeps.Count == 1);
                _lastReflectionResult = FindRunResult(previousRuns)[0];
                if (_secondWorst.Key.CompareTo(_lastReflectionResult.Key) < 0 && _lastReflectionResult.Key.CompareTo(_best.Key) <= 0)
                {
                    // the reflection result is better than the second worse, but not better than the best
                    UpdateSimplex(_lastReflectionResult.Key, _lastReflectionResult.Value);
                    goto case OptimizationStage.NeedReflectionPoint;
                }

                if (_lastReflectionResult.Key.CompareTo(_best.Key) > 0)
                {
                    // the reflection result is the best so far
                    nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaExpansion);
                    if (OutOfBounds(nextPoint) && _args.ProjectInbounds)
                    {
                        // if the expansion point is out of bounds, get the inner contraction point.
                        nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction);
                        _stage    = OptimizationStage.WaitingForInnerContractionResult;
                    }
                    else
                    {
                        _stage = OptimizationStage.WaitingForExpansionResult;
                    }
                }
                else if (_lastReflectionResult.Key.CompareTo(_worst.Key) > 0)
                {
                    // other wise, get results for the outer contraction point.
                    nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaOutsideContraction);
                    _stage    = OptimizationStage.WaitingForOuterContractionResult;
                }
                else
                {
                    // other wise, reflection result is not better than worst, get results for the inner contraction point
                    nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction);
                    _stage    = OptimizationStage.WaitingForInnerContractionResult;
                }
                _pendingSweeps.Clear();
                _pendingSweeps.Add(new KeyValuePair <ParameterSet, Float[]>(FloatArrayAsParameterSet(nextPoint), nextPoint));
                if (previousRuns.Any(runResult => runResult.ParameterSet.Equals(_pendingSweeps[0].Key)))
                {
                    _stage = OptimizationStage.WaitingForReductionResult;
                    _pendingSweeps.Clear();
                    if (!TryGetReductionPoints(maxSweeps, previousRuns))
                    {
                        _stage = OptimizationStage.Done;
                        return(null);
                    }
                    return(_pendingSweeps.Select(kvp => kvp.Key).ToArray());
                }
                return(new ParameterSet[] { _pendingSweeps[0].Key });

            case OptimizationStage.WaitingForExpansionResult:
                Contracts.Assert(_pendingSweeps.Count == 1);
                var expansionResult = FindRunResult(previousRuns)[0].Key;
                if (expansionResult.CompareTo(_lastReflectionResult.Key) > 0)
                {
                    // expansion point is better than reflection point
                    UpdateSimplex(expansionResult, _pendingSweeps[0].Value);
                    goto case OptimizationStage.NeedReflectionPoint;
                }
                // reflection point is better than expansion point
                UpdateSimplex(_lastReflectionResult.Key, _lastReflectionResult.Value);
                goto case OptimizationStage.NeedReflectionPoint;

            case OptimizationStage.WaitingForOuterContractionResult:
                Contracts.Assert(_pendingSweeps.Count == 1);
                var outerContractionResult = FindRunResult(previousRuns)[0].Key;
                if (outerContractionResult.CompareTo(_lastReflectionResult.Key) > 0)
                {
                    // outer contraction point is better than reflection point
                    UpdateSimplex(outerContractionResult, _pendingSweeps[0].Value);
                    goto case OptimizationStage.NeedReflectionPoint;
                }
                // get the reduction points
                _stage = OptimizationStage.WaitingForReductionResult;
                _pendingSweeps.Clear();
                if (!TryGetReductionPoints(maxSweeps, previousRuns))
                {
                    _stage = OptimizationStage.Done;
                    return(null);
                }
                return(_pendingSweeps.Select(kvp => kvp.Key).ToArray());

            case OptimizationStage.WaitingForInnerContractionResult:
                Contracts.Assert(_pendingSweeps.Count == 1);
                var innerContractionResult = FindRunResult(previousRuns)[0].Key;
                if (innerContractionResult.CompareTo(_worst.Key) > 0)
                {
                    // inner contraction point is better than worst point
                    UpdateSimplex(innerContractionResult, _pendingSweeps[0].Value);
                    goto case OptimizationStage.NeedReflectionPoint;
                }
                // get the reduction points
                _stage = OptimizationStage.WaitingForReductionResult;
                _pendingSweeps.Clear();
                if (!TryGetReductionPoints(maxSweeps, previousRuns))
                {
                    _stage = OptimizationStage.Done;
                    return(null);
                }
                return(_pendingSweeps.Select(kvp => kvp.Key).ToArray());

            case OptimizationStage.WaitingForReductionResult:
                Contracts.Assert(_pendingSweeps.Count + _pendingSweepsNotSubmitted.Count == _dim);
                if (_pendingSweeps.Count < _dim)
                {
                    return(SubmitMoreReductionPoints(maxSweeps));
                }
                ReplaceSimplexVertices(previousRuns);

                // if the diameter of the new simplex has become too small, stop sweeping.
                if (SimplexDiameter() < _args.StoppingSimplexDiameter)
                {
                    return(null);
                }

                goto case OptimizationStage.NeedReflectionPoint;

            case OptimizationStage.Done:
            default:
                return(null);
            }
        }
        private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiStreamSource source, bool skipStrictValidation, out ColumnSplitResult result)
        {
            result = default(ColumnSplitResult);
            try
            {
                // No need to provide information from unsuccessful loader, so we create temporary environment and get information from it in case of success
                using (var loaderEnv = new ConsoleEnvironment(0, verbose: true))
                {
                    var messages = new ConcurrentBag <ChannelMessage>();
                    loaderEnv.AddListener <ChannelMessage>(
                        (src, msg) =>
                    {
                        messages.Add(msg);
                    });
                    var  idv          = TextLoader.ReadFile(loaderEnv, args, source).Take(1000);
                    var  columnCounts = new List <int>();
                    int  columnIndex;
                    bool found = idv.Schema.TryGetColumnIndex("C", out columnIndex);
                    ch.Assert(found);

                    using (var cursor = idv.GetRowCursor(x => x == columnIndex))
                    {
                        var getter = cursor.GetGetter <VBuffer <ReadOnlyMemory <char> > >(columnIndex);

                        VBuffer <ReadOnlyMemory <char> > line = default;
                        while (cursor.MoveNext())
                        {
                            getter(ref line);
                            columnCounts.Add(line.Length);
                        }
                    }

                    Contracts.Check(columnCounts.Count > 0);
                    var mostCommon = columnCounts.GroupBy(x => x).OrderByDescending(x => x.Count()).First();
                    if (!skipStrictValidation && mostCommon.Count() < UniformColumnCountThreshold * columnCounts.Count)
                    {
                        return(false);
                    }

                    // If user explicitly specified separator we're allowing "single" column case;
                    // Otherwise user will see message informing that we were not able to detect any columns.
                    if (!skipStrictValidation && mostCommon.Key <= 1)
                    {
                        return(false);
                    }

                    result = new ColumnSplitResult(true, args.Separator, args.AllowQuoting, args.AllowSparse, mostCommon.Key);
                    ch.Trace("Discovered {0} columns using separator '{1}'", mostCommon.Key, args.Separator);
                    foreach (var msg in messages)
                    {
                        ch.Send(msg);
                    }
                    return(true);
                }
            }
            catch (Exception ex)
            {
                if (!ex.IsMarked())
                {
                    throw;
                }
                // For known exceptions, we just continue to the next separator candidate.
            }
            return(false);
        }
Exemple #4
0
 public void GetValue(ref VBuffer <T> dst)
 {
     Contracts.Check(Cursor.IsGood);
     Src.CopyTo(ref dst);
 }
 /// <summary>
 /// This function converts <paramref name="kind"/> to <see cref="DataKind"/>.
 /// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, we should check if <paramref name="kind"/>
 /// can be found in <see cref="DataKind"/>.
 /// </summary>
 public static DataKind ToDataKind(this InternalDataKind kind)
 {
     Contracts.Check(kind != InternalDataKind.UG);
     return((DataKind)kind);
 }
 public override void Combine(ref ReadOnlyMemory <char> dst, ReadOnlyMemory <char> src)
 {
     Contracts.Check(IsDefault(dst));
     dst = src;
 }
Exemple #7
0
 public override ValueGetter <T> GetGetter()
 {
     Contracts.Check(IsActive, "column is not active");
     return(_getter);
 }
 public static ScikitSubComponent Create(Type type)
 {
     Contracts.Check(type != null && typeof(ScikitSubComponent).IsAssignableFrom(type));
     return((ScikitSubComponent)Activator.CreateInstance(type));
 }
Exemple #9
0
 public GetterInfoPrimitive(string kind, ColumnType type, TValue value)
     : base(kind, type)
 {
     Contracts.Check(type.RawType == typeof(TValue), "Incompatible types");
     Value = value;
 }
Exemple #10
0
 public DiscreteValueGenerator(DiscreteParamArguments args)
 {
     Contracts.Check(args.Values.Length > 0);
     _args = args;
 }
 public DiscreteValueGenerator(DiscreteParamOptions options)
 {
     Contracts.Check(options.Values.Length > 0);
     _options = options;
 }
 public void AllocateDataMemory(int docNum, long corpusSize)
 {
     Contracts.Check(docNum >= 0);
     Contracts.Check(corpusSize >= 0);
     LdaInterface.AllocateDataMemory(_engine, docNum, corpusSize);
 }
Exemple #13
0
 public string GetPathOrNull(int index)
 {
     Contracts.Check(index == 0);
     return(null);
 }
Exemple #14
0
        /// <summary>
        /// Predict with data.
        /// This function uses a modified API which does not use caches.
        /// </summary>
        /// <param name="vbuf">one row</param>
        /// <param name="predictedValues">Results of the prediction</param>
        /// <param name="internalBuffer">buffers allocated by Microsoft.ML and given to XGBoost to avoid XGBoost to allocated caches on its own</param>
        /// <param name="outputMargin">Whether to output the raw untransformed margin value.</param>
        /// <param name="ntreeLimit">Limit number of trees in the prediction; defaults to 0 (use all trees).</param>
        public void PredictOneOff(ref VBuffer <Float> vbuf, ref VBuffer <Float> predictedValues,
                                  ref XGBoostTreeBuffer internalBuffer, bool outputMargin = true, int ntreeLimit = 0)
        {
            // REVIEW xadupre: XGBoost can produce an output per tree (pred_leaf=true)
            // When this option is on, the output will be a matrix of (nsample, ntrees)
            // with each record indicating the predicted leaf index of each sample in each tree.
            // Note that the leaf index of a tree is unique per tree, so you may find leaf 1
            // in both tree 1 and tree 0.
            // if (pred_leaf)
            //    option_mask |= 0x02;
            // This might be an interesting feature to implement.

            int optionMask = 0x00;

            if (outputMargin)
            {
                optionMask |= 0x01;
            }

            Contracts.Check(internalBuffer != null);

            uint length       = 0;
            uint lengthBuffer = 0;
            uint nb           = (uint)vbuf.Count;

            // This function relies on a modified API. Instead of letting XGBoost handle its own caches,
            // the function calls XGBoosterPredictOutputSize to know what cache size is required.
            // Microsoft.ML allocated the caches and gives them to XGBoost.
            // First, we allocated the cache for the features. Only then XGBoost
            // will be able to known the required cache size.
#if (XGB_EXTENDED)
            internalBuffer.ResizeEntries(nb, vbuf.Length);
#else
            internalBuffer.ResizeEntries(nb);
#endif

            unsafe
            {
                fixed(float *p = vbuf.Values)
                fixed(int *i        = vbuf.Indices)
                fixed(byte *entries = internalBuffer.XGBoostEntries)
                {
                    WrappedXGBoostInterface.XGBoosterCopyEntries((IntPtr)entries, ref nb, p, vbuf.IsDense ? null : i, float.NaN);
                    WrappedXGBoostInterface.XGBoosterPredictOutputSize(_handle,
                                                                       (IntPtr)entries, nb, optionMask, (uint)ntreeLimit, ref length, ref lengthBuffer);
                }
            }

            // Then we allocated the cache for the prediction.
            internalBuffer.ResizeOutputs(length, lengthBuffer, ref predictedValues);

            unsafe
            {
                fixed(byte *entries = internalBuffer.XGBoostEntries)
                fixed(float *ppreds      = predictedValues.Values)
                fixed(float *ppredBuffer = internalBuffer.PredBuffer)
                fixed(uint *ppredCounter = internalBuffer.PredCounter)
                {
                    WrappedXGBoostInterface.XGBoosterPredictNoInsideCache(_handle,
                                                                          (IntPtr)entries, nb, optionMask, (uint)ntreeLimit, length, lengthBuffer, ppreds, ppredBuffer, ppredCounter
#if (XGB_EXTENDED)
                                                                          , internalBuffer.RegTreeFVec
#endif
                                                                          );
                }
            }
        }
Exemple #15
0
        public TreeEnsemble GetModel(int[] categoricalFeatureBoudaries)
        {
            TreeEnsemble res         = new TreeEnsemble();
            string       modelString = GetModelString();

            string[] lines = modelString.Split('\n');
            int      i     = 0;

            for (; i < lines.Length;)
            {
                if (lines[i].StartsWith("Tree="))
                {
                    Dictionary <string, string> kvPairs = new Dictionary <string, string>();
                    ++i;
                    while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0)
                    {
                        string[] kv = lines[i].Split('=');
                        Contracts.Check(kv.Length == 2);
                        kvPairs[kv[0].Trim()] = kv[1].Trim();
                        ++i;
                    }
                    int numLeaves = int.Parse(kvPairs["num_leaves"]);
                    int numCat    = int.Parse(kvPairs["num_cat"]);
                    if (numLeaves > 1)
                    {
                        var leftChild                = Str2IntArray(kvPairs["left_child"], ' ');
                        var rightChild               = Str2IntArray(kvPairs["right_child"], ' ');
                        var splitFeature             = Str2IntArray(kvPairs["split_feature"], ' ');
                        var threshold                = Str2DoubleArray(kvPairs["threshold"], ' ');
                        var splitGain                = Str2DoubleArray(kvPairs["split_gain"], ' ');
                        var leafOutput               = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        var decisionType             = Str2UIntArray(kvPairs["decision_type"], ' ');
                        var defaultValue             = GetDefalutValue(threshold, decisionType);
                        var categoricalSplitFeatures = new int[numLeaves - 1][];
                        var categoricalSplit         = new bool[numLeaves - 1];
                        if (categoricalFeatureBoudaries != null)
                        {
                            // Add offsets to split features.
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                splitFeature[node] = categoricalFeatureBoudaries[splitFeature[node]];
                            }
                        }

                        if (numCat > 0)
                        {
                            var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' ');
                            var catThreshold  = Str2UIntArray(kvPairs["cat_threshold"], ' ');
                            for (int node = 0; node < numLeaves - 1; ++node)
                            {
                                if (GetIsCategoricalSplit(decisionType[node]))
                                {
                                    int catIdx = (int)threshold[node];
                                    var cats   = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]);
                                    categoricalSplitFeatures[node] = new int[cats.Length];
                                    // Convert Cat thresholds to feature indices.
                                    for (int j = 0; j < cats.Length; ++j)
                                    {
                                        categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1;
                                    }

                                    splitFeature[node]     = -1;
                                    categoricalSplit[node] = true;
                                    // Swap left and right child.
                                    int t = leftChild[node];
                                    leftChild[node]  = rightChild[node];
                                    rightChild[node] = t;
                                }
                                else
                                {
                                    categoricalSplit[node] = false;
                                }
                            }
                        }
                        RegressionTree tree = RegressionTree.Create(numLeaves, splitFeature, splitGain,
                                                                    threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput,
                                                                    categoricalSplitFeatures, categoricalSplit);
                        res.AddTree(tree);
                    }
                    else
                    {
                        RegressionTree tree       = new RegressionTree(2);
                        var            leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' ');
                        if (leafOutput[0] != 0)
                        {
                            // Convert Constant tree to Two-leaf tree, avoid being filter by TLC.
                            var categoricalSplitFeatures = new int[1][];
                            var categoricalSplit         = new bool[1];
                            tree = RegressionTree.Create(2, new int[] { 0 }, new double[] { 0 },
                                                         new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] },
                                                         categoricalSplitFeatures, categoricalSplit);
                        }
                        res.AddTree(tree);
                    }
                }
                else
                {
                    ++i;
                }
            }
            return(res);
        }
Exemple #16
0
        //Project the covariance matrix A on to Omega: Y <- A * Omega
        //A = X' * X / n, where X = data - mean
        //Note that the covariance matrix is not computed explicitly
        private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, Float[][][] y, TransformInfo[] transformInfos)
        {
            Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == Infos.Length);
            for (int i = 0; i < omega.Length; i++)
            {
                Contracts.Assert(omega[i].Length == y[i].Length);
            }

            // set y to be all zeros
            for (int iinfo = 0; iinfo < y.Length; iinfo++)
            {
                for (int i = 0; i < y[iinfo].Length; i++)
                {
                    Array.Clear(y[iinfo][i], 0, y[iinfo][i].Length);
                }
            }

            bool[] center = Enumerable.Range(0, mean.Length).Select(i => mean[i] != null).ToArray();

            Double[] totalColWeight = new Double[Infos.Length];

            bool[] activeColumns = new bool[Source.Schema.ColumnCount];
            for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
            {
                activeColumns[Infos[iinfo].Source] = true;
                if (_weightColumnIndex[iinfo] >= 0)
                {
                    activeColumns[_weightColumnIndex[iinfo]] = true;
                }
            }
            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                var weightGetters = new ValueGetter <Float> [Infos.Length];
                var columnGetters = new ValueGetter <VBuffer <Float> > [Infos.Length];
                for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                {
                    if (_weightColumnIndex[iinfo] >= 0)
                    {
                        weightGetters[iinfo] = cursor.GetGetter <Float>(_weightColumnIndex[iinfo]);
                    }
                    columnGetters[iinfo] = cursor.GetGetter <VBuffer <Float> >(Infos[iinfo].Source);
                }

                var features = default(VBuffer <Float>);
                while (cursor.MoveNext())
                {
                    for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                    {
                        Contracts.Check(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsNumber,
                                        "PCA transform can only be performed on numeric columns of dimension > 1");

                        Float weight = 1;
                        if (weightGetters[iinfo] != null)
                        {
                            weightGetters[iinfo](ref weight);
                        }
                        columnGetters[iinfo](ref features);

                        if (FloatUtils.IsFinite(weight) && weight >= 0 && (features.Count == 0 || FloatUtils.IsFinite(features.Values, features.Count)))
                        {
                            totalColWeight[iinfo] += weight;

                            if (center[iinfo])
                            {
                                VectorUtils.AddMult(ref features, mean[iinfo], weight);
                            }

                            for (int i = 0; i < omega[iinfo].Length; i++)
                            {
                                VectorUtils.AddMult(ref features, y[iinfo][i], weight * VectorUtils.DotProductWithOffset(omega[iinfo][i], 0, ref features));
                            }
                        }
                    }
                }

                for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                {
                    if (totalColWeight[iinfo] <= 0)
                    {
                        throw Host.Except("Empty data in column '{0}'", Source.Schema.GetColumnName(Infos[iinfo].Source));
                    }
                }

                for (int iinfo = 0; iinfo < Infos.Length; iinfo++)
                {
                    var invn = (Float)(1 / totalColWeight[iinfo]);

                    for (var i = 0; i < omega[iinfo].Length; ++i)
                    {
                        VectorUtils.ScaleBy(y[iinfo][i], invn);
                    }

                    if (center[iinfo])
                    {
                        VectorUtils.ScaleBy(mean[iinfo], invn);
                        for (int i = 0; i < omega[iinfo].Length; i++)
                        {
                            VectorUtils.AddMult(mean[iinfo], y[iinfo][i], -VectorUtils.DotProduct(omega[iinfo][i], mean[iinfo]));
                        }
                    }
                }
            }
        }
Exemple #17
0
 private static Session LoadTFSession(IHostEnvironment env, string exportDirSavedModel)
 {
     Contracts.Check(env != null, nameof(env));
     env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel));
     return(Session.LoadFromSavedModel(exportDirSavedModel));
 }
Exemple #18
0
 public uint Hash(uint seed)
 {
     Contracts.Check(!IsNA);
     return(Hashing.MurmurHash(seed, _outerBuffer, _ichMin, IchLim));
 }
Exemple #19
0
 /// <summary>
 /// Returns whether the given column is active in this row.
 /// </summary>
 public override bool IsColumnActive(DataViewSchema.Column column)
 {
     Contracts.Check(column.Index < _getters.Length);
     return(_getters[column.Index] != null);
 }
Exemple #20
0
 // REVIEW: Add method to NormStr.Pool that deal with DvText instead of the other way around.
 public NormStr AddToPool(NormStr.Pool pool)
 {
     Contracts.Check(!IsNA);
     Contracts.CheckValue(pool, nameof(pool));
     return(pool.Add(_outerBuffer, _ichMin, IchLim));
 }
Exemple #21
0
 public void GetValue(ref T dst)
 {
     Contracts.Check(Cursor.IsGood);
     dst = Src;
 }
 /// <summary>
 /// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>,
 /// and the limit value for the progress becomes 'unknown'.
 /// </summary>
 public void SetProgress(int index, Double value)
 {
     Contracts.Check(0 <= index && index < Progress.Length);
     Progress[index]    = value;
     ProgressLim[index] = null;
 }
Exemple #23
0
 /// <summary>
 /// Maps from an index into an array of size KindCount to the corresponding DataKind
 /// </summary>
 public static InternalDataKind FromIndex(int index)
 {
     Contracts.Check(0 <= index && index < KindCount);
     return((InternalDataKind)(index + (int)KindMin));
 }
 /// <summary>
 /// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>.
 /// </summary>
 public void SetMetric(int index, Double value)
 {
     Contracts.Check(0 <= index && index < Metrics.Length);
     Metrics[index] = value;
 }
Exemple #25
0
        public static ModelArgs GetModelArgs(DataViewType type, string colName,
                                             List <long> dims = null, List <bool> dimsParams = null)
        {
            Contracts.CheckValue(type, nameof(type));
            Contracts.CheckNonEmpty(colName, nameof(colName));

            TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
            Type rawType;

            if (type is VectorDataViewType vectorType)
            {
                rawType = vectorType.ItemType.RawType;
            }
            else
            {
                rawType = type.RawType;
            }

            if (rawType == typeof(bool))
            {
                dataType = TensorProto.Types.DataType.Float;
            }
            else if (rawType == typeof(ReadOnlyMemory <char>))
            {
                dataType = TensorProto.Types.DataType.String;
            }
            else if (rawType == typeof(sbyte))
            {
                dataType = TensorProto.Types.DataType.Int8;
            }
            else if (rawType == typeof(byte))
            {
                dataType = TensorProto.Types.DataType.Uint8;
            }
            else if (rawType == typeof(short))
            {
                dataType = TensorProto.Types.DataType.Int16;
            }
            else if (rawType == typeof(ushort))
            {
                dataType = TensorProto.Types.DataType.Uint16;
            }
            else if (rawType == typeof(int))
            {
                dataType = TensorProto.Types.DataType.Int32;
            }
            else if (rawType == typeof(uint))
            {
                dataType = TensorProto.Types.DataType.Int64;
            }
            else if (rawType == typeof(long))
            {
                dataType = TensorProto.Types.DataType.Int64;
            }
            else if (rawType == typeof(ulong))
            {
                dataType = TensorProto.Types.DataType.Uint64;
            }
            else if (rawType == typeof(float))
            {
                dataType = TensorProto.Types.DataType.Float;
            }
            else if (rawType == typeof(double))
            {
                dataType = TensorProto.Types.DataType.Double;
            }
            else
            {
                string msg = "Unsupported type: " + type.ToString();
                Contracts.Check(false, msg);
            }

            string      name           = colName;
            List <long> dimsLocal      = null;
            List <bool> dimsParamLocal = null;

            if (dims != null)
            {
                dimsLocal      = dims;
                dimsParamLocal = dimsParams;
            }
            else
            {
                dimsLocal = new List <long>();
                int valueCount = type.GetValueCount();
                if (valueCount == 0) //Unknown size.
                {
                    dimsLocal.Add(1);
                    dimsParamLocal = new List <bool>()
                    {
                        false, true
                    };                                                 //false for batch size, true for dims.
                }
                else if (valueCount == 1)
                {
                    dimsLocal.Add(1);
                }
                else if (valueCount > 1)
                {
                    var vec = (VectorDataViewType)type;
                    for (int i = 0; i < vec.Dimensions.Length; i++)
                    {
                        dimsLocal.Add(vec.Dimensions[i]);
                    }
                }
            }
            //batch size.
            dimsLocal?.Insert(0, 1);

            return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal));
        }
Exemple #26
0
        public JObject GetJsonObject(object instance, Dictionary <string, List <ParameterBinding> > inputBindingMap, Dictionary <ParameterBinding, VariableBinding> inputMap)
        {
            Contracts.CheckValue(instance, nameof(instance));
            Contracts.Check(instance.GetType() == _type);

            var result   = new JObject();
            var defaults = Activator.CreateInstance(_type);

            for (int i = 0; i < _fields.Length; i++)
            {
                var field       = _fields[i];
                var attr        = _attrs[i];
                var instanceVal = field.GetValue(instance);
                var defaultsVal = field.GetValue(defaults);

                if (inputBindingMap.TryGetValue(field.Name, out List <ParameterBinding> bindings))
                {
                    // Handle variables.
                    Contracts.Assert(bindings.Count > 0);
                    VariableBinding varBinding;
                    var             paramBinding = bindings[0];
                    if (paramBinding is SimpleParameterBinding)
                    {
                        Contracts.Assert(bindings.Count == 1);
                        bool success = inputMap.TryGetValue(paramBinding, out varBinding);
                        Contracts.Assert(success);
                        Contracts.AssertValue(varBinding);

                        result.Add(field.Name, new JValue(varBinding.ToJson()));
                    }
                    else if (paramBinding is ArrayIndexParameterBinding)
                    {
                        // Array parameter bindings.
                        var array = new JArray();
                        foreach (var parameterBinding in bindings)
                        {
                            Contracts.Assert(parameterBinding is ArrayIndexParameterBinding);
                            bool success = inputMap.TryGetValue(parameterBinding, out varBinding);
                            Contracts.Assert(success);
                            Contracts.AssertValue(varBinding);
                            array.Add(new JValue(varBinding.ToJson()));
                        }

                        result.Add(field.Name, array);
                    }
                    else
                    {
                        // Dictionary parameter bindings. Not supported yet.
                        Contracts.Assert(paramBinding is DictionaryKeyParameterBinding);
                        throw Contracts.ExceptNotImpl("Dictionary of variables not yet implemented.");
                    }
                }
                else if (instanceVal == null && defaultsVal != null)
                {
                    // Handle null values.
                    result.Add(field.Name, new JValue(instanceVal));
                }
                else if (instanceVal != null && (attr.Input.IsRequired || !instanceVal.Equals(defaultsVal)))
                {
                    // A required field will be serialized regardless of whether or not its value is identical to the default.
                    var type = instanceVal.GetType();
                    if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional <>))
                    {
                        var isExplicit = ExtractOptional(ref instanceVal, ref type);
                        if (!isExplicit)
                        {
                            continue;
                        }
                    }

                    if (type == typeof(JArray))
                    {
                        result.Add(field.Name, (JArray)instanceVal);
                    }
                    else if (type.IsGenericType &&
                             ((type.GetGenericTypeDefinition() == typeof(Var <>)) ||
                              type.GetGenericTypeDefinition() == typeof(ArrayVar <>) ||
                              type.GetGenericTypeDefinition() == typeof(DictionaryVar <>)))
                    {
                        result.Add(field.Name, new JValue($"${((IVarSerializationHelper)instanceVal).VarName}"));
                    }
                    else if (type == typeof(bool) ||
                             type == typeof(string) ||
                             type == typeof(char) ||
                             type == typeof(double) ||
                             type == typeof(float) ||
                             type == typeof(int) ||
                             type == typeof(long) ||
                             type == typeof(uint) ||
                             type == typeof(ulong))
                    {
                        // Handle simple types.
                        result.Add(field.Name, new JValue(instanceVal));
                    }
                    else if (type.IsEnum)
                    {
                        // Handle enums.
                        result.Add(field.Name, new JValue(instanceVal.ToString()));
                    }
                    else if (type.IsArray)
                    {
                        // Handle arrays.
                        var array       = (Array)instanceVal;
                        var jarray      = new JArray();
                        var elementType = type.GetElementType();
                        if (elementType == typeof(bool) ||
                            elementType == typeof(string) ||
                            elementType == typeof(char) ||
                            elementType == typeof(double) ||
                            elementType == typeof(float) ||
                            elementType == typeof(int) ||
                            elementType == typeof(long) ||
                            elementType == typeof(uint) ||
                            elementType == typeof(ulong))
                        {
                            foreach (object item in array)
                            {
                                jarray.Add(new JValue(item));
                            }
                        }
                        else
                        {
                            var builder = new InputBuilder(_ectx, elementType, _catalog);
                            foreach (object item in array)
                            {
                                jarray.Add(builder.GetJsonObject(item, inputBindingMap, inputMap));
                            }
                        }
                        result.Add(field.Name, jarray);
                    }
                    else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary <,>) &&
                             type.GetGenericArguments()[0] == typeof(string))
                    {
                        // Handle dictionaries.
                        // REVIEW: Needs to be implemented when we will have entry point arguments that contain dictionaries.
                    }
                    else if (typeof(IComponentFactory).IsAssignableFrom(type))
                    {
                        // Handle component factories.
                        bool success = _catalog.TryFindComponent(type, out ModuleCatalog.ComponentInfo instanceInfo);
                        Contracts.Assert(success);
                        var builder      = new InputBuilder(_ectx, type, _catalog);
                        var instSettings = builder.GetJsonObject(instanceVal, inputBindingMap, inputMap);

                        ModuleCatalog.ComponentInfo defaultInfo = null;
                        JObject defSettings = new JObject();
                        if (defaultsVal != null)
                        {
                            var deftype = defaultsVal.GetType();
                            if (deftype.IsGenericType && deftype.GetGenericTypeDefinition() == typeof(Optional <>))
                            {
                                ExtractOptional(ref defaultsVal, ref deftype);
                            }
                            success = _catalog.TryFindComponent(deftype, out defaultInfo);
                            Contracts.Assert(success);
                            builder     = new InputBuilder(_ectx, deftype, _catalog);
                            defSettings = builder.GetJsonObject(defaultsVal, inputBindingMap, inputMap);
                        }

                        if (instanceInfo.Name != defaultInfo?.Name || instSettings.ToString() != defSettings.ToString())
                        {
                            var jcomponent = new JObject
                            {
                                { FieldNames.Name, new JValue(instanceInfo.Name) }
                            };
                            if (instSettings.ToString() != defSettings.ToString())
                            {
                                jcomponent.Add(FieldNames.Settings, instSettings);
                            }
                            result.Add(field.Name, jcomponent);
                        }
                    }
                    else
                    {
                        // REVIEW: pass in the bindings once we support variables in inner fields.

                        // Handle structs.
                        var builder = new InputBuilder(_ectx, type, _catalog);
                        result.Add(field.Name, builder.GetJsonObject(instanceVal, new Dictionary <string, List <ParameterBinding> >(),
                                                                     new Dictionary <ParameterBinding, VariableBinding>()));
                    }
                }
            }

            return(result);
        }
Exemple #27
0
 public override bool IsColumnActive(int col)
 {
     Contracts.Check(0 <= col && col < _getters.Length);
     return(_getters[col] != null);
 }
Exemple #28
0
 public static float DotProduct(float[] a, float[] b)
 {
     Contracts.Check(Utils.Size(a) == Utils.Size(b), "Arrays must have the same length");
     Contracts.Check(Utils.Size(a) > 0);
     return(CpuMathUtils.DotProductDense(a, b, a.Length));
 }
 protected override DataViewType GetColumnTypeCore(int iinfo)
 {
     Contracts.Check(iinfo == 0);
     return(_outputColumnType);
 }
 public Stream Open(int index)
 {
     Contracts.Check(index == 0, "Index must be 0");
     return(new MemoryStream(_buffer));
 }