public double GetValidationMetric() { if (ValidationSampler == null || Trainer.EvaluationFunction() == null) { return(0.0); } double metric = 0.0; int count = 0; Minibatch testData; do { testData = ValidationSampler.GetNextMinibatch(TestDevice); if (testData == null) { break; } var arguments = DataNameToInputMap.GetVariableValueMapAsCNTKUnorderedMap(testData); metric += Trainer.TestMinibatch(arguments, TestDevice); ++count; }while (!testData.SweepEnd); return(metric / count); }
public TrainingSession(WrappedFunction model, WrappedFunction lossFunction, WrappedFunction evaluationFunction, Learner learner, ILearningScheduler scheduler, ISampler sampler, ISampler validationSampler, Hashtable dataNameToInputMap = null, DeviceDescriptor trainingDevice = null, DeviceDescriptor testDevice = null, ICallback[] callbacks = null) { Learner = learner; LearningRateScheduler = scheduler; Trainer = Trainer.CreateTrainer(model, lossFunction, evaluationFunction, new Learner[] { learner }); Sampler = sampler; ValidationSampler = validationSampler; TrainingDevice = trainingDevice; if (TrainingDevice == null) { TrainingDevice = DeviceDescriptor.UseDefaultDevice(); } TestDevice = testDevice; if (TestDevice == null) { TestDevice = DeviceDescriptor.UseDefaultDevice(); } DataNameToInputMap = new DataNameToInputMap( new Function[] { model, lossFunction, evaluationFunction }, dataNameToInputMap); if (callbacks == null) { Callbacks = new ICallback[0]; } else { Callbacks = callbacks; } }
protected override void ProcessRecord() { IList <Value> results; if (ParameterSetName == "Hashtable") { results = FunctionInvoke.Invoke(Function, Arguments, Device, false); } else if (ParameterSetName == "DataSourceSet") { results = FunctionInvoke.Invoke(Function, DataSourceSet, Device, false); } else { DataNameToInputMap map = new DataNameToInputMap(new Function[] { Function }, DataNameToInputMap); Minibatch batch = null; var values = new List <Value>(); do { batch = Sampler.GetNextMinibatch(Device); map.InitializeByMinibatch(batch); values.AddRange(FunctionInvoke.Invoke(Function, batch, map, Device)); }while (!batch.SweepEnd); results = values; } foreach (var r in results) { WriteObject(r); } }
public static Value[] Invoke(PSObject func, object arguments = null, DataNameToInputMap map = null) { Function f = ToFunction(func); if (arguments == null) return FunctionInvoke.Invoke(f, new Dictionary<Variable, Value>(), null, false); if (arguments is Dictionary<Variable, Value> vvdic) return FunctionInvoke.Invoke(f, vvdic, null, false); if (arguments is Dictionary<string, Value> svdic) return FunctionInvoke.Invoke(f, svdic, null, false); if (arguments is Hashtable ht) return FunctionInvoke.Invoke(f, ht, null, false); if (arguments is Minibatch mb) return FunctionInvoke.Invoke(f, mb, map, null, false); if (arguments is IDictionary<string, IDataSource<float>> sddic) return FunctionInvoke.Invoke(f, sddic, null, false); if (arguments is DataSourceSet dss) return FunctionInvoke.Invoke(f, dss, null, false); throw new ArgumentException("Invalid type: arguments"); }
public static Value[] Invoke(this Function func, Minibatch batch, DataNameToInputMap map = null, DeviceDescriptor device = null, bool errorWhenArgumentUnused = true) { if (map == null) { map = new DataNameToInputMap(new Function[] { func }); } map.InitializeByMinibatch(batch); var inputs = map.GetVariableValueMap(batch); return(Invoke(func, inputs, device, errorWhenArgumentUnused)); }
public FunctionGetNodeInfo(Function func, bool showValues, Hashtable arguments = null, Minibatch minibatch = null, DataNameToInputMap map = null) { _showValues = showValues; _arguments = arguments; _minibatch = minibatch; _map = map; _queue = new BlockingCollection <NodeInfo>(); _history = new Dictionary <string, NodeInfo>(); _poison = new NodeInfo(); _exception = null; Task.Run(() => { try { new NodeWalk(func, this); } catch (Exception e) { _exception = e; _queue.Add(_poison); } }); }
public IEnumerable <TrainingSession> GetIterator(int maxIteration = int.MaxValue) { _stopwatch = Stopwatch.StartNew(); _stop = false; Epoch = 1; EpochIncremented = false; if (LearningRateScheduler != null) { Learner.ResetLearningRate(new TrainingParameterScheduleDouble(LearningRateScheduler.LearningRate)); } for (Iteration = 1; Iteration <= maxIteration; ++Iteration) { var minibatch = Sampler.GetNextMinibatch(TrainingDevice); if (minibatch == null) { break; } DataNameToInputMap.InitializeByMinibatch(minibatch); var arguments = DataNameToInputMap.GetVariableValueMap(minibatch); Trainer.TrainMinibatch(arguments, minibatch.SweepEnd, TrainingDevice); SampleCount = (int)Trainer.PreviousMinibatchSampleCount(); Loss = Trainer.PreviousMinibatchLossAverage(); if (Trainer.EvaluationFunction() != null) { Metric = Trainer.PreviousMinibatchEvaluationAverage(); } foreach (var cb in Callbacks) { cb.Run(this); } if (_stop) { break; } yield return(this); if (LearningRateScheduler != null) { bool update = LearningRateScheduler.UpdateLearningRate(Epoch, Iteration, Loss); if (update) { Learner.ResetLearningRate(new TrainingParameterScheduleDouble(LearningRateScheduler.LearningRate)); } } EpochIncremented = false; if (minibatch.SweepEnd) { ++Epoch; EpochIncremented = true; } } }
public static IEnumerable<NodeInfo> GetNodeInfoWithValues(PSObject func, object arguments = null, DataNameToInputMap map = null) { Function f = ToFunction(func); FunctionGetNodeInfo w; if (arguments == null) w = new FunctionGetNodeInfo(f, true, null, null); else if (arguments is Hashtable) w = new FunctionGetNodeInfo(f, true, arguments as Hashtable); else w = new FunctionGetNodeInfo(f, true, null, arguments as Minibatch, map); return w.GetNodeInfo(); }
public static string AsTreeWithValues(PSObject func, object arguments = null, DataNameToInputMap map = null, bool showUid = true) { Function f = ToFunction(func); FunctionAsTree w; if (arguments == null) w = new FunctionAsTree(f, true, null, null, null); else if (arguments is Hashtable) w = new FunctionAsTree(f, true, arguments as Hashtable, null, null, showUid); else w = new FunctionAsTree(f, true, null, arguments as Minibatch, map, showUid); return w.Result; }