Exemplo n.º 1
0
        public double GetValidationMetric()
        {
            if (ValidationSampler == null || Trainer.EvaluationFunction() == null)
            {
                return(0.0);
            }

            double metric = 0.0;
            int    count  = 0;

            Minibatch testData;

            do
            {
                testData = ValidationSampler.GetNextMinibatch(TestDevice);
                if (testData == null)
                {
                    break;
                }

                var arguments = DataNameToInputMap.GetVariableValueMapAsCNTKUnorderedMap(testData);
                metric += Trainer.TestMinibatch(arguments, TestDevice);
                ++count;
            }while (!testData.SweepEnd);

            return(metric / count);
        }
Exemplo n.º 2
0
        public TrainingSession(WrappedFunction model, WrappedFunction lossFunction, WrappedFunction evaluationFunction, Learner learner, ILearningScheduler scheduler, ISampler sampler, ISampler validationSampler, Hashtable dataNameToInputMap = null, DeviceDescriptor trainingDevice = null, DeviceDescriptor testDevice = null, ICallback[] callbacks = null)
        {
            Learner = learner;
            LearningRateScheduler = scheduler;

            Trainer = Trainer.CreateTrainer(model, lossFunction, evaluationFunction, new Learner[] { learner });

            Sampler           = sampler;
            ValidationSampler = validationSampler;

            TrainingDevice = trainingDevice;
            if (TrainingDevice == null)
            {
                TrainingDevice = DeviceDescriptor.UseDefaultDevice();
            }

            TestDevice = testDevice;
            if (TestDevice == null)
            {
                TestDevice = DeviceDescriptor.UseDefaultDevice();
            }

            DataNameToInputMap = new DataNameToInputMap(
                new Function[] { model, lossFunction, evaluationFunction },
                dataNameToInputMap);

            if (callbacks == null)
            {
                Callbacks = new ICallback[0];
            }
            else
            {
                Callbacks = callbacks;
            }
        }
Exemplo n.º 3
0
        protected override void ProcessRecord()
        {
            IList <Value> results;

            if (ParameterSetName == "Hashtable")
            {
                results = FunctionInvoke.Invoke(Function, Arguments, Device, false);
            }
            else if (ParameterSetName == "DataSourceSet")
            {
                results = FunctionInvoke.Invoke(Function, DataSourceSet, Device, false);
            }
            else
            {
                DataNameToInputMap map   = new DataNameToInputMap(new Function[] { Function }, DataNameToInputMap);
                Minibatch          batch = null;
                var values = new List <Value>();
                do
                {
                    batch = Sampler.GetNextMinibatch(Device);
                    map.InitializeByMinibatch(batch);
                    values.AddRange(FunctionInvoke.Invoke(Function, batch, map, Device));
                }while (!batch.SweepEnd);
                results = values;
            }

            foreach (var r in results)
            {
                WriteObject(r);
            }
        }
Exemplo n.º 4
0
        public static Value[] Invoke(PSObject func, object arguments = null, DataNameToInputMap map = null)
        {
            Function f = ToFunction(func);

            if (arguments == null)
                return FunctionInvoke.Invoke(f, new Dictionary<Variable, Value>(), null, false);

            if (arguments is Dictionary<Variable, Value> vvdic)
                return FunctionInvoke.Invoke(f, vvdic, null, false);

            if (arguments is Dictionary<string, Value> svdic)
                return FunctionInvoke.Invoke(f, svdic, null, false);

            if (arguments is Hashtable ht)
                return FunctionInvoke.Invoke(f, ht, null, false);

            if (arguments is Minibatch mb)
                return FunctionInvoke.Invoke(f, mb, map, null, false);

            if (arguments is IDictionary<string, IDataSource<float>> sddic)
                return FunctionInvoke.Invoke(f, sddic, null, false);

            if (arguments is DataSourceSet dss)
                return FunctionInvoke.Invoke(f, dss, null, false);

            throw new ArgumentException("Invalid type: arguments");
        }
