private static InferenceResult InferTextFileColumnTypesCore(IHostEnvironment env, IMultiStreamSource fileSource, Arguments args, IChannel ch) { Contracts.AssertValue(ch); ch.AssertValue(env); ch.AssertValue(fileSource); ch.AssertValue(args); if (args.ColumnCount == 0) { ch.Error("Too many empty columns for automatic inference."); return(InferenceResult.Fail()); } if (args.ColumnCount >= SmartColumnsLim) { ch.Error("Too many columns for automatic inference."); return(InferenceResult.Fail()); } // Read the file as the specified number of text columns. var textLoaderArgs = new TextLoader.Arguments { Column = new[] { TextLoader.Column.Parse(string.Format("C:TX:0-{0}", args.ColumnCount - 1)) }, Separator = args.Separator, AllowSparse = args.AllowSparse, AllowQuoting = args.AllowQuote, }; var textLoader = new TextLoader(env, textLoaderArgs, fileSource); var idv = textLoader.Take(args.MaxRowsToRead); // Read all the data into memory. // List items are rows of the dataset. var data = new List <DvText[]>(); using (var cursor = idv.GetRowCursor(col => true)) { int columnIndex; bool found = cursor.Schema.TryGetColumnIndex("C", out columnIndex); Contracts.Assert(found); var colType = cursor.Schema.GetColumnType(columnIndex); Contracts.Assert(colType.ItemType.IsText); ValueGetter <VBuffer <DvText> > vecGetter = null; ValueGetter <DvText> oneGetter = null; bool isVector = colType.IsVector; if (isVector) { vecGetter = cursor.GetGetter <VBuffer <DvText> >(columnIndex); } else { Contracts.Assert(args.ColumnCount == 1); oneGetter = cursor.GetGetter <DvText>(columnIndex); } VBuffer <DvText> line = default(VBuffer <DvText>); DvText tsValue = default(DvText); while (cursor.MoveNext()) { if (isVector) { vecGetter(ref line); Contracts.Assert(line.Length == args.ColumnCount); var values = new DvText[args.ColumnCount]; line.CopyTo(values); data.Add(values); } else { oneGetter(ref tsValue); var values = new[] { tsValue }; data.Add(values); } } } if (data.Count < 2) { ch.Error("Too few rows ({0}) for automatic inference.", data.Count); return(InferenceResult.Fail()); } var cols = new IntermediateColumn[args.ColumnCount]; for (int i = 0; i < args.ColumnCount; i++) { cols[i] = new IntermediateColumn(data.Select(x => x[i]).ToArray(), i); } foreach (var expert in GetExperts()) { expert.Apply(cols); } Contracts.Check(cols.All(x => x.SuggestedType != null), "Column type inference must be conclusive"); // Aggregating header signals. int suspect = 0; var usedNames = new HashSet <string>(); for (int i = 0; i < args.ColumnCount; i++) { if (cols[i].HasHeader == true) { if (usedNames.Add(cols[i].RawData[0].ToString())) { suspect++; } else { // duplicate value in the first column is a strong signal that this is not a header suspect -= args.ColumnCount; } } else if (cols[i].HasHeader == false) { suspect--; } } // REVIEW: Why not use this for column names as well? TextLoader.Arguments fileArgs; bool hasHeader; if (TextLoader.FileContainsValidSchema(env, fileSource, out fileArgs)) { hasHeader = fileArgs.HasHeader; } else { hasHeader = suspect > 0; } // suggest names var names = new List <string>(); usedNames.Clear(); foreach (var col in cols) { string name0; string name; name0 = name = SuggestName(col, hasHeader); int i = 0; while (!usedNames.Add(name)) { name = string.Format("{0}_{1:00}", name0, i++); } names.Add(name); } var outCols = cols.Select((x, i) => new Column(x.ColumnId, names[i], x.SuggestedType)).ToArray(); var numerics = outCols.Count(x => x.ItemType.IsNumber); ch.Info("Detected {0} numeric and {1} text columns.", numerics, outCols.Length - numerics); if (hasHeader) { ch.Info("Generated column names from the file header."); } return(InferenceResult.Success(outCols, hasHeader, cols.Select(col => col.RawData).ToArray())); }
public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable <IRunResult> previousRuns = null) { int numSweeps = Math.Min(maxSweeps, _dim + 1 - _simplexVertices.Count); if (previousRuns == null) { return(_initSweeper.ProposeSweeps(numSweeps, previousRuns)); } foreach (var run in previousRuns) { Contracts.Check(run != null); } foreach (var run in previousRuns) { if (_simplexVertices.Count == _dim + 1) { break; } if (!_simplexVertices.ContainsKey(run)) { _simplexVertices.Add(run, ParameterSetAsFloatArray(run.ParameterSet)); } if (_simplexVertices.Count == _dim + 1) { ComputeExtremes(); } } if (_simplexVertices.Count < _dim + 1) { numSweeps = Math.Min(maxSweeps, _dim + 1 - _simplexVertices.Count); return(_initSweeper.ProposeSweeps(numSweeps, previousRuns)); } switch (_stage) { case OptimizationStage.NeedReflectionPoint: _pendingSweeps.Clear(); var nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaReflection); if (OutOfBounds(nextPoint) && _args.ProjectInbounds) { // if the reflection point is out of bounds, get the inner contraction point. nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction); _stage = OptimizationStage.WaitingForInnerContractionResult; } else { _stage = OptimizationStage.WaitingForReflectionResult; } _pendingSweeps.Add(new KeyValuePair <ParameterSet, Float[]>(FloatArrayAsParameterSet(nextPoint), nextPoint)); if (previousRuns.Any(runResult => runResult.ParameterSet.Equals(_pendingSweeps[0].Key))) { _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); } return(new ParameterSet[] { _pendingSweeps[0].Key }); case OptimizationStage.WaitingForReflectionResult: Contracts.Assert(_pendingSweeps.Count == 1); _lastReflectionResult = FindRunResult(previousRuns)[0]; if (_secondWorst.Key.CompareTo(_lastReflectionResult.Key) < 0 && _lastReflectionResult.Key.CompareTo(_best.Key) <= 0) { // the reflection result is better than the second worse, but not better than the best UpdateSimplex(_lastReflectionResult.Key, _lastReflectionResult.Value); goto case OptimizationStage.NeedReflectionPoint; } if (_lastReflectionResult.Key.CompareTo(_best.Key) > 0) { // the reflection result is the best so far nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaExpansion); if (OutOfBounds(nextPoint) && _args.ProjectInbounds) { // if the expansion point is out of bounds, get the inner contraction point. nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction); _stage = OptimizationStage.WaitingForInnerContractionResult; } else { _stage = OptimizationStage.WaitingForExpansionResult; } } else if (_lastReflectionResult.Key.CompareTo(_worst.Key) > 0) { // other wise, get results for the outer contraction point. nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaOutsideContraction); _stage = OptimizationStage.WaitingForOuterContractionResult; } else { // other wise, reflection result is not better than worst, get results for the inner contraction point nextPoint = GetNewPoint(_centroid, _worst.Value, _args.DeltaInsideContraction); _stage = OptimizationStage.WaitingForInnerContractionResult; } _pendingSweeps.Clear(); _pendingSweeps.Add(new KeyValuePair <ParameterSet, Float[]>(FloatArrayAsParameterSet(nextPoint), nextPoint)); if (previousRuns.Any(runResult => runResult.ParameterSet.Equals(_pendingSweeps[0].Key))) { _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); } return(new ParameterSet[] { _pendingSweeps[0].Key }); case OptimizationStage.WaitingForExpansionResult: Contracts.Assert(_pendingSweeps.Count == 1); var expansionResult = FindRunResult(previousRuns)[0].Key; if (expansionResult.CompareTo(_lastReflectionResult.Key) > 0) { // expansion point is better than reflection point UpdateSimplex(expansionResult, _pendingSweeps[0].Value); goto case OptimizationStage.NeedReflectionPoint; } // reflection point is better than expansion point UpdateSimplex(_lastReflectionResult.Key, _lastReflectionResult.Value); goto case OptimizationStage.NeedReflectionPoint; case OptimizationStage.WaitingForOuterContractionResult: Contracts.Assert(_pendingSweeps.Count == 1); var outerContractionResult = FindRunResult(previousRuns)[0].Key; if (outerContractionResult.CompareTo(_lastReflectionResult.Key) > 0) { // outer contraction point is better than reflection point UpdateSimplex(outerContractionResult, _pendingSweeps[0].Value); goto case OptimizationStage.NeedReflectionPoint; } // get the reduction points _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); case OptimizationStage.WaitingForInnerContractionResult: Contracts.Assert(_pendingSweeps.Count == 1); var innerContractionResult = FindRunResult(previousRuns)[0].Key; if (innerContractionResult.CompareTo(_worst.Key) > 0) { // inner contraction point is better than worst point UpdateSimplex(innerContractionResult, _pendingSweeps[0].Value); goto case OptimizationStage.NeedReflectionPoint; } // get the reduction points _stage = OptimizationStage.WaitingForReductionResult; _pendingSweeps.Clear(); if (!TryGetReductionPoints(maxSweeps, previousRuns)) { _stage = OptimizationStage.Done; return(null); } return(_pendingSweeps.Select(kvp => kvp.Key).ToArray()); case OptimizationStage.WaitingForReductionResult: Contracts.Assert(_pendingSweeps.Count + _pendingSweepsNotSubmitted.Count == _dim); if (_pendingSweeps.Count < _dim) { return(SubmitMoreReductionPoints(maxSweeps)); } ReplaceSimplexVertices(previousRuns); // if the diameter of the new simplex has become too small, stop sweeping. if (SimplexDiameter() < _args.StoppingSimplexDiameter) { return(null); } goto case OptimizationStage.NeedReflectionPoint; case OptimizationStage.Done: default: return(null); } }
private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiStreamSource source, bool skipStrictValidation, out ColumnSplitResult result) { result = default(ColumnSplitResult); try { // No need to provide information from unsuccessful loader, so we create temporary environment and get information from it in case of success using (var loaderEnv = new ConsoleEnvironment(0, verbose: true)) { var messages = new ConcurrentBag <ChannelMessage>(); loaderEnv.AddListener <ChannelMessage>( (src, msg) => { messages.Add(msg); }); var idv = TextLoader.ReadFile(loaderEnv, args, source).Take(1000); var columnCounts = new List <int>(); int columnIndex; bool found = idv.Schema.TryGetColumnIndex("C", out columnIndex); ch.Assert(found); using (var cursor = idv.GetRowCursor(x => x == columnIndex)) { var getter = cursor.GetGetter <VBuffer <ReadOnlyMemory <char> > >(columnIndex); VBuffer <ReadOnlyMemory <char> > line = default; while (cursor.MoveNext()) { getter(ref line); columnCounts.Add(line.Length); } } Contracts.Check(columnCounts.Count > 0); var mostCommon = columnCounts.GroupBy(x => x).OrderByDescending(x => x.Count()).First(); if (!skipStrictValidation && mostCommon.Count() < UniformColumnCountThreshold * columnCounts.Count) { return(false); } // If user explicitly specified separator we're allowing "single" column case; // Otherwise user will see message informing that we were not able to detect any columns. if (!skipStrictValidation && mostCommon.Key <= 1) { return(false); } result = new ColumnSplitResult(true, args.Separator, args.AllowQuoting, args.AllowSparse, mostCommon.Key); ch.Trace("Discovered {0} columns using separator '{1}'", mostCommon.Key, args.Separator); foreach (var msg in messages) { ch.Send(msg); } return(true); } } catch (Exception ex) { if (!ex.IsMarked()) { throw; } // For known exceptions, we just continue to the next separator candidate. } return(false); }
public void GetValue(ref VBuffer <T> dst) { Contracts.Check(Cursor.IsGood); Src.CopyTo(ref dst); }
/// <summary> /// This function converts <paramref name="kind"/> to <see cref="DataKind"/>. /// Because <see cref="DataKind"/> is a subset of <see cref="InternalDataKind"/>, we should check if <paramref name="kind"/> /// can be found in <see cref="DataKind"/>. /// </summary> public static DataKind ToDataKind(this InternalDataKind kind) { Contracts.Check(kind != InternalDataKind.UG); return((DataKind)kind); }
public override void Combine(ref ReadOnlyMemory <char> dst, ReadOnlyMemory <char> src) { Contracts.Check(IsDefault(dst)); dst = src; }
public override ValueGetter <T> GetGetter() { Contracts.Check(IsActive, "column is not active"); return(_getter); }
public static ScikitSubComponent Create(Type type) { Contracts.Check(type != null && typeof(ScikitSubComponent).IsAssignableFrom(type)); return((ScikitSubComponent)Activator.CreateInstance(type)); }
public GetterInfoPrimitive(string kind, ColumnType type, TValue value) : base(kind, type) { Contracts.Check(type.RawType == typeof(TValue), "Incompatible types"); Value = value; }
public DiscreteValueGenerator(DiscreteParamArguments args) { Contracts.Check(args.Values.Length > 0); _args = args; }
public DiscreteValueGenerator(DiscreteParamOptions options) { Contracts.Check(options.Values.Length > 0); _options = options; }
public void AllocateDataMemory(int docNum, long corpusSize) { Contracts.Check(docNum >= 0); Contracts.Check(corpusSize >= 0); LdaInterface.AllocateDataMemory(_engine, docNum, corpusSize); }
public string GetPathOrNull(int index) { Contracts.Check(index == 0); return(null); }
/// <summary> /// Predict with data. /// This function uses a modified API which does not use caches. /// </summary> /// <param name="vbuf">one row</param> /// <param name="predictedValues">Results of the prediction</param> /// <param name="internalBuffer">buffers allocated by Microsoft.ML and given to XGBoost to avoid XGBoost to allocated caches on its own</param> /// <param name="outputMargin">Whether to output the raw untransformed margin value.</param> /// <param name="ntreeLimit">Limit number of trees in the prediction; defaults to 0 (use all trees).</param> public void PredictOneOff(ref VBuffer <Float> vbuf, ref VBuffer <Float> predictedValues, ref XGBoostTreeBuffer internalBuffer, bool outputMargin = true, int ntreeLimit = 0) { // REVIEW xadupre: XGBoost can produce an output per tree (pred_leaf=true) // When this option is on, the output will be a matrix of (nsample, ntrees) // with each record indicating the predicted leaf index of each sample in each tree. // Note that the leaf index of a tree is unique per tree, so you may find leaf 1 // in both tree 1 and tree 0. // if (pred_leaf) // option_mask |= 0x02; // This might be an interesting feature to implement. int optionMask = 0x00; if (outputMargin) { optionMask |= 0x01; } Contracts.Check(internalBuffer != null); uint length = 0; uint lengthBuffer = 0; uint nb = (uint)vbuf.Count; // This function relies on a modified API. Instead of letting XGBoost handle its own caches, // the function calls XGBoosterPredictOutputSize to know what cache size is required. // Microsoft.ML allocated the caches and gives them to XGBoost. // First, we allocated the cache for the features. Only then XGBoost // will be able to known the required cache size. #if (XGB_EXTENDED) internalBuffer.ResizeEntries(nb, vbuf.Length); #else internalBuffer.ResizeEntries(nb); #endif unsafe { fixed(float *p = vbuf.Values) fixed(int *i = vbuf.Indices) fixed(byte *entries = internalBuffer.XGBoostEntries) { WrappedXGBoostInterface.XGBoosterCopyEntries((IntPtr)entries, ref nb, p, vbuf.IsDense ? null : i, float.NaN); WrappedXGBoostInterface.XGBoosterPredictOutputSize(_handle, (IntPtr)entries, nb, optionMask, (uint)ntreeLimit, ref length, ref lengthBuffer); } } // Then we allocated the cache for the prediction. internalBuffer.ResizeOutputs(length, lengthBuffer, ref predictedValues); unsafe { fixed(byte *entries = internalBuffer.XGBoostEntries) fixed(float *ppreds = predictedValues.Values) fixed(float *ppredBuffer = internalBuffer.PredBuffer) fixed(uint *ppredCounter = internalBuffer.PredCounter) { WrappedXGBoostInterface.XGBoosterPredictNoInsideCache(_handle, (IntPtr)entries, nb, optionMask, (uint)ntreeLimit, length, lengthBuffer, ppreds, ppredBuffer, ppredCounter #if (XGB_EXTENDED) , internalBuffer.RegTreeFVec #endif ); } } }
public TreeEnsemble GetModel(int[] categoricalFeatureBoudaries) { TreeEnsemble res = new TreeEnsemble(); string modelString = GetModelString(); string[] lines = modelString.Split('\n'); int i = 0; for (; i < lines.Length;) { if (lines[i].StartsWith("Tree=")) { Dictionary <string, string> kvPairs = new Dictionary <string, string>(); ++i; while (!lines[i].StartsWith("Tree=") && lines[i].Trim().Length != 0) { string[] kv = lines[i].Split('='); Contracts.Check(kv.Length == 2); kvPairs[kv[0].Trim()] = kv[1].Trim(); ++i; } int numLeaves = int.Parse(kvPairs["num_leaves"]); int numCat = int.Parse(kvPairs["num_cat"]); if (numLeaves > 1) { var leftChild = Str2IntArray(kvPairs["left_child"], ' '); var rightChild = Str2IntArray(kvPairs["right_child"], ' '); var splitFeature = Str2IntArray(kvPairs["split_feature"], ' '); var threshold = Str2DoubleArray(kvPairs["threshold"], ' '); var splitGain = Str2DoubleArray(kvPairs["split_gain"], ' '); var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' '); var decisionType = Str2UIntArray(kvPairs["decision_type"], ' '); var defaultValue = GetDefalutValue(threshold, decisionType); var categoricalSplitFeatures = new int[numLeaves - 1][]; var categoricalSplit = new bool[numLeaves - 1]; if (categoricalFeatureBoudaries != null) { // Add offsets to split features. for (int node = 0; node < numLeaves - 1; ++node) { splitFeature[node] = categoricalFeatureBoudaries[splitFeature[node]]; } } if (numCat > 0) { var catBoundaries = Str2IntArray(kvPairs["cat_boundaries"], ' '); var catThreshold = Str2UIntArray(kvPairs["cat_threshold"], ' '); for (int node = 0; node < numLeaves - 1; ++node) { if (GetIsCategoricalSplit(decisionType[node])) { int catIdx = (int)threshold[node]; var cats = GetCatThresholds(catThreshold, catBoundaries[catIdx], catBoundaries[catIdx + 1]); categoricalSplitFeatures[node] = new int[cats.Length]; // Convert Cat thresholds to feature indices. for (int j = 0; j < cats.Length; ++j) { categoricalSplitFeatures[node][j] = splitFeature[node] + cats[j] - 1; } splitFeature[node] = -1; categoricalSplit[node] = true; // Swap left and right child. int t = leftChild[node]; leftChild[node] = rightChild[node]; rightChild[node] = t; } else { categoricalSplit[node] = false; } } } RegressionTree tree = RegressionTree.Create(numLeaves, splitFeature, splitGain, threshold.Select(x => (float)(x)).ToArray(), defaultValue.Select(x => (float)(x)).ToArray(), leftChild, rightChild, leafOutput, categoricalSplitFeatures, categoricalSplit); res.AddTree(tree); } else { RegressionTree tree = new RegressionTree(2); var leafOutput = Str2DoubleArray(kvPairs["leaf_value"], ' '); if (leafOutput[0] != 0) { // Convert Constant tree to Two-leaf tree, avoid being filter by TLC. var categoricalSplitFeatures = new int[1][]; var categoricalSplit = new bool[1]; tree = RegressionTree.Create(2, new int[] { 0 }, new double[] { 0 }, new float[] { 0 }, new float[] { 0 }, new int[] { -1 }, new int[] { -2 }, new double[] { leafOutput[0], leafOutput[0] }, categoricalSplitFeatures, categoricalSplit); } res.AddTree(tree); } } else { ++i; } } return(res); }
//Project the covariance matrix A on to Omega: Y <- A * Omega //A = X' * X / n, where X = data - mean //Note that the covariance matrix is not computed explicitly private void Project(IDataView trainingData, Float[][] mean, Float[][][] omega, Float[][][] y, TransformInfo[] transformInfos) { Host.Assert(mean.Length == omega.Length && omega.Length == y.Length && y.Length == Infos.Length); for (int i = 0; i < omega.Length; i++) { Contracts.Assert(omega[i].Length == y[i].Length); } // set y to be all zeros for (int iinfo = 0; iinfo < y.Length; iinfo++) { for (int i = 0; i < y[iinfo].Length; i++) { Array.Clear(y[iinfo][i], 0, y[iinfo][i].Length); } } bool[] center = Enumerable.Range(0, mean.Length).Select(i => mean[i] != null).ToArray(); Double[] totalColWeight = new Double[Infos.Length]; bool[] activeColumns = new bool[Source.Schema.ColumnCount]; for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { activeColumns[Infos[iinfo].Source] = true; if (_weightColumnIndex[iinfo] >= 0) { activeColumns[_weightColumnIndex[iinfo]] = true; } } using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { var weightGetters = new ValueGetter <Float> [Infos.Length]; var columnGetters = new ValueGetter <VBuffer <Float> > [Infos.Length]; for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { if (_weightColumnIndex[iinfo] >= 0) { weightGetters[iinfo] = cursor.GetGetter <Float>(_weightColumnIndex[iinfo]); } columnGetters[iinfo] = cursor.GetGetter <VBuffer <Float> >(Infos[iinfo].Source); } var features = default(VBuffer <Float>); while (cursor.MoveNext()) { for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { Contracts.Check(Infos[iinfo].TypeSrc.IsVector && Infos[iinfo].TypeSrc.ItemType.IsNumber, "PCA transform can only be performed on numeric columns of dimension > 1"); Float weight = 1; if (weightGetters[iinfo] != null) { weightGetters[iinfo](ref weight); } columnGetters[iinfo](ref features); if (FloatUtils.IsFinite(weight) && weight >= 0 && (features.Count == 0 || FloatUtils.IsFinite(features.Values, features.Count))) { totalColWeight[iinfo] += weight; if (center[iinfo]) { VectorUtils.AddMult(ref features, mean[iinfo], weight); } for (int i = 0; i < omega[iinfo].Length; i++) { VectorUtils.AddMult(ref features, y[iinfo][i], weight * VectorUtils.DotProductWithOffset(omega[iinfo][i], 0, ref features)); } } } } for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { if (totalColWeight[iinfo] <= 0) { throw Host.Except("Empty data in column '{0}'", Source.Schema.GetColumnName(Infos[iinfo].Source)); } } for (int iinfo = 0; iinfo < Infos.Length; iinfo++) { var invn = (Float)(1 / totalColWeight[iinfo]); for (var i = 0; i < omega[iinfo].Length; ++i) { VectorUtils.ScaleBy(y[iinfo][i], invn); } if (center[iinfo]) { VectorUtils.ScaleBy(mean[iinfo], invn); for (int i = 0; i < omega[iinfo].Length; i++) { VectorUtils.AddMult(mean[iinfo], y[iinfo][i], -VectorUtils.DotProduct(omega[iinfo][i], mean[iinfo])); } } } } }
private static Session LoadTFSession(IHostEnvironment env, string exportDirSavedModel) { Contracts.Check(env != null, nameof(env)); env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel)); return(Session.LoadFromSavedModel(exportDirSavedModel)); }
public uint Hash(uint seed) { Contracts.Check(!IsNA); return(Hashing.MurmurHash(seed, _outerBuffer, _ichMin, IchLim)); }
/// <summary> /// Returns whether the given column is active in this row. /// </summary> public override bool IsColumnActive(DataViewSchema.Column column) { Contracts.Check(column.Index < _getters.Length); return(_getters[column.Index] != null); }
// REVIEW: Add method to NormStr.Pool that deal with DvText instead of the other way around. public NormStr AddToPool(NormStr.Pool pool) { Contracts.Check(!IsNA); Contracts.CheckValue(pool, nameof(pool)); return(pool.Add(_outerBuffer, _ichMin, IchLim)); }
public void GetValue(ref T dst) { Contracts.Check(Cursor.IsGood); dst = Src; }
/// <summary> /// Set the progress value for the index <paramref name="index"/> to <paramref name="value"/>, /// and the limit value for the progress becomes 'unknown'. /// </summary> public void SetProgress(int index, Double value) { Contracts.Check(0 <= index && index < Progress.Length); Progress[index] = value; ProgressLim[index] = null; }
/// <summary> /// Maps from an index into an array of size KindCount to the corresponding DataKind /// </summary> public static InternalDataKind FromIndex(int index) { Contracts.Check(0 <= index && index < KindCount); return((InternalDataKind)(index + (int)KindMin)); }
/// <summary> /// Sets the metric with index <paramref name="index"/> to <paramref name="value"/>. /// </summary> public void SetMetric(int index, Double value) { Contracts.Check(0 <= index && index < Metrics.Length); Metrics[index] = value; }
public static ModelArgs GetModelArgs(DataViewType type, string colName, List <long> dims = null, List <bool> dimsParams = null) { Contracts.CheckValue(type, nameof(type)); Contracts.CheckNonEmpty(colName, nameof(colName)); TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined; Type rawType; if (type is VectorDataViewType vectorType) { rawType = vectorType.ItemType.RawType; } else { rawType = type.RawType; } if (rawType == typeof(bool)) { dataType = TensorProto.Types.DataType.Float; } else if (rawType == typeof(ReadOnlyMemory <char>)) { dataType = TensorProto.Types.DataType.String; } else if (rawType == typeof(sbyte)) { dataType = TensorProto.Types.DataType.Int8; } else if (rawType == typeof(byte)) { dataType = TensorProto.Types.DataType.Uint8; } else if (rawType == typeof(short)) { dataType = TensorProto.Types.DataType.Int16; } else if (rawType == typeof(ushort)) { dataType = TensorProto.Types.DataType.Uint16; } else if (rawType == typeof(int)) { dataType = TensorProto.Types.DataType.Int32; } else if (rawType == typeof(uint)) { dataType = TensorProto.Types.DataType.Int64; } else if (rawType == typeof(long)) { dataType = TensorProto.Types.DataType.Int64; } else if (rawType == typeof(ulong)) { dataType = TensorProto.Types.DataType.Uint64; } else if (rawType == typeof(float)) { dataType = TensorProto.Types.DataType.Float; } else if (rawType == typeof(double)) { dataType = TensorProto.Types.DataType.Double; } else { string msg = "Unsupported type: " + type.ToString(); Contracts.Check(false, msg); } string name = colName; List <long> dimsLocal = null; List <bool> dimsParamLocal = null; if (dims != null) { dimsLocal = dims; dimsParamLocal = dimsParams; } else { dimsLocal = new List <long>(); int valueCount = type.GetValueCount(); if (valueCount == 0) //Unknown size. { dimsLocal.Add(1); dimsParamLocal = new List <bool>() { false, true }; //false for batch size, true for dims. } else if (valueCount == 1) { dimsLocal.Add(1); } else if (valueCount > 1) { var vec = (VectorDataViewType)type; for (int i = 0; i < vec.Dimensions.Length; i++) { dimsLocal.Add(vec.Dimensions[i]); } } } //batch size. dimsLocal?.Insert(0, 1); return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal)); }
public JObject GetJsonObject(object instance, Dictionary <string, List <ParameterBinding> > inputBindingMap, Dictionary <ParameterBinding, VariableBinding> inputMap) { Contracts.CheckValue(instance, nameof(instance)); Contracts.Check(instance.GetType() == _type); var result = new JObject(); var defaults = Activator.CreateInstance(_type); for (int i = 0; i < _fields.Length; i++) { var field = _fields[i]; var attr = _attrs[i]; var instanceVal = field.GetValue(instance); var defaultsVal = field.GetValue(defaults); if (inputBindingMap.TryGetValue(field.Name, out List <ParameterBinding> bindings)) { // Handle variables. Contracts.Assert(bindings.Count > 0); VariableBinding varBinding; var paramBinding = bindings[0]; if (paramBinding is SimpleParameterBinding) { Contracts.Assert(bindings.Count == 1); bool success = inputMap.TryGetValue(paramBinding, out varBinding); Contracts.Assert(success); Contracts.AssertValue(varBinding); result.Add(field.Name, new JValue(varBinding.ToJson())); } else if (paramBinding is ArrayIndexParameterBinding) { // Array parameter bindings. var array = new JArray(); foreach (var parameterBinding in bindings) { Contracts.Assert(parameterBinding is ArrayIndexParameterBinding); bool success = inputMap.TryGetValue(parameterBinding, out varBinding); Contracts.Assert(success); Contracts.AssertValue(varBinding); array.Add(new JValue(varBinding.ToJson())); } result.Add(field.Name, array); } else { // Dictionary parameter bindings. Not supported yet. Contracts.Assert(paramBinding is DictionaryKeyParameterBinding); throw Contracts.ExceptNotImpl("Dictionary of variables not yet implemented."); } } else if (instanceVal == null && defaultsVal != null) { // Handle null values. result.Add(field.Name, new JValue(instanceVal)); } else if (instanceVal != null && (attr.Input.IsRequired || !instanceVal.Equals(defaultsVal))) { // A required field will be serialized regardless of whether or not its value is identical to the default. var type = instanceVal.GetType(); if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional <>)) { var isExplicit = ExtractOptional(ref instanceVal, ref type); if (!isExplicit) { continue; } } if (type == typeof(JArray)) { result.Add(field.Name, (JArray)instanceVal); } else if (type.IsGenericType && ((type.GetGenericTypeDefinition() == typeof(Var <>)) || type.GetGenericTypeDefinition() == typeof(ArrayVar <>) || type.GetGenericTypeDefinition() == typeof(DictionaryVar <>))) { result.Add(field.Name, new JValue($"${((IVarSerializationHelper)instanceVal).VarName}")); } else if (type == typeof(bool) || type == typeof(string) || type == typeof(char) || type == typeof(double) || type == typeof(float) || type == typeof(int) || type == typeof(long) || type == typeof(uint) || type == typeof(ulong)) { // Handle simple types. result.Add(field.Name, new JValue(instanceVal)); } else if (type.IsEnum) { // Handle enums. result.Add(field.Name, new JValue(instanceVal.ToString())); } else if (type.IsArray) { // Handle arrays. var array = (Array)instanceVal; var jarray = new JArray(); var elementType = type.GetElementType(); if (elementType == typeof(bool) || elementType == typeof(string) || elementType == typeof(char) || elementType == typeof(double) || elementType == typeof(float) || elementType == typeof(int) || elementType == typeof(long) || elementType == typeof(uint) || elementType == typeof(ulong)) { foreach (object item in array) { jarray.Add(new JValue(item)); } } else { var builder = new InputBuilder(_ectx, elementType, _catalog); foreach (object item in array) { jarray.Add(builder.GetJsonObject(item, inputBindingMap, inputMap)); } } result.Add(field.Name, jarray); } else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary <,>) && type.GetGenericArguments()[0] == typeof(string)) { // Handle dictionaries. // REVIEW: Needs to be implemented when we will have entry point arguments that contain dictionaries. } else if (typeof(IComponentFactory).IsAssignableFrom(type)) { // Handle component factories. bool success = _catalog.TryFindComponent(type, out ModuleCatalog.ComponentInfo instanceInfo); Contracts.Assert(success); var builder = new InputBuilder(_ectx, type, _catalog); var instSettings = builder.GetJsonObject(instanceVal, inputBindingMap, inputMap); ModuleCatalog.ComponentInfo defaultInfo = null; JObject defSettings = new JObject(); if (defaultsVal != null) { var deftype = defaultsVal.GetType(); if (deftype.IsGenericType && deftype.GetGenericTypeDefinition() == typeof(Optional <>)) { ExtractOptional(ref defaultsVal, ref deftype); } success = _catalog.TryFindComponent(deftype, out defaultInfo); Contracts.Assert(success); builder = new InputBuilder(_ectx, deftype, _catalog); defSettings = builder.GetJsonObject(defaultsVal, inputBindingMap, inputMap); } if (instanceInfo.Name != defaultInfo?.Name || instSettings.ToString() != defSettings.ToString()) { var jcomponent = new JObject { { FieldNames.Name, new JValue(instanceInfo.Name) } }; if (instSettings.ToString() != defSettings.ToString()) { jcomponent.Add(FieldNames.Settings, instSettings); } result.Add(field.Name, jcomponent); } } else { // REVIEW: pass in the bindings once we support variables in inner fields. // Handle structs. var builder = new InputBuilder(_ectx, type, _catalog); result.Add(field.Name, builder.GetJsonObject(instanceVal, new Dictionary <string, List <ParameterBinding> >(), new Dictionary <ParameterBinding, VariableBinding>())); } } } return(result); }
public override bool IsColumnActive(int col) { Contracts.Check(0 <= col && col < _getters.Length); return(_getters[col] != null); }
public static float DotProduct(float[] a, float[] b) { Contracts.Check(Utils.Size(a) == Utils.Size(b), "Arrays must have the same length"); Contracts.Check(Utils.Size(a) > 0); return(CpuMathUtils.DotProductDense(a, b, a.Length)); }
protected override DataViewType GetColumnTypeCore(int iinfo) { Contracts.Check(iinfo == 0); return(_outputColumnType); }
public Stream Open(int index) { Contracts.Check(index == 0, "Index must be 0"); return(new MemoryStream(_buffer)); }