protected override void ProcessRecord() { IList <Value> results; if (ParameterSetName == "Hashtable") { results = FunctionInvoke.Invoke(Function, Arguments, Device, false); } else if (ParameterSetName == "DataSourceSet") { results = FunctionInvoke.Invoke(Function, DataSourceSet, Device, false); } else { DataNameToInputMap map = new DataNameToInputMap(new Function[] { Function }, DataNameToInputMap); Minibatch batch = null; var values = new List <Value>(); do { batch = Sampler.GetNextMinibatch(Device); map.InitializeByMinibatch(batch); values.AddRange(FunctionInvoke.Invoke(Function, batch, map, Device)); }while (!batch.SweepEnd); results = values; } foreach (var r in results) { WriteObject(r); } }
public UnorderedMapVariableValuePtr GetVariableValueMapAsCNTKUnorderedMap(Minibatch batch) { var arguments = GetVariableValueMap(batch); var map = new UnorderedMapVariableValuePtr(); foreach (var entry in arguments) { map.Add(entry.Key, entry.Value); } return(map); }
public static Value[] Invoke(this Function func, Minibatch batch, DataNameToInputMap map = null, DeviceDescriptor device = null, bool errorWhenArgumentUnused = true) { if (map == null) { map = new DataNameToInputMap(new Function[] { func }); } map.InitializeByMinibatch(batch); var inputs = map.GetVariableValueMap(batch); return(Invoke(func, inputs, device, errorWhenArgumentUnused)); }
public override Minibatch GetNextMinibatch(DeviceDescriptor device = null) { if (device == null) { device = DeviceDescriptor.UseDefaultDevice(); } Value value; if (InputVariable == null) { value = FunctionInvoke.Invoke(Expression, new Dictionary <Variable, Value>(), device, false)[0]; } else { value = FunctionInvoke.Invoke(Expression, new Dictionary <Variable, Value>() { { InputVariable, PrevValue } }, device, false)[0]; } int sampleCount = 0; int rank = value.Shape.Rank; if (rank == 0) { sampleCount = 1; } else { sampleCount = value.Shape[rank - 1]; } ++Iterations; var sweepEnd = (Iterations + 1) % IterationsPerEpoch == 0; var data = new MinibatchData(value, (uint)sampleCount, sweepEnd); var minibatch = new Minibatch(); minibatch.Add(Name, data); PrevValue = value; return(minibatch); }
public override Minibatch GetNextMinibatch(DeviceDescriptor device = null) { if (device == null) { device = DeviceDescriptor.UseDefaultDevice(); } var minibatchMap = _minibatchSource.GetNextMinibatch((uint)MinibatchSize, device); var minibatch = new Minibatch(); foreach (var info in _streamInfos) { minibatch.Add(info.Key, minibatchMap[info.Value]); } return(minibatch); }
public void InitializeByMinibatch(Minibatch batch) { if (_map.Count > 0) { return; } foreach (var entry in batch.Features) { var name = entry.Key; var variables = FindVariables(name); if (variables != null) { foreach (var va in variables) { AddToMap(name, va); } } } }
public override Minibatch GetNextMinibatch(DeviceDescriptor device = null) { if (device == null) { device = DeviceDescriptor.UseDefaultDevice(); } for (var i = 0; i < _dataSize * MinibatchSize; ++i) { _data[i] = (float)(_random.NextDouble() * (Max - Min) + Min); } ++Iterations; var sweepEnd = (Iterations + 1) % IterationsPerEpoch == 0; var minibatch = new Minibatch(new Dictionary <string, IDataSource <float> >() { { Name, _samples } }, sweepEnd, device); return(minibatch); }
public Dictionary <Variable, Value> GetVariableValueMap(Minibatch batch) { var arguments = new Dictionary <Variable, Value>(); foreach (var entry in batch.Features) { List <Variable> variables = null; if (_map.TryGetValue(entry.Key, out variables)) { foreach (var va in variables) { arguments.Add(va, entry.Value); } } } if (arguments.Count == 0) { throw new ApplicationException("Minibatch is empty or contains no data corresponding to the input variables of the model"); } return(arguments); }
public FunctionGetNodeInfo(Function func, bool showValues, Hashtable arguments = null, Minibatch minibatch = null, DataNameToInputMap map = null) { _showValues = showValues; _arguments = arguments; _minibatch = minibatch; _map = map; _queue = new BlockingCollection <NodeInfo>(); _history = new Dictionary <string, NodeInfo>(); _poison = new NodeInfo(); _exception = null; Task.Run(() => { try { new NodeWalk(func, this); } catch (Exception e) { _exception = e; _queue.Add(_poison); } }); }