Пример #1
0
        protected virtual bool RunWithBuiltGraph(Session session)
        {
            Console.WriteLine("Building dataset...");
            var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

            ITextClassificationModel model = null;

            switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn
            {
            case "word_cnn":
            case "char_cnn":
            case "word_rnn":
            case "att_rnn":
            case "rcnn":
                throw new NotImplementedException();
                break;

            case "vd_cnn":
                model = new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
                break;
            }
            // todo train the model
            return(false);
        }
Пример #2
0
 public void Run()
 {
     download_dbpedia();
     Console.WriteLine("Building dataset...");
     var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
     var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15);
 }
Пример #3
0
        public void Run()
        {
            download_dbpedia();
            Console.WriteLine("Building dataset...");
            var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

            with(tf.Session(), sess =>
            {
                new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
            });
        }
        public bool Run()
        {
            PrepareData();
            Console.WriteLine("Building dataset...");
            var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit);

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

            return(with(tf.Session(), sess =>
            {
                new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
                return false;
            }));
        }
Пример #5
0
        protected virtual bool RunWithImportedGraph(Session sess)
        {
            var graph = tf.Graph().as_default();

            Console.WriteLine("Building dataset...");
            var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

            var meta_file = model_name + "_untrained.meta";

            tf.train.import_meta_graph(Path.Join("graph", meta_file));

            //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export

            var    train_batches         = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
            var    num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
            double max_accuracy          = 0;

            Tensor is_training = graph.get_operation_by_name("is_training");
            Tensor model_x     = graph.get_operation_by_name("x");
            Tensor model_y     = graph.get_operation_by_name("y");
            Tensor loss        = graph.get_operation_by_name("Variable");
            Tensor accuracy    = graph.get_operation_by_name("accuracy/accuracy");

            foreach (var(x_batch, y_batch) in train_batches)
            {
                var train_feed_dict = new Hashtable
                {
                    [model_x]     = x_batch,
                    [model_y]     = y_batch,
                    [is_training] = true,
                };

                //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
            }

            return(false);
        }
        protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
        {
            var stopwatch = Stopwatch.StartNew();

            Console.WriteLine("Building dataset...");
            var path = UseSubset ? SUBSET_PATH : TRAIN_PATH;

            var(x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle: !UseSubset);
            Console.WriteLine("\tDONE ");

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
            Console.WriteLine("Training set size: " + train_x.len);
            Console.WriteLine("Test set size: " + valid_x.len);

            Console.WriteLine("Import graph...");
            var meta_file = model_name + ".meta";

            tf.train.import_meta_graph(Path.Join("graph", meta_file));
            Console.WriteLine("\tDONE " + stopwatch.Elapsed);

            sess.run(tf.global_variables_initializer());

            var    train_batches         = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
            var    num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
            double max_accuracy          = 0;

            Tensor is_training = graph.get_tensor_by_name("is_training:0");
            Tensor model_x     = graph.get_tensor_by_name("x:0");
            Tensor model_y     = graph.get_tensor_by_name("y:0");
            Tensor loss        = graph.get_tensor_by_name("loss/value:0");
            Tensor optimizer   = graph.get_tensor_by_name("loss/optimizer:0");
            Tensor global_step = graph.get_tensor_by_name("global_step:0");
            Tensor accuracy    = graph.get_tensor_by_name("accuracy/value:0");

            stopwatch = Stopwatch.StartNew();
            int i = 0;

            foreach (var(x_batch, y_batch, total) in train_batches)
            {
                i++;
                var train_feed_dict = new FeedDict
                {
                    [model_x]     = x_batch,
                    [model_y]     = y_batch,
                    [is_training] = true,
                };
                //Console.WriteLine("x: " + x_batch.ToString() + "\n");
                //Console.WriteLine("y: " + y_batch.ToString());
                // original python:
                //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
                var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
                loss_value = result[2];
                var step = (int)result[1];
                if (step % 10 == 0 || step < 10)
                {
                    var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
                    Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}");
                    Console.WriteLine($"Step {step} loss: {loss_value}");
                }

                if (step % 100 == 0)
                {
                    // # Test accuracy with validation data for each epoch.
                    var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
                    var(sum_accuracy, cnt) = (0.0f, 0);
                    foreach (var(valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
                    {
                        var valid_feed_dict = new FeedDict
                        {
                            [model_x]     = valid_x_batch,
                            [model_y]     = valid_y_batch,
                            [is_training] = false
                        };
                        var   result1        = sess.run(accuracy, valid_feed_dict);
                        float accuracy_value = result1;
                        sum_accuracy += accuracy_value;
                        cnt          += 1;
                    }

                    var valid_accuracy = sum_accuracy / cnt;

                    print($"\nValidation Accuracy = {valid_accuracy}\n");

                    //    # Save model
                    //        if valid_accuracy > max_accuracy:
                    //        max_accuracy = valid_accuracy
                    //        saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step)
                    //        print("Model is saved.\n")
                }
            }

            return(false);
        }
        protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
        {
            Console.WriteLine("Building dataset...");
            var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit);
            Console.WriteLine("\tDONE");

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

            Console.WriteLine("Import graph...");
            var meta_file = model_name + ".meta";

            tf.train.import_meta_graph(Path.Join("graph", meta_file));
            Console.WriteLine("\tDONE");
            // definitely necessary, otherwize will get the exception of "use uninitialized variable"
            sess.run(tf.global_variables_initializer());

            var    train_batches         = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
            var    num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
            double max_accuracy          = 0;

            Tensor is_training = graph.get_operation_by_name("is_training");
            Tensor model_x     = graph.get_operation_by_name("x");
            Tensor model_y     = graph.get_operation_by_name("y");
            Tensor loss        = graph.get_operation_by_name("loss/loss");
            //var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray();
            Tensor optimizer   = graph.get_operation_by_name("loss/optimizer");
            Tensor global_step = graph.get_operation_by_name("global_step");
            Tensor accuracy    = graph.get_operation_by_name("accuracy/accuracy");

            int i = 0;

            foreach (var(x_batch, y_batch) in train_batches)
            {
                i++;
                Console.WriteLine("Training on batch " + i);
                var train_feed_dict = new Hashtable
                {
                    [model_x]     = x_batch,
                    [model_y]     = y_batch,
                    [is_training] = true,
                };
                // original python:
                //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
                var result = sess.run(new Tensor[] { optimizer, global_step, loss }, train_feed_dict);
                // exception here, loss value seems like a float[]
                //loss_value = result[2];
                var step = result[1];
                if (step % 10 == 0)
                {
                    Console.WriteLine($"Step {step} loss: {result[2]}");
                }
                if (step % 100 == 0)
                {
                    continue;
                    // # Test accuracy with validation data for each epoch.
                    var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
                    var(sum_accuracy, cnt) = (0, 0);
                    foreach (var(valid_x_batch, valid_y_batch) in valid_batches)
                    {
                        //        valid_feed_dict = {
                        //            model.x: valid_x_batch,
                        //            model.y: valid_y_batch,
                        //            model.is_training: False
                        //        }

                        //        accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict)
                        //        sum_accuracy += accuracy
                        //        cnt += 1
                    }
                    //    valid_accuracy = sum_accuracy / cnt

                    //    print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt))

                    //    # Save model
                    //        if valid_accuracy > max_accuracy:
                    //        max_accuracy = valid_accuracy
                    //        saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step)
                    //        print("Model is saved.\n")
                }
            }

            return(false);
        }