protected virtual bool RunWithBuiltGraph(Session session) { Console.WriteLine("Building dataset..."); var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); ITextClassificationModel model = null; switch (model_name) // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn { case "word_cnn": case "char_cnn": case "word_rnn": case "att_rnn": case "rcnn": throw new NotImplementedException(); break; case "vd_cnn": model = new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); break; } // todo train the model return(false); }
public void Run() { download_dbpedia(); Console.WriteLine("Building dataset..."); var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15); }
public void Run() { download_dbpedia(); Console.WriteLine("Building dataset..."); var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); with(tf.Session(), sess => { new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); }); }
public bool Run() { PrepareData(); Console.WriteLine("Building dataset..."); var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); return(with(tf.Session(), sess => { new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); return false; })); }
protected virtual bool RunWithImportedGraph(Session sess) { var graph = tf.Graph().as_default(); Console.WriteLine("Building dataset..."); var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); var meta_file = model_name + "_untrained.meta"; tf.train.import_meta_graph(Path.Join("graph", meta_file)); //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 double max_accuracy = 0; Tensor is_training = graph.get_operation_by_name("is_training"); Tensor model_x = graph.get_operation_by_name("x"); Tensor model_y = graph.get_operation_by_name("y"); Tensor loss = graph.get_operation_by_name("Variable"); Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); foreach (var(x_batch, y_batch) in train_batches) { var train_feed_dict = new Hashtable { [model_x] = x_batch, [model_y] = y_batch, [is_training] = true, }; //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) } return(false); }
protected virtual bool RunWithImportedGraph(Session sess, Graph graph) { var stopwatch = Stopwatch.StartNew(); Console.WriteLine("Building dataset..."); var path = UseSubset ? SUBSET_PATH : TRAIN_PATH; var(x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle: !UseSubset); Console.WriteLine("\tDONE "); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); Console.WriteLine("Training set size: " + train_x.len); Console.WriteLine("Test set size: " + valid_x.len); Console.WriteLine("Import graph..."); var meta_file = model_name + ".meta"; tf.train.import_meta_graph(Path.Join("graph", meta_file)); Console.WriteLine("\tDONE " + stopwatch.Elapsed); sess.run(tf.global_variables_initializer()); var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; double max_accuracy = 0; Tensor is_training = graph.get_tensor_by_name("is_training:0"); Tensor model_x = graph.get_tensor_by_name("x:0"); Tensor model_y = graph.get_tensor_by_name("y:0"); Tensor loss = graph.get_tensor_by_name("loss/value:0"); Tensor optimizer = graph.get_tensor_by_name("loss/optimizer:0"); Tensor global_step = graph.get_tensor_by_name("global_step:0"); Tensor accuracy = graph.get_tensor_by_name("accuracy/value:0"); stopwatch = Stopwatch.StartNew(); int i = 0; foreach (var(x_batch, y_batch, total) in train_batches) { i++; var train_feed_dict = new FeedDict { [model_x] = x_batch, [model_y] = y_batch, [is_training] = true, }; //Console.WriteLine("x: " + x_batch.ToString() + "\n"); //Console.WriteLine("y: " + y_batch.ToString()); // original python: //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); loss_value = result[2]; var step = (int)result[1]; if (step % 10 == 0 || step < 10) { var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}"); Console.WriteLine($"Step {step} loss: {loss_value}"); } if (step % 100 == 0) { // # Test accuracy with validation data for each epoch. var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); var(sum_accuracy, cnt) = (0.0f, 0); foreach (var(valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) { var valid_feed_dict = new FeedDict { [model_x] = valid_x_batch, [model_y] = valid_y_batch, [is_training] = false }; var result1 = sess.run(accuracy, valid_feed_dict); float accuracy_value = result1; sum_accuracy += accuracy_value; cnt += 1; } var valid_accuracy = sum_accuracy / cnt; print($"\nValidation Accuracy = {valid_accuracy}\n"); // # Save model // if valid_accuracy > max_accuracy: // max_accuracy = valid_accuracy // saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step) // print("Model is saved.\n") } } return(false); }
protected virtual bool RunWithImportedGraph(Session sess, Graph graph) { Console.WriteLine("Building dataset..."); var(x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); Console.WriteLine("\tDONE"); var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); Console.WriteLine("Import graph..."); var meta_file = model_name + ".meta"; tf.train.import_meta_graph(Path.Join("graph", meta_file)); Console.WriteLine("\tDONE"); // definitely necessary, otherwize will get the exception of "use uninitialized variable" sess.run(tf.global_variables_initializer()); var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 double max_accuracy = 0; Tensor is_training = graph.get_operation_by_name("is_training"); Tensor model_x = graph.get_operation_by_name("x"); Tensor model_y = graph.get_operation_by_name("y"); Tensor loss = graph.get_operation_by_name("loss/loss"); //var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray(); Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); Tensor global_step = graph.get_operation_by_name("global_step"); Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); int i = 0; foreach (var(x_batch, y_batch) in train_batches) { i++; Console.WriteLine("Training on batch " + i); var train_feed_dict = new Hashtable { [model_x] = x_batch, [model_y] = y_batch, [is_training] = true, }; // original python: //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) var result = sess.run(new Tensor[] { optimizer, global_step, loss }, train_feed_dict); // exception here, loss value seems like a float[] //loss_value = result[2]; var step = result[1]; if (step % 10 == 0) { Console.WriteLine($"Step {step} loss: {result[2]}"); } if (step % 100 == 0) { continue; // # Test accuracy with validation data for each epoch. var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); var(sum_accuracy, cnt) = (0, 0); foreach (var(valid_x_batch, valid_y_batch) in valid_batches) { // valid_feed_dict = { // model.x: valid_x_batch, // model.y: valid_y_batch, // model.is_training: False // } // accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict) // sum_accuracy += accuracy // cnt += 1 } // valid_accuracy = sum_accuracy / cnt // print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) // # Save model // if valid_accuracy > max_accuracy: // max_accuracy = valid_accuracy // saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step) // print("Model is saved.\n") } } return(false); }