示例#1
0
        public override Minibatch GetNextMinibatch(DeviceDescriptor device = null)
        {
            if (device == null)
            {
                device = DeviceDescriptor.UseDefaultDevice();
            }

            Value value;

            if (InputVariable == null)
            {
                value = FunctionInvoke.Invoke(Expression, new Dictionary <Variable, Value>(), device, false)[0];
            }
            else
            {
                value = FunctionInvoke.Invoke(Expression, new Dictionary <Variable, Value>()
                {
                    { InputVariable, PrevValue }
                }, device, false)[0];
            }

            int sampleCount = 0;
            int rank        = value.Shape.Rank;

            if (rank == 0)
            {
                sampleCount = 1;
            }
            else
            {
                sampleCount = value.Shape[rank - 1];
            }

            ++Iterations;
            var sweepEnd = (Iterations + 1) % IterationsPerEpoch == 0;

            var data      = new MinibatchData(value, (uint)sampleCount, sweepEnd);
            var minibatch = new Minibatch();

            minibatch.Add(Name, data);

            PrevValue = value;

            return(minibatch);
        }
示例#2
0
        public override Minibatch GetNextMinibatch(DeviceDescriptor device = null)
        {
            if (device == null)
            {
                device = DeviceDescriptor.UseDefaultDevice();
            }

            var minibatchMap = _minibatchSource.GetNextMinibatch((uint)MinibatchSize, device);

            var minibatch = new Minibatch();

            foreach (var info in _streamInfos)
            {
                minibatch.Add(info.Key, minibatchMap[info.Value]);
            }

            return(minibatch);
        }