public CTFSampler(string path, int minibatchSize, bool randomize = true) { _minibatchSize = minibatchSize; // Build a stream configuration _streamConfigurations = new List <StreamConfiguration>(); var elements = GuessDataFormat(path, 10); foreach (var e in elements) { if (e.Value == -1) { throw new ArgumentException("CTF file contains sparse data"); } var config = new StreamConfiguration(e.Key, e.Value, false); _streamConfigurations.Add(config); } _minibatchSource = MinibatchSource.TextFormatMinibatchSource(path, _streamConfigurations, MinibatchSource.InfinitelyRepeat, randomize); _streamInfos = new Dictionary <string, StreamInformation>(); foreach (var name in elements.Keys) { _streamInfos.Add(name, _minibatchSource.StreamInfo(name)); } }
public static void SetupStream(string stream, string subject, Options options, JetStreamOptions jso) { using (IConnection c = new ConnectionFactory().CreateConnection(options)) { IJetStreamManagement jsm = c.CreateJetStreamManagementContext(jso); try { jsm.PurgeStream(stream); IList <string> cons = jsm.GetConsumerNames(stream); foreach (string co in cons) { jsm.DeleteConsumer(stream, co); } Console.WriteLine("PURGED: " + jsm.GetStreamInfo(stream)); } catch (NATSJetStreamException) { StreamConfiguration streamConfig = StreamConfiguration.Builder() .WithName(stream) .WithSubjects(subject) .WithStorageType(StorageType.Memory) .Build(); Console.WriteLine("CREATED: " + jsm.AddStream(streamConfig)); } Thread.Sleep(1000); // give it a little time to setup } }
public void TestAddUpdateStreamInvalids() { Context.RunInJsServer(c => { IJetStreamManagement jsm = c.CreateJetStreamManagementContext(); StreamConfiguration scNoName = StreamConfiguration.Builder().Build(); Assert.Throws <ArgumentNullException>(() => jsm.AddStream(null)); Assert.Throws <ArgumentException>(() => jsm.AddStream(scNoName)); Assert.Throws <ArgumentNullException>(() => jsm.UpdateStream(null)); Assert.Throws <ArgumentException>(() => jsm.UpdateStream(scNoName)); // cannot update non existent stream StreamConfiguration sc = GetTestStreamConfiguration(); // stream not added yet Assert.Throws <NATSJetStreamException>(() => jsm.UpdateStream(sc)); // add the stream jsm.AddStream(sc); // cannot change MaxConsumers StreamConfiguration scMaxCon = GetTestStreamConfigurationBuilder() .WithMaxConsumers(2) .Build(); Assert.Throws <NATSJetStreamException>(() => jsm.UpdateStream(scMaxCon)); // cannot change RetentionPolicy StreamConfiguration scReten = GetTestStreamConfigurationBuilder() .WithRetentionPolicy(RetentionPolicy.Interest) .Build(); Assert.Throws <NATSJetStreamException>(() => jsm.UpdateStream(scReten)); }); }
private static StreamConfiguration.StreamConfigurationBuilder GetTestStreamConfigurationBuilder() { return(StreamConfiguration.Builder() .WithName(STREAM) .WithStorageType(StorageType.Memory) .WithSubjects(Subject(0), Subject(1))); }
public void TestSubjects() { StreamConfiguration.StreamConfigurationBuilder builder = StreamConfiguration.Builder(); // subjects(...) replaces builder.WithSubjects(Subject(0)); AssertSubjects(builder.Build(), 0); // subjects(...) replaces builder.WithSubjects(); AssertSubjects(builder.Build()); // subjects(...) replaces builder.WithSubjects(Subject(1)); AssertSubjects(builder.Build(), 1); // Subjects(...) replaces builder.WithSubjects((string)null); AssertSubjects(builder.Build()); // Subjects(...) replaces builder.WithSubjects(Subject(2), Subject(3)); AssertSubjects(builder.Build(), 2, 3); // Subjects(...) replaces builder.WithSubjects(Subject(101), null, Subject(102)); AssertSubjects(builder.Build(), 101, 102); // Subjects(...) replaces List <String> list45 = new List <String>(); list45.Add(Subject(4)); list45.Add(Subject(5)); builder.WithSubjects(list45); AssertSubjects(builder.Build(), 4, 5); // AddSubjects(...) adds unique builder.AddSubjects(Subject(5), Subject(6)); AssertSubjects(builder.Build(), 4, 5, 6); // AddSubjects(...) adds unique List <String> list678 = new List <String>(); list678.Add(Subject(6)); list678.Add(Subject(7)); list678.Add(Subject(8)); builder.AddSubjects(list678); AssertSubjects(builder.Build(), 4, 5, 6, 7, 8); // AddSubjects(...) null check builder.AddSubjects((String[])null); AssertSubjects(builder.Build(), 4, 5, 6, 7, 8); // AddSubjects(...) null check builder.AddSubjects((List <String>)null); AssertSubjects(builder.Build(), 4, 5, 6, 7, 8); }
public StreamMarketService(StreamConfiguration configuration) { if (configuration == null) { throw new ArgumentNullException(nameof(configuration)); } _configuration = configuration; }
public void TestUpdateStream() { Context.RunInJsServer(c => { IJetStreamManagement jsm = c.CreateJetStreamManagementContext(); StreamInfo si = AddTestStream(jsm); StreamConfiguration sc = si.Config; Assert.NotNull(sc); Assert.Equal(STREAM, sc.Name); Assert.NotNull(sc.Subjects); Assert.Equal(2, sc.Subjects.Count); Assert.Equal(Subject(0), sc.Subjects[0]); Assert.Equal(Subject(1), sc.Subjects[1]); Assert.Equal(-1, sc.MaxBytes); Assert.Equal(-1, sc.MaxValueSize); Assert.Equal(Duration.Zero, sc.MaxAge); Assert.Equal(StorageType.Memory, sc.StorageType); Assert.Equal(DiscardPolicy.Old, sc.DiscardPolicy); Assert.Equal(1, sc.Replicas); Assert.False(sc.NoAck); Assert.Equal(Duration.OfMinutes(2), sc.DuplicateWindow); Assert.Empty(sc.TemplateOwner); sc = StreamConfiguration.Builder() .WithName(STREAM) .WithStorageType(StorageType.Memory) .WithSubjects(Subject(0), Subject(1), Subject(2)) .WithMaxBytes(43) .WithMaxMsgSize(44) .WithMaxAge(Duration.OfDays(100)) .WithDiscardPolicy(DiscardPolicy.New) .WithNoAck(true) .WithDuplicateWindow(Duration.OfMinutes(3)) .Build(); si = jsm.UpdateStream(sc); Assert.NotNull(si); sc = si.Config; Assert.NotNull(sc); Assert.Equal(STREAM, sc.Name); Assert.NotNull(sc.Subjects); Assert.Equal(3, sc.Subjects.Count); Assert.Equal(Subject(0), sc.Subjects[0]); Assert.Equal(Subject(1), sc.Subjects[1]); Assert.Equal(Subject(2), sc.Subjects[2]); Assert.Equal(43u, sc.MaxBytes); Assert.Equal(44, sc.MaxValueSize); Assert.Equal(Duration.OfDays(100), sc.MaxAge); Assert.Equal(StorageType.Memory, sc.StorageType); Assert.Equal(DiscardPolicy.New, sc.DiscardPolicy); Assert.Equal(1, sc.Replicas); Assert.True(sc.NoAck); Assert.Equal(Duration.OfMinutes(3), sc.DuplicateWindow); Assert.Empty(sc.TemplateOwner); }); }
public KeyValueStatus Create(KeyValueConfiguration config) { StreamConfiguration sc = config.BackingConfig; if (jsm.Conn.ServerInfo.IsOlderThanVersion("2.7.2")) { sc = StreamConfiguration.Builder(sc).WithDiscardPolicy(null).Build(); // null discard policy will use default } return(new KeyValueStatus(jsm.AddStream(sc))); }
private void AssertSubjects(StreamConfiguration sc, params int[] subIds) { int count = sc.Subjects.Count; Assert.Equal(subIds.Length, count); foreach (int subId in subIds) { Assert.Contains(Subject(subId), sc.Subjects); } }
public void TestDiscardPolicy() { StreamConfiguration.StreamConfigurationBuilder builder = StreamConfiguration.Builder(); Assert.Equal(DiscardPolicy.Old, builder.Build().DiscardPolicy); builder.WithDiscardPolicy(DiscardPolicy.New); Assert.Equal(DiscardPolicy.New, builder.Build().DiscardPolicy); builder.WithDiscardPolicy(null); Assert.Equal(DiscardPolicy.Old, builder.Build().DiscardPolicy); }
public void TestStorageType() { StreamConfiguration.StreamConfigurationBuilder builder = StreamConfiguration.Builder(); Assert.Equal(StorageType.File, builder.Build().StorageType); builder.WithStorageType(StorageType.Memory); Assert.Equal(StorageType.Memory, builder.Build().StorageType); builder.WithStorageType(null); Assert.Equal(StorageType.File, builder.Build().StorageType); }
public void createStreamConfiguration_test03() { //stream configuration to distinct features and labels in the file StreamConfiguration[] streamConfig = new StreamConfiguration[] { new StreamConfiguration("year", 3, true), new StreamConfiguration("month", 12, true), new StreamConfiguration("shop", 56, true), new StreamConfiguration("item", 5100, true), new StreamConfiguration("cnt_past3m", 3), new StreamConfiguration("item_cnt_month", 1) }; string strFeature = "|year 3 1 |month 12 1 |shop 56 1 |item 5100 1 |cnt_past3m 3 0"; string strLabels = "|item_cnt_month 1 0"; MLFactory f = new MLFactory(); //setup stream configuration f.CreateIOVariables(strFeature, strLabels, DataType.Float); // Assert.Equal(5, f.InputVariables.Count); Assert.Single(f.OutputVariables); //first feature Assert.Equal("year", f.InputVariables[0].Name); Assert.Equal(3, f.InputVariables[0].Shape.Dimensions[0]); Assert.True(f.InputVariables[0].IsSparse); //second feature Assert.Equal("month", f.InputVariables[1].Name); Assert.Equal(12, f.InputVariables[1].Shape.Dimensions[0]); Assert.True(f.InputVariables[1].IsSparse); //third feature Assert.Equal("shop", f.InputVariables[2].Name); Assert.Equal(56, f.InputVariables[2].Shape.Dimensions[0]); Assert.True(f.InputVariables[2].IsSparse); //fourth feature Assert.Equal("item", f.InputVariables[3].Name); Assert.Equal(5100, f.InputVariables[3].Shape.Dimensions[0]); Assert.True(f.InputVariables[3].IsSparse); //fifth feature Assert.Equal("cnt_past3m", f.InputVariables[4].Name); Assert.Equal(3, f.InputVariables[4].Shape.Dimensions[0]); Assert.False(f.InputVariables[4].IsSparse); //first label Assert.Equal("item_cnt_month", f.OutputVariables[0].Name); Assert.Equal(1, f.OutputVariables[0].Shape.Dimensions[0]); Assert.False(f.OutputVariables[0].IsSparse); }
internal void LoadTextData(CNTK.Variable feature, CNTK.Variable label) { int imageSize = feature.Shape.Rank == 1 ? feature.Shape[0] : feature.Shape[0] * feature.Shape[1] * feature.Shape[2]; int numClasses = label.Shape[0]; IList <StreamConfiguration> streamConfigurations = new StreamConfiguration[] { new StreamConfiguration(featureStreamName, imageSize), new StreamConfiguration(labelsStreamName, numClasses) }; miniBatchSource = MinibatchSource.TextFormatMinibatchSource(FileName, streamConfigurations, MinibatchSource.InfinitelyRepeat); featureVariable = feature; labelVariable = label; featureStreamInfo = miniBatchSource.StreamInfo(featureStreamName); labelStreamInfo = miniBatchSource.StreamInfo(labelsStreamName); }
public StreamMarketService(StreamConfiguration configuration) { if (configuration == null) { throw new ArgumentNullException(nameof(configuration)); } _configuration = configuration; if (_configuration.ReconnectEnabled && _configuration.ResubscribeOnReconnect) { _subscriptions = new SubscriptionsCollection(); } }
public void TestRetentionPolicy() { StreamConfiguration.StreamConfigurationBuilder builder = StreamConfiguration.Builder(); Assert.Equal(RetentionPolicy.Limits, builder.Build().RetentionPolicy); builder.WithRetentionPolicy(RetentionPolicy.Interest); Assert.Equal(RetentionPolicy.Interest, builder.Build().RetentionPolicy); builder.WithRetentionPolicy(RetentionPolicy.WorkQueue); Assert.Equal(RetentionPolicy.WorkQueue, builder.Build().RetentionPolicy); builder.WithRetentionPolicy(null); Assert.Equal(RetentionPolicy.Limits, builder.Build().RetentionPolicy); }
public void TestStreamCreate() { Context.RunInJsServer(c => { DateTime now = DateTime.Now; IJetStreamManagement jsm = c.CreateJetStreamManagementContext(); StreamConfiguration sc = StreamConfiguration.Builder() .WithName(STREAM) .WithStorageType(StorageType.Memory) .WithSubjects(Subject(0), Subject(1)) .Build(); StreamInfo si = jsm.AddStream(sc); Assert.True(now <= si.Created); Assert.NotNull(si.Config); sc = si.Config; Assert.Equal(STREAM, sc.Name); Assert.Equal(2, sc.Subjects.Count); Assert.Equal(Subject(0), sc.Subjects[0]); Assert.Equal(Subject(1), sc.Subjects[1]); Assert.Equal(RetentionPolicy.Limits, sc.RetentionPolicy); Assert.Equal(DiscardPolicy.Old, sc.DiscardPolicy); Assert.Equal(StorageType.Memory, sc.StorageType); Assert.NotNull(si.State); Assert.Equal(-1, sc.MaxConsumers); Assert.Equal(-1, sc.MaxMsgs); Assert.Equal(-1, sc.MaxBytes); Assert.Equal(-1, sc.MaxValueSize); Assert.Equal(1, sc.Replicas); Assert.Equal(Duration.Zero, sc.MaxAge); Assert.Equal(Duration.OfSeconds(120), sc.DuplicateWindow); Assert.False(sc.NoAck); Assert.Empty(sc.TemplateOwner); StreamState ss = si.State; Assert.Equal(0u, ss.Messages); Assert.Equal(0u, ss.Bytes); Assert.Equal(0u, ss.FirstSeq); Assert.Equal(0u, ss.LastSeq); Assert.Equal(0u, ss.ConsumerCount); }); }
public static StreamInfo CreateStream(IJetStreamManagement jsm, string streamName, StorageType storageType, params string[] subjects) { // Create a stream, here will use a file storage type, and one subject, // the passed subject. StreamConfiguration sc = StreamConfiguration.Builder() .WithName(streamName) .WithStorageType(storageType) .WithSubjects(subjects) .Build(); // Add or use an existing stream. StreamInfo si = jsm.AddStream(sc); Console.WriteLine("Created stream '{0}' with subject(s) [{1}]\n", streamName, string.Join(",", si.Config.Subjects)); return(si); }
/// <summary> /// create stream configuration based on defined variables /// </summary> /// <param name="lstFeaturesVars"></param> /// <param name="lstLabelVars"></param> /// <returns></returns> private static List <StreamConfiguration> createStreamConfiguration(List <Variable> lstFeaturesVars, List <Variable> lstLabelVars) { var retVal = new List <StreamConfiguration>(); // foreach (var var in lstFeaturesVars) { var sc = new StreamConfiguration(var.Name, var.Shape.Dimensions.Last(), var.IsSparse, "", false); retVal.Add(sc); } foreach (var var in lstLabelVars) { var sc = new StreamConfiguration(var.Name, var.Shape.Dimensions.Last(), var.IsSparse, "", false); retVal.Add(sc); } return(retVal); }
public static void CreateMemoryStream(IJetStreamManagement jsm, string streamName, params string[] subjects) { try { jsm.DeleteStream(streamName); // since the server is re-used, we want a fresh stream } catch (NATSJetStreamException) { // it's might not have existed } jsm.AddStream(StreamConfiguration.Builder() .WithName(streamName) .WithStorageType(StorageType.Memory) .WithSubjects(subjects) .Build() ); }
public void TestConstruction() { StreamConfiguration testSc = getTestConfiguration(); // from json Validate(testSc); // Validate(new StreamConfiguration(testSc.ToJsonNode())); Validate(new StreamConfiguration(testSc.ToJsonNode().ToString())); StreamConfiguration.StreamConfigurationBuilder builder = StreamConfiguration.Builder(testSc); Validate(builder.Build()); builder.WithName(testSc.Name) .WithSubjects(testSc.Subjects) .WithRetentionPolicy(testSc.RetentionPolicy) .WithMaxConsumers(testSc.MaxConsumers) .WithMaxMessages(testSc.MaxMsgs) .WithMaxMessagesPerSubject(testSc.MaxMsgsPerSubject) .WithMaxBytes(testSc.MaxBytes) .WithMaxAge(testSc.MaxAge) .WithMaxMsgSize(testSc.MaxValueSize) .WithStorageType(testSc.StorageType) .WithReplicas(testSc.Replicas) .WithNoAck(testSc.NoAck) .WithTemplateOwner(testSc.TemplateOwner) .WithDiscardPolicy(testSc.DiscardPolicy) .WithDuplicateWindow(testSc.DuplicateWindow) .WithPlacement(testSc.Placement) .WithMirror(testSc.Mirror) .WithSources(testSc.Sources) ; Validate(builder.Build()); Validate(builder.AddSources((Source)null).Build()); List <Source> sources = new List <Source>(testSc.Sources); sources.Add(null); Source copy = new Source(sources[0].ToJsonNode()); sources.Add(copy); Validate(builder.AddSources(sources).Build()); }
private void Validate(StreamConfiguration sc) { Assert.Equal("sname", sc.Name); Assert.Collection(sc.Subjects, item => item.Equals("foo"), item => item.Equals("bar")); Assert.Equal(RetentionPolicy.Interest, sc.RetentionPolicy); Assert.Equal(730, sc.MaxConsumers); Assert.Equal(731, sc.MaxMsgs); Assert.Equal(7311, sc.MaxMsgsPerSubject); Assert.Equal(732, sc.MaxBytes); Assert.Equal(Duration.OfNanos(42000000000L), sc.MaxAge); Assert.Equal(734, sc.MaxValueSize); Assert.Equal(StorageType.Memory, sc.StorageType); Assert.Equal(5, sc.Replicas); Assert.False(sc.NoAck); Assert.Equal("twnr", sc.TemplateOwner); Assert.Equal(DiscardPolicy.New, sc.DiscardPolicy); Assert.Equal(Duration.OfNanos(73000000000L), sc.DuplicateWindow); Assert.NotNull(sc.Placement); Assert.Equal("clstr", sc.Placement.Cluster); Assert.Collection(sc.Placement.Tags, item => item.Equals("tag1"), item => item.Equals("tag2")); DateTime zdt = AsDateTime("2020-11-05T19:33:21.163377Z"); Assert.NotNull(sc.Mirror); Assert.Equal("eman", sc.Mirror.Name); Assert.Equal(736u, sc.Mirror.StartSeq); Assert.Equal(zdt, sc.Mirror.StartTime); Assert.Equal("mfsub", sc.Mirror.FilterSubject); Assert.NotNull(sc.Mirror.External); Assert.Equal("apithing", sc.Mirror.External.Api); Assert.Equal("dlvrsub", sc.Mirror.External.Deliver); Assert.Equal(2, sc.Sources.Count); Assert.Collection(sc.Sources, item => ValidateSource(item, "s0", 737, "s0sub", "s0api", "s0dlvrsub", zdt), item => ValidateSource(item, "s1", 738, "s1sub", "s1api", "s1dlvrsub", zdt)); }
public static void CreateStreamWhenDoesNotExist(IJetStreamManagement jsm, string stream, params string[] subjects) { try { jsm.GetStreamInfo(stream); // this throws if the stream does not exist return; } catch (NATSJetStreamException) { /* stream does not exist */ } StreamConfiguration sc = StreamConfiguration.Builder() .WithName(stream) .WithStorageType(StorageType.Memory) .WithSubjects(subjects) .Build(); jsm.AddStream(sc); }
public async Task Save <T>(EventBatch <T> s) { Check.NotNull(s, nameof(T)); Check.Requires( s.Events.Count <= EventID.MaxEventSequence, nameof(s), "A single batch can contain at most {0} events.", EventID.MaxEventSequence ); StreamConfiguration config = _store.GetStreamConfig <T>(); // Now checked by the serializer itself... //RecordedEvent eventOfUnregisteredType = s.Events // .FirstOrDefault(x => !_serializationConfiguration.EventClasses.Contains(x.Event.GetType())); //if (eventOfUnregisteredType != null) { // throw new EventStoreConfigurationException( // $"Event type {eventOfUnregisteredType.Event.GetType().Name} was not configured. Call " + // $"RegisterEventClass when calling MongoEventStore.Configure."); //} if (s.Events.Any()) { EventIDGenerator idGenerator = await _store.GetBatch(); foreach (RecordedEvent e in s.Events) { e.ID = idGenerator.Next(); } await _transaction .GetCollection <RecordedEvent>(MongoEventStore.CollectionName) .InsertManyAsync(s.Events); StreamInfo info = new StreamInfo(s.StreamID); await _transaction .GetCollection <StreamInfo>(MongoEventStore.GetStreamInfoName(config)) .UpsertAsync(x => x.StreamID, info.StreamID, info); } }
public static StreamInfo CreateStreamOrUpdateSubjects(IJetStreamManagement jsm, string streamName, StorageType storageType, params string[] subjects) { StreamInfo si = GetStreamInfoOrNullWhenNotExist(jsm, streamName); if (si == null) { return(CreateStream(jsm, streamName, storageType, subjects)); } // check to see if the configuration has all the subject we want StreamConfiguration sc = si.Config; bool needToUpdate = false; foreach (string sub in subjects) { if (!sc.Subjects.Contains(sub)) { needToUpdate = true; sc.Subjects.Add(sub); } } if (needToUpdate) { si = jsm.UpdateStream(sc); Console.WriteLine("Existing stream '{0}' was updated, has subject(s) [{1}]\n", streamName, string.Join(",", si.Config.Subjects)); // Existing stream 'scratch' [sub1, sub2] } else { Console.WriteLine("Existing stream '{0}' already contained subject(s) [{1}]\n", streamName, string.Join(",", si.Config.Subjects)); } return(si); }
static void TrainFromMiniBatchFile(Trainer trainer, Variable inputs, Variable labels, DeviceDescriptor device, int epochs = 1000, int outputFrequencyInMinibatches = 50) { int i = 0; IList <StreamConfiguration> streamConfigurations = new StreamConfiguration[] { new StreamConfiguration("features", inputs.Shape[0]), new StreamConfiguration("labels", labels.Shape[0]) }; var minibatchSource = MinibatchSource.TextFormatMinibatchSource("XORdataset.txt", streamConfigurations, MinibatchSource.InfinitelyRepeat, true); while (epochs >= 0) { var minibatchData = minibatchSource.GetNextMinibatch(4, device); var arguments = new Dictionary <Variable, MinibatchData> { { inputs, minibatchData[minibatchSource.StreamInfo("features")] }, { labels, minibatchData[minibatchSource.StreamInfo("labels")] } }; trainer.TrainMinibatch(arguments, device); PrintTrainingProgress(trainer, i++, outputFrequencyInMinibatches); if (minibatchData.Values.Any(a => a.sweepEnd)) { epochs--; } } }
/// <summary> /// Add a configuration corresponding to the specified stream name. /// </summary> /// <param name="name"> /// Stream name. /// </param> /// <param name="configuration"> /// Stream configuration. /// </param> protected void AddStreamConfiguration(string name, StreamConfiguration configuration) { this.streamHandlerConfiguration.Add(name, configuration); this.supportedStreams = null; }
internal static void TrainSimpleFeedForwardClassifier(DeviceDescriptor device) { int inputDim = 2; int numOutputClasses = 2; int hiddenLayerDim = 50; int numHiddenLayers = 2; int minibatchSize = 50; int numSamplesPerSweep = 10000; int numSweepsToTrainWith = 2; int numMinibatchesToTrain = (numSamplesPerSweep * numSweepsToTrainWith) / minibatchSize; var featureStreamName = "features"; var labelsStreamName = "labels"; var input = Variable.InputVariable(new int[] { inputDim }, DataType.Float, "features"); var labels = Variable.InputVariable(new int[] { numOutputClasses }, DataType.Float, "labels"); Function classifierOutput; Function trainingLoss; Function prediction; IList <StreamConfiguration> streamConfigurations = new StreamConfiguration[] { new StreamConfiguration(featureStreamName, inputDim), new StreamConfiguration(labelsStreamName, numOutputClasses) }; using (var minibatchSource = MinibatchSource.TextFormatMinibatchSource( Path.Combine(DataFolder, "SimpleDataTrain_cntk_text.txt"), streamConfigurations, MinibatchSource.FullDataSweep, true, MinibatchSource.DefaultRandomizationWindowInChunks)) { var featureStreamInfo = minibatchSource.StreamInfo(featureStreamName); var labelStreamInfo = minibatchSource.StreamInfo(labelsStreamName); IDictionary <StreamInformation, Tuple <NDArrayView, NDArrayView> > inputMeansAndInvStdDevs = new Dictionary <StreamInformation, Tuple <NDArrayView, NDArrayView> > { { featureStreamInfo, new Tuple <NDArrayView, NDArrayView>(null, null) } }; MinibatchSource.ComputeInputPerDimMeansAndInvStdDevs(minibatchSource, inputMeansAndInvStdDevs, device); var normalizedinput = CNTKLib.PerDimMeanVarianceNormalize(input, inputMeansAndInvStdDevs[featureStreamInfo].Item1, inputMeansAndInvStdDevs[featureStreamInfo].Item2); Function fullyConnected = TestHelper.FullyConnectedLinearLayer(normalizedinput, hiddenLayerDim, device, ""); classifierOutput = CNTKLib.Sigmoid(fullyConnected, ""); for (int i = 1; i < numHiddenLayers; ++i) { fullyConnected = TestHelper.FullyConnectedLinearLayer(classifierOutput, hiddenLayerDim, device, ""); classifierOutput = CNTKLib.Sigmoid(fullyConnected, ""); } var outputTimesParam = new Parameter(NDArrayView.RandomUniform <float>( new int[] { numOutputClasses, hiddenLayerDim }, -0.05, 0.05, 1, device)); var outputBiasParam = new Parameter(NDArrayView.RandomUniform <float>( new int[] { numOutputClasses }, -0.05, 0.05, 1, device)); classifierOutput = CNTKLib.Plus(outputBiasParam, outputTimesParam * classifierOutput, "classifierOutput"); trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labels, "lossFunction");; prediction = CNTKLib.ClassificationError(classifierOutput, labels, "classificationError"); // Test save and reload of model { Variable classifierOutputVar = classifierOutput; Variable trainingLossVar = trainingLoss; Variable predictionVar = prediction; var combinedNet = Function.Combine(new List <Variable>() { trainingLoss, prediction, classifierOutput }, "feedForwardClassifier"); TestHelper.SaveAndReloadModel(ref combinedNet, new List <Variable>() { input, labels, trainingLossVar, predictionVar, classifierOutputVar }, device); classifierOutput = classifierOutputVar; trainingLoss = trainingLossVar; prediction = predictionVar; } } CNTK.TrainingParameterScheduleDouble learningRatePerSample = new CNTK.TrainingParameterScheduleDouble( 0.02, TrainingParameterScheduleDouble.UnitType.Sample); using (var minibatchSource = MinibatchSource.TextFormatMinibatchSource( Path.Combine(DataFolder, "SimpleDataTrain_cntk_text.txt"), streamConfigurations)) { var featureStreamInfo = minibatchSource.StreamInfo(featureStreamName); var labelStreamInfo = minibatchSource.StreamInfo(labelsStreamName); streamConfigurations = new StreamConfiguration[] { new StreamConfiguration("features", inputDim), new StreamConfiguration("labels", numOutputClasses) }; IList <Learner> parameterLearners = new List <Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) }; var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners); int outputFrequencyInMinibatches = 20; int trainingCheckpointFrequency = 100; for (int i = 0; i < numMinibatchesToTrain; ++i) { var minibatchData = minibatchSource.GetNextMinibatch((uint)minibatchSize, device); var arguments = new Dictionary <Variable, MinibatchData> { { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }; trainer.TrainMinibatch(arguments, device); TestHelper.PrintTrainingProgress(trainer, i, outputFrequencyInMinibatches); if ((i % trainingCheckpointFrequency) == (trainingCheckpointFrequency - 1)) { string ckpName = "feedForward.net"; trainer.SaveCheckpoint(ckpName); trainer.RestoreFromCheckpoint(ckpName); } } double trainLossValue = trainer.PreviousMinibatchLossAverage(); double evaluationValue = trainer.PreviousMinibatchEvaluationAverage(); if (trainLossValue > 0.3 || evaluationValue > 0.2) { throw new Exception($"TrainSimpleFeedForwardClassifier resulted in unusual high training loss (= {trainLossValue}) or error rate (= {evaluationValue})"); } } }
/// <summary> /// Build and train a RNN model. /// </summary> /// <param name="device">CPU or GPU device to train and run the model</param> public static void Train(DeviceDescriptor device) { const int inputDim = 2000; const int cellDim = 25; const int hiddenDim = 25; const int embeddingDim = 50; const int numOutputClasses = 5; // build the model var featuresName = "features"; var features = Variable.InputVariable(new int[] { inputDim }, DataType.Float, featuresName, null, true /*isSparse*/); var labelsName = "labels"; var labels = Variable.InputVariable(new int[] { numOutputClasses }, DataType.Float, labelsName, new List <Axis>() { Axis.DefaultBatchAxis() }, true); var classifierOutput = LSTMSequenceClassifierNet(features, numOutputClasses, embeddingDim, hiddenDim, cellDim, device, "classifierOutput"); Function trainingLoss = CNTKLib.CrossEntropyWithSoftmax(classifierOutput, labels, "lossFunction"); Function prediction = CNTKLib.ClassificationError(classifierOutput, labels, "classificationError"); // prepare training data IList <StreamConfiguration> streamConfigurations = new StreamConfiguration[] { new StreamConfiguration(featuresName, inputDim, true, "x"), new StreamConfiguration(labelsName, numOutputClasses, false, "y") }; var minibatchSource = MinibatchSource.TextFormatMinibatchSource( Path.Combine(DataFolder, "Train.ctf"), streamConfigurations, MinibatchSource.InfinitelyRepeat, true); var featureStreamInfo = minibatchSource.StreamInfo(featuresName); var labelStreamInfo = minibatchSource.StreamInfo(labelsName); // prepare for training TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble( 0.0005, 1); TrainingParameterScheduleDouble momentumTimeConstant = CNTKLib.MomentumAsTimeConstantSchedule(256); IList <Learner> parameterLearners = new List <Learner>() { Learner.MomentumSGDLearner(classifierOutput.Parameters(), learningRatePerSample, momentumTimeConstant, /*unitGainMomentum = */ true) }; var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners); // train the model uint minibatchSize = 200; int outputFrequencyInMinibatches = 20; int miniBatchCount = 0; int numEpochs = 5; while (numEpochs > 0) { var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device); var arguments = new Dictionary <Variable, MinibatchData> { { features, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }; trainer.TrainMinibatch(arguments, device); TestHelper.PrintTrainingProgress(trainer, miniBatchCount++, outputFrequencyInMinibatches); // Because minibatchSource is created with MinibatchSource.InfinitelyRepeat, // batching will not end. Each time minibatchSource completes an sweep (epoch), // the last minibatch data will be marked as end of a sweep. We use this flag // to count number of epochs. if (TestHelper.MiniBatchDataIsSweepEnd(minibatchData.Values)) { numEpochs--; } } }
/// <summary> /// Train and evaluate a image classifier for MNIST data. /// </summary> /// <param name="device">CPU or GPU device to run training and evaluation</param> /// <param name="useConvolution">option to use convolution network or to use multilayer perceptron</param> /// <param name="forceRetrain">whether to override an existing model. /// if true, any existing model will be overridden and the new one evaluated. /// if false and there is an existing model, the existing model is evaluated.</param> public static void TrainAndEvaluate(DeviceDescriptor device, bool useConvolution, bool forceRetrain) { var featureStreamName = "features"; var labelsStreamName = "labels"; var classifierName = "classifierOutput"; Function classifierOutput; int[] imageDim = useConvolution ? new int[] { 28, 28, 1 } : new int[] { 784 }; int imageSize = 28 * 28; int numClasses = 10; IList <StreamConfiguration> streamConfigurations = new StreamConfiguration[] { new StreamConfiguration(featureStreamName, imageSize), new StreamConfiguration(labelsStreamName, numClasses) }; string modelFile = useConvolution ? "MNISTConvolution.model" : "MNISTMLP.model"; // If a model already exists and not set to force retrain, validate the model and return. if (File.Exists(modelFile) && !forceRetrain) { var minibatchSourceExistModel = MinibatchSource.TextFormatMinibatchSource( Path.Combine(ImageDataFolder, "Test_cntk_text.txt"), streamConfigurations); TestHelper.ValidateModelWithMinibatchSource(modelFile, minibatchSourceExistModel, imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device); return; } // build the network var input = CNTKLib.InputVariable(imageDim, DataType.Float, featureStreamName); if (useConvolution) { var scaledInput = CNTKLib.ElementTimes(Constant.Scalar <float>(0.00390625f, device), input); classifierOutput = CreateConvolutionalNeuralNetwork(scaledInput, numClasses, device, classifierName); } else { // For MLP, we like to have the middle layer to have certain amount of states. int hiddenLayerDim = 200; var scaledInput = CNTKLib.ElementTimes(Constant.Scalar <float>(0.00390625f, device), input); classifierOutput = CreateMLPClassifier(device, numClasses, hiddenLayerDim, scaledInput, classifierName); } var labels = CNTKLib.InputVariable(new int[] { numClasses }, DataType.Float, labelsStreamName); var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction"); var prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError"); // prepare training data var minibatchSource = MinibatchSource.TextFormatMinibatchSource( Path.Combine(ImageDataFolder, "Train_cntk_text.txt"), streamConfigurations, MinibatchSource.InfinitelyRepeat); var featureStreamInfo = minibatchSource.StreamInfo(featureStreamName); var labelStreamInfo = minibatchSource.StreamInfo(labelsStreamName); // set per sample learning rate CNTK.TrainingParameterScheduleDouble learningRatePerSample = new CNTK.TrainingParameterScheduleDouble( 0.003125, TrainingParameterScheduleDouble.UnitType.Sample); IList <Learner> parameterLearners = new List <Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) }; var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners); // const uint minibatchSize = 64; int outputFrequencyInMinibatches = 20, i = 0; int epochs = 5; while (epochs > 0) { var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device); var arguments = new Dictionary <Variable, MinibatchData> { { input, minibatchData[featureStreamInfo] }, { labels, minibatchData[labelStreamInfo] } }; trainer.TrainMinibatch(arguments, device); TestHelper.PrintTrainingProgress(trainer, i++, outputFrequencyInMinibatches); // MinibatchSource is created with MinibatchSource.InfinitelyRepeat. // Batching will not end. Each time minibatchSource completes an sweep (epoch), // the last minibatch data will be marked as end of a sweep. We use this flag // to count number of epochs. if (TestHelper.MiniBatchDataIsSweepEnd(minibatchData.Values)) { epochs--; } } // save the trained model classifierOutput.Save(modelFile); // validate the model var minibatchSourceNewModel = MinibatchSource.TextFormatMinibatchSource( Path.Combine(ImageDataFolder, "Test_cntk_text.txt"), streamConfigurations, MinibatchSource.FullDataSweep); TestHelper.ValidateModelWithMinibatchSource(modelFile, minibatchSourceNewModel, imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device); }
/// <summary> /// Update the MLCamera characteristics. /// </summary> /// <returns> /// MLResult.Result will be <c>MLResult.Code.Ok</c> if obtained camera characteristic handle successfully. /// MLResult.Result will be <c>MLResult.Code.InvalidParam</c> if failed to obtain camera characteristic handle due to invalid input parameter. /// MLResult.Result will be <c>MLResult.Code.MediaGenericUnexpectedNull</c> if failed to capture raw image due to null pointer. /// MLResult.Result will be <c>MLResult.Code.AllocFailed</c> if failed to allocate memory. /// MLResult.Result will be <c>MLResult.Code.PrivilegeDenied</c> if a required permission is missing. /// </returns> internal MLResult PopulateCharacteristics() { MLResult.Code resultCode; ulong cameraCharacteristicsHandle = MagicLeapNativeBindings.InvalidHandle; resultCode = MLCameraNativeBindings.MLCameraGetCameraCharacteristics(ref cameraCharacteristicsHandle); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get camera characteristics for MLCamera. Reason: {0}", result); return(result); } ulong controlAEModeCount = 0; IntPtr controlAEAvailableModesData = IntPtr.Zero; resultCode = MLCameraNativeBindings.MLCameraMetadataGetControlAEAvailableModes(cameraCharacteristicsHandle, ref controlAEAvailableModesData, ref controlAEModeCount); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get camera control AE available modes for MLCamera. Reason: {0}", result); return(result); } this.ControlAEModesAvailable = new List <MLCamera.MetadataControlAEMode>(); int[] controlAEModeArray = new int[controlAEModeCount]; Marshal.Copy(controlAEAvailableModesData, controlAEModeArray, 0, (int)controlAEModeCount); for (int i = 0; i < controlAEModeArray.Length; ++i) { this.ControlAEModesAvailable.Add((MLCamera.MetadataControlAEMode)controlAEModeArray[i]); } ulong colorCorrectionAberrationModeCount = 0; IntPtr colorCorrectionAberrationModesData = IntPtr.Zero; resultCode = MLCameraNativeBindings.MLCameraMetadataGetColorCorrectionAvailableAberrationModes(cameraCharacteristicsHandle, ref colorCorrectionAberrationModesData, ref colorCorrectionAberrationModeCount); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get color correction aberration modes available for MLCamera. Reason: {0}", result); return(result); } this.ColorCorrectionAberrationModesAvailable = new List <MLCamera.MetadataColorCorrectionAberrationMode>((int)colorCorrectionAberrationModeCount); int[] aberrationModeArray = new int[colorCorrectionAberrationModeCount]; Marshal.Copy(colorCorrectionAberrationModesData, aberrationModeArray, 0, (int)colorCorrectionAberrationModeCount); for (int i = 0; i < aberrationModeArray.Length; ++i) { this.ColorCorrectionAberrationModesAvailable.Add((MLCamera.MetadataColorCorrectionAberrationMode)aberrationModeArray[i]); } int[] compensationRange = new int[2]; resultCode = MLCameraNativeBindings.MLCameraMetadataGetControlAECompensationRange(cameraCharacteristicsHandle, compensationRange); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get control AE compensation range. Reason: {0}", result); return(result); } this.AECompensationRange = new AECompensationRangeValues(compensationRange[0], compensationRange[1]); MLCameraNativeBindings.MLCameraMetadataRationalNative rational = new MLCameraNativeBindings.MLCameraMetadataRationalNative(); resultCode = MLCameraNativeBindings.MLCameraMetadataGetControlAECompensationStep(cameraCharacteristicsHandle, ref rational); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get control AE compensation step. Reason: {0}", result); return(result); } this.AECompensationStepNumerator = rational.Numerator; this.AECompensationStepDenominator = rational.Denominator; this.AECompensationStep = (float)this.AECompensationStepNumerator / (float)this.AECompensationStepDenominator; float availableMaxDigitalZoom = 0.0f; resultCode = MLCameraNativeBindings.MLCameraMetadataGetScalerAvailableMaxDigitalZoom(cameraCharacteristicsHandle, ref availableMaxDigitalZoom); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get max available digital zoom. Reason: {0}", result); return(result); } this.AvailableMaxDigitalZoom = availableMaxDigitalZoom; int sensorOrientation = 0; resultCode = MLCameraNativeBindings.MLCameraMetadataGetSensorOrientation(cameraCharacteristicsHandle, ref sensorOrientation); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get sensor orientation. Reason: {0}", result); return(result); } this.SensorOrientation = sensorOrientation; MLCamera.MetadataControlAELock controlAELockAvailable = 0; resultCode = MLCameraNativeBindings.MLCameraMetadataGetControlAELockAvailable(cameraCharacteristicsHandle, ref controlAELockAvailable); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get control AE lock available. Reason: {0}", result); return(result); } this.ControlAELockAvailable = controlAELockAvailable; ulong controlAWBModeCount = 0; IntPtr controlAWBAvailableModesData = IntPtr.Zero; resultCode = MLCameraNativeBindings.MLCameraMetadataGetControlAWBAvailableModes(cameraCharacteristicsHandle, ref controlAWBAvailableModesData, ref controlAWBModeCount); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get camera control ABW available modes for MLCamera. Reason: {0}", result); return(result); } this.ControlAWBModesAvailable = new List <MLCamera.MetadataControlAWBMode>(); int[] awbModeArray = new int[controlAWBModeCount]; Marshal.Copy(controlAWBAvailableModesData, awbModeArray, 0, (int)controlAWBModeCount); for (int i = 0; i < awbModeArray.Length; ++i) { this.ControlAWBModesAvailable.Add((MLCamera.MetadataControlAWBMode)awbModeArray[i]); } MLCamera.MetadataControlAWBLock controlAWBLockAvailable = 0; resultCode = MLCameraNativeBindings.MLCameraMetadataGetControlAWBLockAvailable(cameraCharacteristicsHandle, ref controlAWBLockAvailable); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get control AWB lock available. Reason: {0}", result); return(result); } this.ControlAWBLockAvailable = controlAWBLockAvailable; int[] sensorInfoActiveArraySize = new int[4]; resultCode = MLCameraNativeBindings.MLCameraMetadataGetSensorInfoActiveArraySize(cameraCharacteristicsHandle, sensorInfoActiveArraySize); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get sensor info active array size. Reason: {0}", result); return(result); } this.SensorInfoActiveArraySize = new SensorInfoActiveArraySizeValues( sensorInfoActiveArraySize[0], sensorInfoActiveArraySize[1], sensorInfoActiveArraySize[2], sensorInfoActiveArraySize[3]); ulong scalerProcessedSizesCount = 0; IntPtr scalerProcessedSizesData = IntPtr.Zero; resultCode = MLCameraNativeBindings.MLCameraMetadataGetScalerProcessedSizes(cameraCharacteristicsHandle, ref scalerProcessedSizesData, ref scalerProcessedSizesCount); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get scaler processed sizes. Reason: {0}", result); return(result); } int[] scalerProcessedSizesDataArray = new int[scalerProcessedSizesCount]; Marshal.Copy(scalerProcessedSizesData, scalerProcessedSizesDataArray, 0, (int)scalerProcessedSizesCount); this.ScalerProcessedSizes = new List <ScalerProcessedSize>((int)scalerProcessedSizesCount); for (int i = 0; i < (int)scalerProcessedSizesCount; i += 2) { int width = scalerProcessedSizesDataArray[i]; int height = scalerProcessedSizesDataArray[i + 1]; ScalerProcessedSize newSize = new ScalerProcessedSize(width, height); this.ScalerProcessedSizes.Add(newSize); } ulong streamConfigurationsCount = 0; IntPtr streamConfigurationsData = IntPtr.Zero; resultCode = MLCameraNativeBindings.MLCameraMetadataGetScalerAvailableStreamConfigurations(cameraCharacteristicsHandle, ref streamConfigurationsData, ref streamConfigurationsCount); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get scaler available stream configurations. Reason: {0}", result); return(result); } int[] streamConfigurationsDataArray = new int[streamConfigurationsCount]; Marshal.Copy(streamConfigurationsData, streamConfigurationsDataArray, 0, (int)streamConfigurationsCount); this.ScalerAvailableStreamConfigurations = new List <StreamConfiguration>(); for (int i = 0; i < (int)streamConfigurationsCount; i += 4) { StreamConfiguration config = new StreamConfiguration( (MLCamera.MetadataScalerAvailableFormats)streamConfigurationsDataArray[i], streamConfigurationsDataArray[i + 1], streamConfigurationsDataArray[i + 2], (MLCamera.MetadataScalerAvailableStreamConfigurations)streamConfigurationsDataArray[i + 3]); this.ScalerAvailableStreamConfigurations.Add(config); } int[] sensorInfoSensitivityRange = new int[2]; resultCode = MLCameraNativeBindings.MLCameraMetadataGetSensorInfoSensitivityRange(cameraCharacteristicsHandle, sensorInfoSensitivityRange); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get sensor info sensitivity range. Reason: {0}", result); return(result); } this.SensorInfoSensitivityRange = new SensorInfoSensitivityRangeValues(sensorInfoSensitivityRange[0], sensorInfoSensitivityRange[1]); long[] sensorInfoExposureTimeRange = new long[2]; resultCode = MLCameraNativeBindings.MLCameraMetadataGetSensorInfoExposureTimeRange(cameraCharacteristicsHandle, sensorInfoExposureTimeRange); if (!MLResult.IsOK(resultCode)) { MLResult result = MLResult.Create(resultCode); MLPluginLog.ErrorFormat("MLCamera.GeneralSettings.PopulateCharacteristics failed to get sensor info exposure time range. Reason: {0}", result); return(result); } this.SensorInfoExposureTimeRange = new SensorInfoExposureTimeRangeValues(sensorInfoExposureTimeRange[0], sensorInfoExposureTimeRange[1]); return(MLResult.Create(MLResult.Code.Ok)); }