static Model CreateModel(Function inputVariable, Variable targetVariable, int targetCount, DataType dataType, DeviceDescriptor device) { var random = new Random(232); Func <CNTKDictionary> weightInit = () => Initializers.GlorotNormal(random.Next()); var biasInit = Initializers.Zero(); // Create the architecture. var network = inputVariable .Dense(32, weightInit(), biasInit, device, dataType) .ReLU() .Dense(32, weightInit(), biasInit, device, dataType) .ReLU() .Dense(targetCount, weightInit(), biasInit, device, dataType); // loss var lossFunc = Losses.MeanSquaredError(network.Output, targetVariable); var metricFunc = Losses.MeanAbsoluteError(network.Output, targetVariable); // setup trainer. var learner = CntkCatalyst.Learners.Adam(network.Parameters()); var trainer = CNTKLib.CreateTrainer(network, lossFunc, metricFunc, new LearnerVector { learner }); var model = new Model(trainer, network, dataType, device); Trace.WriteLine(model.Summary()); return(model); }
public void MeanSquareError_Zero_Error() { var targetsData = new float[] { 0, 0, 0, 0, 0, 0 }; var targetsVariable = CNTKLib.InputVariable(new int[] { targetsData.Length }, m_dataType); var predictionsData = new float[] { 0, 0, 0, 0, 0, 0 }; var predictionsVariable = CNTKLib.InputVariable(new int[] { predictionsData.Length }, m_dataType); var sut = Losses.MeanSquaredError(predictionsVariable, targetsVariable); var actual = Evaluate(sut, targetsVariable, targetsData, predictionsVariable, predictionsData); Assert.AreEqual(0.0f, actual); }
public void MeanSquareError() { var targetsData = new float[] { 1.0f, 2.3f, 3.1f, 4.4f, 5.8f }; var targetsVariable = CNTKLib.InputVariable(new int[] { targetsData.Length }, m_dataType); var predictionsData = new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }; var predictionsVariable = CNTKLib.InputVariable(new int[] { predictionsData.Length }, m_dataType); var sut = Losses.MeanSquaredError(predictionsVariable, targetsVariable); var actual = Evaluate(sut, targetsVariable, targetsData, predictionsVariable, predictionsData); Assert.AreEqual(0.18f, actual, 0.00001); }