public override void Train()
        {
            // Number of training iterations in each epoch
            var num_tr_iter = mnist.Train.Labels.shape[0] / batch_size;

            var init = tf.global_variables_initializer();

            sess.run(init);

            float loss_val     = 100.0f;
            float accuracy_val = 0f;

            var sw = new Stopwatch();

            sw.Start();

            foreach (var epoch in range(epochs))
            {
                print($"Training epoch: {epoch + 1}");
                // Randomly shuffle the training data at the beginning of each epoch
                var(x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels);

                foreach (var iteration in range(num_tr_iter))
                {
                    var start = iteration * batch_size;
                    var end   = (iteration + 1) * batch_size;
                    var(x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);

                    // Run optimization op (backprop)
                    sess.run(optimizer, (x, x_batch), (y, y_batch));

                    if (iteration % display_freq == 0)
                    {
                        // Calculate and display the batch loss and accuracy
                        (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch));
                        print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms");
                        sw.Restart();
                    }
                }

                // Run validation after every epoch
                (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, mnist.Validation.Data), (y, mnist.Validation.Labels));
                print("---------------------------------------------------------");
                print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
                print("---------------------------------------------------------");
            }
        }
        public void Train(Session sess)
        {
            // Number of training iterations in each epoch
            var num_tr_iter = mnist.Train.Labels.len / batch_size;

            var init = tf.global_variables_initializer();

            sess.run(init);

            float loss_val     = 100.0f;
            float accuracy_val = 0f;

            foreach (var epoch in range(epochs))
            {
                print($"Training epoch: {epoch + 1}");
                // Randomly shuffle the training data at the beginning of each epoch
                var(x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels);

                foreach (var iteration in range(num_tr_iter))
                {
                    var start = iteration * batch_size;
                    var end   = (iteration + 1) * batch_size;
                    var(x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);

                    // Run optimization op (backprop)
                    sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));

                    if (iteration % display_freq == 0)
                    {
                        // Calculate and display the batch loss and accuracy
                        var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
                        loss_val     = result[0];
                        accuracy_val = result[1];
                        print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
                    }
                }

                // Run validation after every epoch
                var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels));

                loss_val     = results1[0];
                accuracy_val = results1[1];
                print("---------------------------------------------------------");
                print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
                print("---------------------------------------------------------");
            }
        }
Пример #3
0
        public override void Train()
        {
            float loss_val     = 100.0f;
            float accuracy_val = 0f;

            // Number of training iterations in each epoch
            var n_batches = y_train.shape[0] / batch_size;

            var init = tf.global_variables_initializer();

            sess.run(init);

            foreach (var epoch in range(n_epochs))
            {
                print($"Training epoch: {epoch + 1}");
                // Randomly shuffle the training data at the beginning of each epoch
                (x_train, y_train) = mnist.Randomize(x_train, y_train);

                foreach (var iteration in range(n_batches))
                {
                    var start = iteration * batch_size;
                    var end   = (iteration + 1) * batch_size;
                    var(X_train, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
                    X_train = X_train.reshape(-1, n_steps, n_inputs);
                    y_batch = np.argmax(y_batch, axis: 1);
                    // Run optimization op (backprop)
                    sess.run(optimizer, new FeedItem(X, X_train), new FeedItem(y, y_batch));

                    if (iteration % display_freq == 0)
                    {
                        // Calculate and display the batch loss and accuracy
                        (loss_val, accuracy_val) = sess.run((loss, accuracy), (X, X_train), (y, y_batch));
                        print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
                    }
                }

                // Run validation after every epoch
                (loss_val, accuracy_val) = sess.run((loss, accuracy), (X, x_valid), (y, y_valid));

                print("---------------------------------------------------------");
                print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
                print("---------------------------------------------------------");
            }
        }