public static void Run() { var mnist_train = new FashionMNIST(train: true); var(x, y) = mnist_train[0]; Console.WriteLine($"X shape: {x.Shape}, X dtype: {x.DataType}, Y shape: {y.Shape}, Y dtype: {y.DataType}"); var transformer = new Compose( new ToTensor(), new Normalize(new MxNet.Tuple <float>(0.13f, 0.31f)) ); var train = mnist_train.TransformFirst(transformer); int batch_size = 256; var train_data = new DataLoader(train, batch_size: batch_size, shuffle: true); foreach (var(data, label) in train_data) { Console.WriteLine(data.Shape + ", " + label.Shape); break; } var mnist_valid = new FashionMNIST(train: false); var valid_data = new DataLoader(mnist_valid, batch_size: batch_size, shuffle: true); var net = new Sequential(); net.Add(new Conv2D(channels: 6, kernel_size: (5, 5), activation: ActivationType.Relu), new MaxPool2D(pool_size: (2, 2), strides: (2, 2)), new Conv2D(channels: 16, kernel_size: (3, 3), activation: ActivationType.Relu), new MaxPool2D(pool_size: (2, 2), strides: (2, 2)), new Flatten(), new Dense(120, activation: ActivationType.Relu), new Dense(84, activation: ActivationType.Relu), new Dense(10)); net.Initialize(new Xavier()); var softmax_cross_entropy = new SoftmaxCrossEntropyLoss(); var trainer = new Trainer(net.CollectParams(), new SGD(learning_rate: 0.1f)); for (int epoch = 0; epoch < 10; epoch++) { var tic = DateTime.Now; float train_loss = 0; float train_acc = 0; float valid_acc = 0; foreach (var(data, label) in train_data) { NDArray loss = null; NDArray output = null; // forward + backward using (Autograd.Record()) { output = net.Call(data); loss = softmax_cross_entropy.Call(output, label); } loss.Backward(); //update parameters trainer.Step(batch_size); //calculate training metrics train_loss += loss.Mean(); train_acc += Acc(output, label); } // calculate validation accuracy foreach (var(data, label) in valid_data) { valid_acc += Acc(net.Call(data), label); } Console.WriteLine($"Epoch {epoch}: loss {train_loss / train_data.Length}," + $" train acc {train_acc / train_data.Length}, " + $"test acc {train_acc / train_data.Length} " + $"in {(DateTime.Now - tic).TotalMilliseconds} ms"); } net.SaveParameters("net.params"); }