static Model CreateModel(Function inputVariable, Variable targetVariable, int targetCount,
                                 DataType dataType, DeviceDescriptor device)
        {
            var random = new Random(232);
            Func <CNTKDictionary> weightInit = () => Initializers.GlorotNormal(random.Next());
            var biasInit = Initializers.Zero();

            // Create the architecture.
            var network = inputVariable

                          .Dense(32, weightInit(), biasInit, device, dataType)
                          .ReLU()
                          .Dense(32, weightInit(), biasInit, device, dataType)
                          .ReLU()
                          .Dense(targetCount, weightInit(), biasInit, device, dataType);

            // loss
            var lossFunc   = Losses.MeanSquaredError(network.Output, targetVariable);
            var metricFunc = Losses.MeanAbsoluteError(network.Output, targetVariable);

            // setup trainer.
            var learner = CntkCatalyst.Learners.Adam(network.Parameters());
            var trainer = CNTKLib.CreateTrainer(network, lossFunc, metricFunc, new LearnerVector {
                learner
            });

            var model = new Model(trainer, network, dataType, device);

            Trace.WriteLine(model.Summary());
            return(model);
        }
Exemplo n.º 2
0
        public void MeanSquareError_Zero_Error()
        {
            var targetsData     = new float[] { 0, 0, 0, 0, 0, 0 };
            var targetsVariable = CNTKLib.InputVariable(new int[] { targetsData.Length }, m_dataType);

            var predictionsData     = new float[] { 0, 0, 0, 0, 0, 0 };
            var predictionsVariable = CNTKLib.InputVariable(new int[] { predictionsData.Length }, m_dataType);

            var sut    = Losses.MeanSquaredError(predictionsVariable, targetsVariable);
            var actual = Evaluate(sut, targetsVariable, targetsData,
                                  predictionsVariable, predictionsData);

            Assert.AreEqual(0.0f, actual);
        }
Exemplo n.º 3
0
        public void MeanSquareError()
        {
            var targetsData     = new float[] { 1.0f, 2.3f, 3.1f, 4.4f, 5.8f };
            var targetsVariable = CNTKLib.InputVariable(new int[] { targetsData.Length }, m_dataType);

            var predictionsData     = new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f };
            var predictionsVariable = CNTKLib.InputVariable(new int[] { predictionsData.Length }, m_dataType);

            var sut    = Losses.MeanSquaredError(predictionsVariable, targetsVariable);
            var actual = Evaluate(sut, targetsVariable, targetsData,
                                  predictionsVariable, predictionsData);

            Assert.AreEqual(0.18f, actual, 0.00001);
        }