/// <summary> /// This constructor is for forking. /// </summary> protected HostEnvironmentBase(HostEnvironmentBase <TEnv> source, IRandom rand, bool verbose, int?conc, string shortName = null, string parentFullName = null) : base(shortName, parentFullName, verbose) { Contracts.CheckValue(source, nameof(source)); Contracts.CheckValueOrNull(rand); _rand = rand ?? RandomUtils.Create(); _conc = conc; _cancelLock = new object(); // This fork shares some stuff with the master. Master = source; Root = source.Root; ListenerDict = source.ListenerDict; ProgressTracker = source.ProgressTracker; }
public new IHost Register(string name, int?seed = null, bool?verbose = null, int?conc = null) { Contracts.CheckNonEmpty(name, nameof(name)); IHost host; lock (_cancelLock) { IRandom rand = (seed.HasValue) ? RandomUtils.Create(seed.Value) : RandomUtils.Create(_rand); host = RegisterCore(this, name, Master?.FullName, rand, verbose ?? Verbose, conc ?? _conc); if (!IsCancelled) { _children.Add(new WeakReference <IHost>(host)); } } return(host); }
private static IDataView AppendFloatMapper <TInput>(IHostEnvironment env, IChannel ch, IDataView input, string col, KeyType type, int seed) { // Any key is convertible to ulong, so rather than add special case handling for all possible // key-types we just upfront convert it to the most general type (ulong) and work from there. KeyType dstType = new KeyType(DataKind.U8, type.Min, type.Count, type.Contiguous); bool identity; var converter = Conversions.Instance.GetStandardConversion <TInput, ulong>(type, dstType, out identity); var isNa = Conversions.Instance.GetIsNAPredicate <TInput>(type); ulong temp = 0; ValueMapper <TInput, Single> mapper; if (seed == 0) { mapper = (in TInput src, ref Single dst) => { if (isNa(in src)) { dst = Single.NaN; return; } converter(in src, ref temp); dst = (Single)(temp - 1); }; } else { ch.Check(type.Count > 0, "Label must be of known cardinality."); int[] permutation = Utils.GetRandomPermutation(RandomUtils.Create(seed), type.Count); mapper = (in TInput src, ref Single dst) => { if (isNa(in src)) { dst = Single.NaN; return; } converter(in src, ref temp); dst = (Single)permutation[(int)(temp - 1)]; }; } return(LambdaColumnMapper.Create(env, "Key to Float Mapper", input, col, col, type, NumberType.Float, mapper)); }
protected override IRowCursor GetRowCursorCore(Func <int, bool> predicate, IRandom rand = null) { Host.AssertValue(predicate, "predicate"); Host.AssertValueOrNull(rand); // REVIEW: This is slightly interesting. Our mechanism for inducing // randomness in the source cursor is this Random object, but this can change // from release to release. The correct solution, it seems, is to instead have // randomness injected into cursor creation by using IRandom (or something akin // to it), vs. just a straight system Random. // The desired functionality is to support some permutations of whether we allow // shuffling at the source level, or not. // // Pool | Source | Options // -----------+----------+-------- // Randonly | Never | poolOnly+ // " | Randonly | (default) // " | Always | forceSource+ // Always | Never | force+ poolOnly+ // Always | Randonly | force+ forceSource- // Always | Always | force+ bool shouldShuffleMe = _forceShuffle || rand != null; bool shouldShuffleSource = _forceShuffleSource || (!_poolOnly && rand != null); IRandom myRandom = rand ?? (shouldShuffleMe || shouldShuffleSource ? RandomUtils.Create(_forceShuffleSeed) : null); if (shouldShuffleMe) { rand = myRandom; } IRandom sourceRand = shouldShuffleSource ? RandomUtils.Create(myRandom) : null; var input = _subsetInput.GetRowCursor(predicate, sourceRand); // If rand is null (so we're not doing pool shuffling) or number of pool rows is 1 // (so any pool shuffling, if attempted, would be trivial anyway), just return the // source cursor. if (rand == null || _poolRows == 1) { return(input); } return(new RowCursor(Host, _poolRows, input, rand)); }
/// <summary> /// Re/Builds internal CDL (Cummulative Distribution List) /// Must be called after modifying (calling Add or Remove), or it will break. /// Switches between linear or binary search, depending on which one will be faster. /// Might generate some garbage (list resize) on first few builds. /// </summary> /// <param name="seed">You can specify seed for internal random gen or leave it alone</param> /// <returns>Returns itself</returns> public IRandomSelector <T> Build(Int32 seed = -1) { if (_items.Count == 0) { throw new Exception("Cannot build with no items."); } // clear list and then transfer weights _cdl.Clear(); foreach (Double weight in _weights) { _cdl.Add(weight); } RandomMath.BuildCumulativeDistribution(_cdl); // default behavior // if seed wasn't specified (it is seed==-1), keep same seed - avoids garbage collection from making new random if (seed != -1) { // input -2 if you want to randomize seed if (seed == -2) { seed = _random.Next(); } _random = RandomUtils.Create(seed); } // RandomMath.ListBreakpoint decides where to use Linear or Binary search, based on internal buffer size // if CDL list is smaller than breakpoint, then pick linear search random selector, else pick binary search selector if (_cdl.Count < RandomMath.ListBreakpoint) { _select = RandomMath.SelectIndexLinearSearch; } else { _select = RandomMath.SelectIndexBinarySearch; } return(this); }
public RandCursor(AppendRowsDataView parent, IEnumerable <Schema.Column> columnsNeeded, Random rand, int[] counts) : base(parent) { Ch.AssertValue(rand); _rand = rand; Ch.AssertValue(counts); Ch.Assert(Sources.Length == counts.Length); _cursorSet = new RowCursor[counts.Length]; for (int i = 0; i < counts.Length; i++) { Ch.Assert(counts[i] >= 0); _cursorSet[i] = parent._sources[i].GetRowCursor(columnsNeeded, RandomUtils.Create(_rand)); } _sampler = new MultinomialWithoutReplacementSampler(Ch, counts, rand); _currentSourceIndex = -1; foreach (var col in columnsNeeded) { Getters[col.Index] = CreateGetter(col.Index); } }
public TransformInfo(IHost host, ApproximatedKernelMappingEstimator.ColumnOptions column, int d, float avgDist) { Contracts.AssertValue(host); SrcDim = d; NewDim = column.Rank; host.CheckUserArg(NewDim > 0, nameof(column.Rank)); _useSin = column.UseCosAndSinBases; var seed = column.Seed; _rand = seed.HasValue ? RandomUtils.Create(seed) : RandomUtils.Create(host.Rand); _state = _rand.GetState(); var generator = column.Generator; _matrixGenerator = generator.GetRandomNumberGenerator(avgDist); int roundedUpD = RoundUp(NewDim, _cfltAlign); int roundedUpNumFeatures = RoundUp(SrcDim, _cfltAlign); RndFourierVectors = new AlignedArray(roundedUpD * roundedUpNumFeatures, CpuMathUtils.GetVectorAlignment()); RotationTerms = _useSin ? null : new AlignedArray(roundedUpD, CpuMathUtils.GetVectorAlignment()); InitializeFourierCoefficients(roundedUpNumFeatures, roundedUpD); }
/// <summary> /// Create an ML.NET <see cref="IHostEnvironment"/> for local execution. /// </summary> /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> /// <param name="compositionContainerFactory">The function to retrieve the composition container</param> public LocalEnvironment(int?seed = null, int conc = 0, Func <CompositionContainer> compositionContainerFactory = null) : base(RandomUtils.Create(seed), verbose: false, conc) { _compositionContainerFactory = compositionContainerFactory; }
public RmlEnvironment(RmlEnvironment source, int?seed = null, bool verbose = false) : this(source, RandomUtils.Create(seed), verbose) { }
public static NearestNeighborsTrees NearestNeighborsBuild <TLabel>(IChannel ch, IDataView data, int featureIndex, int labelIndex, int idIndex, int weightIndex, out Dictionary <long, Tuple <TLabel, float> > outLabelsWeights, NearestNeighborsArguments args) where TLabel : IComparable <TLabel> { var indexes = new HashSet <int>() { featureIndex, labelIndex, weightIndex, idIndex }; if (idIndex != -1) { var colType = data.Schema[idIndex].Type; if (idIndex != -1 && (colType.IsVector() || colType.RawKind() != DataKind.I8)) { throw ch.Except("Column '{0}' must be of type '{1}' not '{2}'", args.colId, DataKind.I8, colType); } } int nt = args.numThreads ?? 1; Random rand = RandomUtils.Create(args.seed); var cursors = (nt == 1) ? new RowCursor[] { data.GetRowCursor(i => indexes.Contains(i), rand) } : data.GetRowCursorSet(i => indexes.Contains(i), nt, rand); KdTree[] kdtrees; Dictionary <long, Tuple <TLabel, float> >[] labelsWeights; if (nt == 1) { labelsWeights = new Dictionary <long, Tuple <TLabel, float> > [1]; kdtrees = new KdTree[] { BuildKDTree <TLabel>(data, cursors[0], featureIndex, labelIndex, idIndex, weightIndex, out labelsWeights[0], args) }; } else { // Multithreading. We assume the distributed set of cursor is well distributed. // No KdTree will be much smaller than the others. Action[] ops = new Action[cursors.Length]; kdtrees = new KdTree[cursors.Length]; labelsWeights = new Dictionary <long, Tuple <TLabel, float> > [cursors.Length]; for (int i = 0; i < ops.Length; ++i) { int chunkId = i; kdtrees[i] = null; ops[i] = new Action(() => { kdtrees[chunkId] = BuildKDTree <TLabel>(data, cursors[chunkId], featureIndex, labelIndex, idIndex, weightIndex, out labelsWeights[chunkId], args); }); } Parallel.Invoke(new ParallelOptions() { MaxDegreeOfParallelism = cursors.Length }, ops); } kdtrees = kdtrees.Where(c => c.Any()).ToArray(); labelsWeights = labelsWeights.Where(c => c.Any()).ToArray(); var merged = labelsWeights[0]; long start = merged.Count; long newKey; for (int i = 1; i < labelsWeights.Length; ++i) { kdtrees[i].MoveId(start); foreach (var pair in labelsWeights[i]) { newKey = pair.Key + start; if (merged.ContainsKey(newKey)) { throw ch.Except("The same key appeared twice in two differents threads: {0}", newKey); } else { merged.Add(newKey, pair.Value); } } start += labelsWeights[i].Count; } // Id checking. var labelId = merged.Select(c => c.Key).ToList(); var treeId = new List <long>(); for (int i = 0; i < kdtrees.Length; ++i) { treeId.AddRange(kdtrees[i].EnumeratePoints().Select(c => c.id)); } var h1 = new HashSet <long>(labelId); var h2 = new HashSet <long>(treeId); if (h1.Count != labelId.Count) { throw ch.Except("Duplicated label ids."); } if (h2.Count != treeId.Count) { throw ch.Except("Duplicated label ids."); } if (h1.Count != h2.Count) { throw ch.Except("Mismatch (1) in ids."); } var inter = h1.Intersect(h2); if (inter.Count() != h1.Count) { throw ch.Except("Mismatch (2) in ids."); } // End. outLabelsWeights = merged; return(new NearestNeighborsTrees(ch, kdtrees)); }
public RmlEnvironment(Bridge.CheckCancelled checkDelegate, int?seed = null, bool verbose = false) : this(RandomUtils.Create(seed), verbose) { CheckCancelled = checkDelegate; }
/// <summary> /// Create an ML.NET <see cref="IHostEnvironment"/> for local execution. /// </summary> /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> public LocalEnvironment(int?seed = null, int conc = 0) : base(RandomUtils.Create(seed), verbose: false, conc) { }
private protected override void InitializeStateCore() { _parent = (PValueTransform)ParentTransform; _randomGen = RandomUtils.Create(_parent._seed); }
/// <summary> /// Create an ML.NET <see cref="IHostEnvironment"/> for local execution, with console feedback. /// </summary> /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> /// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param> /// <param name="sensitivity">Allowed message sensitivity.</param> /// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param> /// <param name="outWriter">Text writer to print normal messages to.</param> /// <param name="errWriter">Text writer to print error messages to.</param> public ConsoleEnvironment(int?seed = null, bool verbose = false, MessageSensitivity sensitivity = MessageSensitivity.All, int conc = 0, TextWriter outWriter = null, TextWriter errWriter = null) : this(RandomUtils.Create(seed), verbose, sensitivity, conc, outWriter, errWriter) { }
/// <summary> /// Create an ML.NET <see cref="IHostEnvironment"/> for local execution. /// </summary> /// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param> public LocalEnvironment(int?seed = null) : base(RandomUtils.Create(seed), verbose: false) { }
/// <summary> /// Perform WebSocket client upgrade /// </summary> /// <param name="response">WebSocket upgrade HTTP response</param> /// <param name="id">WebSocket client Id</param> /// <returns>'true' if the WebSocket was successfully upgrade, 'false' if the WebSocket was not upgrade</returns> public Boolean PerformClientUpgrade(HttpNetworkResponse response, Guid id) { if (response.Status != 101) { return(false); } Boolean error = false; Boolean accept = false; Boolean connection = false; Boolean upgrade = false; // Validate WebSocket handshake headers for (Int32 i = 0; i < response.Headers; ++i) { Tuple <String, String> header = response.Header(i); String key = header.Item1; String value = header.Item2; if (key == "Connection") { if (value != "Upgrade") { error = true; _wsHandler.OnWsError("Invalid WebSocket handshaked response: 'Connection' header value must be 'Upgrade'"); break; } connection = true; } else if (key == "Upgrade") { if (value != "websocket") { error = true; _wsHandler.OnWsError("Invalid WebSocket handshaked response: 'Upgrade' header value must be 'websocket'"); break; } upgrade = true; } else if (key == "Sec-WebSocket-Accept") { // Calculate the original WebSocket hash String wskey = Convert.ToBase64String(Encoding.UTF8.GetBytes(id.ToString())) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; String wshash; using (SHA1Managed sha1 = new SHA1Managed()) { wshash = Encoding.UTF8.GetString(sha1.ComputeHash(Encoding.UTF8.GetBytes(wskey))); } // Get the received WebSocket hash wskey = Encoding.UTF8.GetString(Convert.FromBase64String(value)); // Compare original and received hashes if (String.Compare(wskey, wshash, StringComparison.InvariantCulture) != 0) { error = true; _wsHandler.OnWsError("Invalid WebSocket handshaked response: 'Sec-WebSocket-Accept' value validation failed"); break; } accept = true; } } // Failed to perform WebSocket handshake if (!accept || !connection || !upgrade) { if (!error) { _wsHandler.OnWsError("Invalid WebSocket response"); } return(false); } // WebSocket successfully handshaked! WsHandshaked = true; IRandom random = RandomUtils.Create(); random.NextBytes(WsSendMask); _wsHandler.OnWsConnected(response); return(true); }
public IList <TRunDetail> Execute() { var iterationResults = new List <TRunDetail>(); // Create a timer for the max duration of experiment. When given time has // elapsed, MaxExperimentTimeExpiredEvent is called to interrupt training // of current model. Timer is not used if no experiment time is given, or // is not a positive number. if (_experimentSettings.MaxExperimentTimeInSeconds > 0) { _maxExperimentTimeTimer = new Timer( new TimerCallback(MaxExperimentTimeExpiredEvent), null, _experimentSettings.MaxExperimentTimeInSeconds * 1000, Timeout.Infinite ); } // If given max duration of experiment is 0, only 1 model will be trained. // _experimentSettings.MaxExperimentTimeInSeconds is of type uint, it is // either 0 or >0. else { _experimentTimerExpired = true; } // Add second timer to check for the cancelation signal from the main MLContext // to the active child MLContext. This timer will propagate the cancelation // signal from the main to the child MLContexs if the main MLContext is // canceled. _mainContextCanceledTimer = new Timer(new TimerCallback(MainContextCanceledEvent), null, 1000, 1000); // Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to // maintain variability between training iterations. int?mainContextSeed = ((IHostEnvironmentInternal)_context.Model.GetEnvironment()).Seed; _newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : null; do { try { var iterationStopwatch = Stopwatch.StartNew(); // get next pipeline var getPipelineStopwatch = Stopwatch.StartNew(); // A new MLContext is needed per model run. When max experiment time is reached, each used // context is canceled to stop further model training. The cancellation of the main MLContext // a user has instantiated is not desirable, thus additional MLContexts are used. _currentModelMLContext = _newContextSeedGenerator == null ? new MLContext() : new MLContext(_newContextSeedGenerator.Next()); _currentModelMLContext.Log += RelayCurrentContextLogsToLogger; var pipeline = PipelineSuggester.GetNextInferredPipeline(_currentModelMLContext, _history, _datasetColumnInfo, _task, _optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList); // break if no candidates returned, means no valid pipeline available if (pipeline == null) { break; } // evaluate pipeline _logger.Trace($"Evaluating pipeline {pipeline.ToString()}"); (SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail) = _runner.Run(pipeline, _modelDirectory, _history.Count + 1); _history.Add(suggestedPipelineRunDetail); WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch); runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds; runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds; ReportProgress(runDetail); iterationResults.Add(runDetail); // if model is perfect, break if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score)) { break; } // If after third run, all runs have failed so far, throw exception if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded)) { throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}"); } } catch (OperationCanceledException e) { // This exception is thrown when the IHost/MLContext of the trainer is canceled due to // reaching maximum experiment time. Simply catch this exception and return finished // iteration results. _logger.Warning(_operationCancelledMessage, e.Message); return(iterationResults); } catch (AggregateException e) { // This exception is thrown when the IHost/MLContext of the trainer is canceled due to // reaching maximum experiment time. Simply catch this exception and return finished // iteration results. For some trainers, like FastTree, because training is done in parallel // in can throw multiple OperationCancelledExceptions. This causes them to be returned as an // AggregateException and misses the first catch block. This is to handle that case. if (e.InnerExceptions.All(exception => exception is OperationCanceledException)) { _logger.Warning(_operationCancelledMessage, e.Message); return(iterationResults); } throw; } } while (_history.Count < _experimentSettings.MaxModels && !_experimentSettings.CancellationToken.IsCancellationRequested && !_experimentTimerExpired); return(iterationResults); }
/// <summary> /// Constructor, used by StaticRandomSelectorBuilder /// Needs array of items and CDA (Cummulative Distribution Array). /// </summary> /// <param name="items">Items of type T</param> /// <param name="cda">Cummulative Distribution Array</param> /// <param name="seed">Seed for internal random generator</param> public StaticRandomSelectorLinear(T[] items, Double[] cda, Int32 seed) { _items = items; _cda = cda; _random = RandomUtils.Create(seed); }
public void TestComparableDvText() { const int count = 100; var rand = RandomUtils.Create(42); var chars = new char[2000]; for (int i = 0; i < chars.Length; i++) { chars[i] = (char)rand.Next(128); } var str = new string(chars); var values = new DvText[2 * count]; for (int i = 0; i < count; i++) { int len = rand.Next(20); int ich = rand.Next(str.Length - len + 1); var v = values[i] = new DvText(str, ich, ich + len); values[values.Length - i - 1] = v; } // Assign two NA's and an empty at random. int iv1 = rand.Next(values.Length); int iv2 = rand.Next(values.Length - 1); if (iv2 >= iv1) { iv2++; } int iv3 = rand.Next(values.Length - 2); if (iv3 >= iv1) { iv3++; } if (iv3 >= iv2) { iv3++; } values[iv1] = DvText.NA; values[iv2] = DvText.NA; values[iv3] = DvText.Empty; Array.Sort(values); Assert.True(values[0].IsNA); Assert.True(values[1].IsNA); Assert.True(values[2].IsEmpty); Assert.True((values[0] == values[1]).IsNA); Assert.True((values[0] != values[1]).IsNA); Assert.True(values[0].Equals(values[1])); Assert.True(values[0].CompareTo(values[1]) == 0); Assert.True((values[1] == values[2]).IsNA); Assert.True((values[1] != values[2]).IsNA); Assert.True(!values[1].Equals(values[2])); Assert.True(values[1].CompareTo(values[2]) < 0); for (int i = 3; i < values.Length; i++) { DvBool eq = values[i - 1] == values[i]; DvBool ne = values[i - 1] != values[i]; bool feq = values[i - 1].Equals(values[i]); int cmp = values[i - 1].CompareTo(values[i]); Assert.True(!eq.IsNA); Assert.True(!ne.IsNA); Assert.True(eq.IsTrue == ne.IsFalse); Assert.True(feq == eq.IsTrue); Assert.True(cmp <= 0); Assert.True(feq == (cmp == 0)); } }
private static Float[] Train(IHost host, ColInfo[] infos, Arguments args, IDataView trainingData) { Contracts.AssertValue(host, "host"); host.AssertNonEmpty(infos); var avgDistances = new Float[infos.Length]; const int reservoirSize = 5000; bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; for (int i = 0; i < infos.Length; i++) { activeColumns[infos[i].Source] = true; } var reservoirSamplers = new ReservoirSamplerWithReplacement <VBuffer <Float> > [infos.Length]; using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { var rng = args.Seed.HasValue ? RandomUtils.Create(args.Seed) : host.Rand; for (int i = 0; i < infos.Length; i++) { if (infos[i].TypeSrc.IsVector) { var get = cursor.GetGetter <VBuffer <Float> >(infos[i].Source); reservoirSamplers[i] = new ReservoirSamplerWithReplacement <VBuffer <Float> >(rng, reservoirSize, get); } else { var getOne = cursor.GetGetter <Float>(infos[i].Source); Float val = 0; ValueGetter <VBuffer <Float> > get = (ref VBuffer <Float> dst) => { getOne(ref val); dst = new VBuffer <float>(1, new[] { val }); }; reservoirSamplers[i] = new ReservoirSamplerWithReplacement <VBuffer <Float> >(rng, reservoirSize, get); } } while (cursor.MoveNext()) { for (int i = 0; i < infos.Length; i++) { reservoirSamplers[i].Sample(); } } for (int i = 0; i < infos.Length; i++) { reservoirSamplers[i].Lock(); } } for (int iinfo = 0; iinfo < infos.Length; iinfo++) { var instanceCount = reservoirSamplers[iinfo].NumSampled; // If the number of pairs is at most the maximum reservoir size / 2, we go over all the pairs, // so we get all the examples. Otherwise, get a sample with replacement. VBuffer <Float>[] res; int resLength; if (instanceCount < reservoirSize && instanceCount * (instanceCount - 1) <= reservoirSize) { res = reservoirSamplers[iinfo].GetCache(); resLength = reservoirSamplers[iinfo].Size; Contracts.Assert(resLength == instanceCount); } else { res = reservoirSamplers[iinfo].GetSample().ToArray(); resLength = res.Length; } // If the dataset contains only one valid Instance, then we can't learn anything anyway, so just return 1. if (instanceCount <= 1) { avgDistances[iinfo] = 1; } else { Float[] distances; var sub = args.Column[iinfo].MatrixGenerator; if (sub == null) { sub = args.MatrixGenerator; } // create a dummy generator in order to get its type. // REVIEW this should be refactored. See https://github.com/dotnet/machinelearning/issues/699 var matrixGenerator = sub.CreateComponent(host, 1); bool gaussian = matrixGenerator is GaussianFourierSampler; // If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs. if (resLength < reservoirSize) { distances = new Float[instanceCount * (instanceCount - 1) / 2]; int count = 0; for (int i = 0; i < instanceCount; i++) { for (int j = i + 1; j < instanceCount; j++) { distances[count++] = gaussian ? VectorUtils.L2DistSquared(ref res[i], ref res[j]) : VectorUtils.L1Distance(ref res[i], ref res[j]); } } host.Assert(count == distances.Length); } else { distances = new Float[reservoirSize / 2]; for (int i = 0; i < reservoirSize - 1; i += 2) { // For Gaussian kernels, we scale by the L2 distance squared, since the kernel function is exp(-gamma ||x-y||^2). // For Laplacian kernels, we scale by the L1 distance, since the kernel function is exp(-gamma ||x-y||_1). distances[i / 2] = gaussian ? VectorUtils.L2DistSquared(ref res[i], ref res[i + 1]) : VectorUtils.L1Distance(ref res[i], ref res[i + 1]); } } // If by chance, in the random permutation all the pairs are the same instance we return 1. Float median = MathUtils.GetMedianInPlace(distances, distances.Length); avgDistances[iinfo] = median == 0 ? 1 : median; } } return(avgDistances); }
void LoadCache(Random rand) { if (_cacheReplica != null) { // Already done. return; } uint?useed = _args.seed.HasValue ? (uint)_args.seed.Value : (uint?)null; if (rand == null) { rand = RandomUtils.Create(useed); } using (var ch = _host.Start("Resample: fill the cache")) { var indexClass = SchemaHelper.GetColumnIndexDC(_input.Schema, _args.column, true); using (var cur = _input.GetRowCursor(Schema.Where(c => c.Index == indexClass.Index))) { if (string.IsNullOrEmpty(_args.column)) { _cacheReplica = new Dictionary <DataViewRowId, int>(); var gid = cur.GetIdGetter(); DataViewRowId did = default(DataViewRowId); int rep; while (cur.MoveNext()) { gid(ref did); rep = NextPoisson(_args.lambda, rand); _cacheReplica[did] = rep; } } else { var type = _input.Schema[indexClass.Index].Type; switch (type.RawKind()) { case DataKind.Boolean: bool clbool; if (!bool.TryParse(_args.classValue, out clbool)) { throw ch.Except("Unable to parse '{0}'.", _args.classValue); } LoadCache <bool>(rand, cur, indexClass, clbool, ch); break; case DataKind.UInt32: uint cluint; if (!uint.TryParse(_args.classValue, out cluint)) { throw ch.Except("Unable to parse '{0}'.", _args.classValue); } LoadCache <uint>(rand, cur, indexClass, cluint, ch); break; case DataKind.Single: float clfloat; if (!float.TryParse(_args.classValue, out clfloat)) { throw ch.Except("Unable to parse '{0}'.", _args.classValue); } LoadCache <float>(rand, cur, indexClass, clfloat, ch); break; case DataKind.String: var cltext = new ReadOnlyMemory <char>(_args.classValue.ToCharArray()); LoadCache <ReadOnlyMemory <char> >(rand, cur, indexClass, cltext, ch); break; default: throw _host.Except("Unsupported type '{0}'", type); } } } } }
GetImportanceMetricsMatrix( IHostEnvironment env, IPredictionTransformer <TModel> model, IDataView data, Func <TResult> resultInitializer, Func <IDataView, TMetric> evaluationFunc, Func <TMetric, TMetric, TMetric> deltaFunc, string features, int permutationCount, bool useFeatureWeightFilter = false, int?topExamples = null) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(nameof(PermutationFeatureImportance <TModel, TMetric, TResult>)); host.CheckValue(model, nameof(model)); host.CheckValue(data, nameof(data)); host.CheckNonEmpty(features, nameof(features)); topExamples = topExamples ?? Utils.ArrayMaxSize; host.Check(topExamples > 0, "Provide how many examples to use (positive number) or set to null to use whole dataset."); VBuffer <ReadOnlyMemory <char> > slotNames = default; var metricsDelta = new List <TResult>(); using (var ch = host.Start("GetImportanceMetrics")) { ch.Trace("Scoring and evaluating baseline."); var baselineMetrics = evaluationFunc(model.Transform(data)); // Get slot names. var featuresColumn = data.Schema[features]; int numSlots = featuresColumn.Type.GetVectorSize(); data.Schema.TryGetColumnIndex(features, out int featuresColumnIndex); ch.Info("Number of slots: " + numSlots); if (data.Schema[featuresColumnIndex].HasSlotNames(numSlots)) { data.Schema[featuresColumnIndex].Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref slotNames); } if (slotNames.Length != numSlots) { slotNames = VBufferUtils.CreateEmpty <ReadOnlyMemory <char> >(numSlots); } VBuffer <float> weights = default; var workingFeatureIndices = Enumerable.Range(0, numSlots).ToList(); int zeroWeightsCount = 0; // By default set to the number of all features available. var evaluatedFeaturesCount = numSlots; if (useFeatureWeightFilter) { var predictorWithWeights = model.Model as IPredictorWithFeatureWeights <Single>; if (predictorWithWeights != null) { predictorWithWeights.GetFeatureWeights(ref weights); const int maxReportedZeroFeatures = 10; StringBuilder msgFilteredOutFeatures = new StringBuilder("The following features have zero weight and will not be evaluated: \n \t"); var prefix = ""; foreach (var k in weights.Items(all: true)) { if (k.Value == 0) { zeroWeightsCount++; // Print info about first few features we're not going to evaluate. if (zeroWeightsCount <= maxReportedZeroFeatures) { msgFilteredOutFeatures.Append(prefix); msgFilteredOutFeatures.Append(GetSlotName(slotNames, k.Key)); prefix = ", "; } } else { workingFeatureIndices.Add(k.Key); } } // Old FastTree models has less weights than slots. if (weights.Length < numSlots) { ch.Warning( "Predictor had fewer features than slots. All unknown features will get default 0 weight."); zeroWeightsCount += numSlots - weights.Length; var indexes = weights.GetIndices().ToArray(); var values = weights.GetValues().ToArray(); var count = values.Length; weights = new VBuffer <float>(numSlots, count, values, indexes); } evaluatedFeaturesCount = workingFeatureIndices.Count; ch.Info("Number of zero weights: {0} out of {1}.", zeroWeightsCount, weights.Length); // Print what features have 0 weight if (zeroWeightsCount > 0) { if (zeroWeightsCount > maxReportedZeroFeatures) { msgFilteredOutFeatures.Append(string.Format("... (printing out {0} features here).\n Use 'Index' column in the report for info on what features are not evaluated.", maxReportedZeroFeatures)); } ch.Info(msgFilteredOutFeatures.ToString()); } } } if (workingFeatureIndices.Count == 0 && zeroWeightsCount == 0) { // Use all features otherwise. workingFeatureIndices.AddRange(Enumerable.Range(0, numSlots)); } if (zeroWeightsCount == numSlots) { ch.Warning("All features have 0 weight thus can not do thorough evaluation"); return(metricsDelta.ToImmutableArray()); } // Note: this will not work on the huge dataset. var maxSize = topExamples; List <float> initialfeatureValuesList = new List <float>(); // Cursor through the data to cache slot 0 values for the upcoming permutation. var valuesRowCount = 0; // REVIEW: Seems like if the labels are NaN, so that all metrics are NaN, this command will be useless. // In which case probably erroring out is probably the most useful thing. using (var cursor = data.GetRowCursor(featuresColumn)) { var featuresGetter = cursor.GetGetter <VBuffer <float> >(featuresColumn); var featuresBuffer = default(VBuffer <float>); while (initialfeatureValuesList.Count < maxSize && cursor.MoveNext()) { featuresGetter(ref featuresBuffer); initialfeatureValuesList.Add(featuresBuffer.GetItemOrDefault(workingFeatureIndices[0])); } valuesRowCount = initialfeatureValuesList.Count; } if (valuesRowCount > 0) { ch.Info("Detected {0} examples for evaluation.", valuesRowCount); } else { ch.Warning("Detected no examples for evaluation."); return(metricsDelta.ToImmutableArray()); } float[] featureValuesBuffer = initialfeatureValuesList.ToArray(); float[] nextValues = new float[valuesRowCount]; // Now iterate through all the working slots, do permutation and calc the delta of metrics. int processedCnt = 0; int nextFeatureIndex = 0; var shuffleRand = RandomUtils.Create(host.Rand.Next()); using (var pch = host.StartProgressChannel("Calculating Permutation Feature Importance")) { pch.SetHeader(new ProgressHeader("processed slots"), e => e.SetProgress(0, processedCnt)); foreach (var workingIndx in workingFeatureIndices) { // Index for the feature we will permute next. Needed to build in advance a buffer for the permutation. if (processedCnt < workingFeatureIndices.Count - 1) { nextFeatureIndex = workingFeatureIndices[processedCnt + 1]; } // Used for pre-caching the next feature int nextValuesIndex = 0; SchemaDefinition input = SchemaDefinition.Create(typeof(FeaturesBuffer)); Contracts.Assert(input.Count == 1); input[0].ColumnName = features; SchemaDefinition output = SchemaDefinition.Create(typeof(FeaturesBuffer)); Contracts.Assert(output.Count == 1); output[0].ColumnName = features; output[0].ColumnType = featuresColumn.Type; // Perform multiple permutations for one feature to build a confidence interval var metricsDeltaForFeature = resultInitializer(); for (int permutationIteration = 0; permutationIteration < permutationCount; permutationIteration++) { Utils.Shuffle <float>(shuffleRand, featureValuesBuffer); Action <FeaturesBuffer, FeaturesBuffer, PermuterState> permuter = (src, dst, state) => { src.Features.CopyTo(ref dst.Features); VBufferUtils.ApplyAt(ref dst.Features, workingIndx, (int ii, ref float d) => d = featureValuesBuffer[state.SampleIndex++]); // Is it time to pre-cache the next feature? if (permutationIteration == permutationCount - 1 && processedCnt < workingFeatureIndices.Count - 1) { // Fill out the featureValueBuffer for the next feature while updating the current feature // This is the reason I need PermuterState in LambdaTransform.CreateMap. nextValues[nextValuesIndex] = src.Features.GetItemOrDefault(nextFeatureIndex); if (nextValuesIndex < valuesRowCount - 1) { nextValuesIndex++; } } }; IDataView viewPermuted = LambdaTransform.CreateMap( host, data, permuter, null, input, output); if (valuesRowCount == topExamples) { viewPermuted = SkipTakeFilter.Create(host, new SkipTakeFilter.TakeOptions() { Count = valuesRowCount }, viewPermuted); } var metrics = evaluationFunc(model.Transform(viewPermuted)); var delta = deltaFunc(metrics, baselineMetrics); metricsDeltaForFeature.Add(delta); } // Add the metrics delta to the list metricsDelta.Add(metricsDeltaForFeature); // Swap values for next iteration of permutation. if (processedCnt < workingFeatureIndices.Count - 1) { Array.Clear(featureValuesBuffer, 0, featureValuesBuffer.Length); nextValues.CopyTo(featureValuesBuffer, 0); Array.Clear(nextValues, 0, nextValues.Length); } processedCnt++; } pch.Checkpoint(processedCnt, processedCnt); } } return(metricsDelta.ToImmutableArray()); }
void LoadCache(IRandom rand) { if (_cacheReplica != null) { // Already done. return; } uint?useed = _args.seed.HasValue ? (uint)_args.seed.Value : (uint?)null; if (rand == null) { rand = RandomUtils.Create(useed); } using (var ch = _host.Start("Resample: fill the cache")) { int indexClass = -1; if (!string.IsNullOrEmpty(_args.column) && !_input.Schema.TryGetColumnIndex(_args.column, out indexClass)) { throw ch.Except("Unable to find column '{0}'.", _args.column); } using (var cur = _input.GetRowCursor(i => i == indexClass)) { if (string.IsNullOrEmpty(_args.column)) { _cacheReplica = new Dictionary <UInt128, int>(); var gid = cur.GetIdGetter(); UInt128 did = default(UInt128); int rep; while (cur.MoveNext()) { gid(ref did); rep = NextPoisson(_args.lambda, rand); _cacheReplica[did] = rep; } } else { var type = _input.Schema.GetColumnType(indexClass); switch (type.RawKind()) { case DataKind.BL: bool clbool; if (!bool.TryParse(_args.classValue, out clbool)) { throw ch.Except("Unable to parse '{0}'.", _args.classValue); } LoadCache <bool>(rand, cur, indexClass, clbool, ch); break; case DataKind.U4: uint cluint; if (!uint.TryParse(_args.classValue, out cluint)) { throw ch.Except("Unable to parse '{0}'.", _args.classValue); } LoadCache <uint>(rand, cur, indexClass, cluint, ch); break; case DataKind.R4: float clfloat; if (!float.TryParse(_args.classValue, out clfloat)) { throw ch.Except("Unable to parse '{0}'.", _args.classValue); } LoadCache <float>(rand, cur, indexClass, clfloat, ch); break; case DataKind.TX: var cltext = new ReadOnlyMemory <char>(_args.classValue.ToCharArray()); LoadCache <ReadOnlyMemory <char> >(rand, cur, indexClass, cltext, ch); break; default: throw _host.Except("Unsupported type '{0}'", type); } } } } }