Exemplo n.º 5
0
        public static Value[] Invoke(this Function func, Minibatch batch, DataNameToInputMap map = null, DeviceDescriptor device = null, bool errorWhenArgumentUnused = true)
        {
            if (map == null)
            {
                map = new DataNameToInputMap(new Function[] { func });
            }

            map.InitializeByMinibatch(batch);

            var inputs = map.GetVariableValueMap(batch);

            return(Invoke(func, inputs, device, errorWhenArgumentUnused));
        }
Exemplo n.º 6
0
        public FunctionGetNodeInfo(Function func, bool showValues, Hashtable arguments = null, Minibatch minibatch = null, DataNameToInputMap map = null)
        {
            _showValues = showValues;
            _arguments  = arguments;
            _minibatch  = minibatch;
            _map        = map;

            _queue     = new BlockingCollection <NodeInfo>();
            _history   = new Dictionary <string, NodeInfo>();
            _poison    = new NodeInfo();
            _exception = null;

            Task.Run(() => {
                try
                {
                    new NodeWalk(func, this);
                }
                catch (Exception e)
                {
                    _exception = e;
                    _queue.Add(_poison);
                }
            });
        }
Exemplo n.º 7
0
        public IEnumerable <TrainingSession> GetIterator(int maxIteration = int.MaxValue)
        {
            _stopwatch = Stopwatch.StartNew();
            _stop      = false;

            Epoch            = 1;
            EpochIncremented = false;

            if (LearningRateScheduler != null)
            {
                Learner.ResetLearningRate(new TrainingParameterScheduleDouble(LearningRateScheduler.LearningRate));
            }

            for (Iteration = 1; Iteration <= maxIteration; ++Iteration)
            {
                var minibatch = Sampler.GetNextMinibatch(TrainingDevice);
                if (minibatch == null)
                {
                    break;
                }

                DataNameToInputMap.InitializeByMinibatch(minibatch);

                var arguments = DataNameToInputMap.GetVariableValueMap(minibatch);

                Trainer.TrainMinibatch(arguments, minibatch.SweepEnd, TrainingDevice);

                SampleCount = (int)Trainer.PreviousMinibatchSampleCount();
                Loss        = Trainer.PreviousMinibatchLossAverage();
                if (Trainer.EvaluationFunction() != null)
                {
                    Metric = Trainer.PreviousMinibatchEvaluationAverage();
                }

                foreach (var cb in Callbacks)
                {
                    cb.Run(this);
                }

                if (_stop)
                {
                    break;
                }

                yield return(this);

                if (LearningRateScheduler != null)
                {
                    bool update = LearningRateScheduler.UpdateLearningRate(Epoch, Iteration, Loss);
                    if (update)
                    {
                        Learner.ResetLearningRate(new TrainingParameterScheduleDouble(LearningRateScheduler.LearningRate));
                    }
                }

                EpochIncremented = false;
                if (minibatch.SweepEnd)
                {
                    ++Epoch;
                    EpochIncremented = true;
                }
            }
        }
Exemplo n.º 8
0
        public static IEnumerable<NodeInfo> GetNodeInfoWithValues(PSObject func, object arguments = null, DataNameToInputMap map = null)
        {
            Function f = ToFunction(func);
            FunctionGetNodeInfo w;

            if (arguments == null)
                w = new FunctionGetNodeInfo(f, true, null, null);
            else if (arguments is Hashtable)
                w = new FunctionGetNodeInfo(f, true, arguments as Hashtable);
            else
                w = new FunctionGetNodeInfo(f, true, null, arguments as Minibatch, map);

            return w.GetNodeInfo();
        }
Exemplo n.º 9
0
        public static string AsTreeWithValues(PSObject func, object arguments = null, DataNameToInputMap map = null, bool showUid = true)
        {
            Function f = ToFunction(func);
            FunctionAsTree w;

            if (arguments == null)
                w = new FunctionAsTree(f, true, null, null, null);
            else if (arguments is Hashtable)
                w = new FunctionAsTree(f, true, arguments as Hashtable, null, null, showUid);
            else
                w = new FunctionAsTree(f, true, null, arguments as Minibatch, map, showUid);

            return w.Result;
        }