Ejemplo n.º 1
0
        /*************************************************************
        *
        * FetchLocalFeedAsDictionary
        *
        *************************************************************/

        private FeedDict FetchLocalFeedAsDictionary()
        {
            FeedDict fd = new FeedDict();

            string path = RootDir;

            try
            {
                var folders = Directory.EnumerateDirectories(path);
                foreach (var folder in folders)
                {
                    DateTime dt;
                    FeedItem fi = null;
                    string   f  = new DirectoryInfo(folder).Name;
                    try
                    {
                        dt = DateTime.ParseExact(f, "yyyy-MMM-dd HHmm", CultureInfo.InvariantCulture);
                    }
                    catch (Exception)
                    {
                        continue;
                    }
                    try
                    {
                        using (StreamReader sr = new StreamReader(folder + "\\feed.xml"))
                        {
                            XmlSerializer xr = new XmlSerializer(typeof(FeedItem));
                            fi = (FeedItem)xr.Deserialize(sr);
                        }
                    }
                    catch (Exception ex)
                    {
                        MessageBox.Show(string.Format("Failed to deserialize {0}. {1}. Will skip", folder + "\\feed.xml", ex.Message));
                        continue;
                    }
                    fd[dt] = fi;
                }
            }
            catch (Exception e)
            {
                MessageBox.Show(string.Format("Failed to enumerate locally downloaded podcasts. {0}", e.Message));
            }
            return(fd);
        }
Ejemplo n.º 2
0
        /*************************************************************
        *
        * GetFullFeed
        *
        *************************************************************/

        private async void GetFullFeed()
        {
            FeedDict fd   = FetchLocalFeedAsDictionary();
            Feed     feed = await FetchTastyTradeFeed();

            foreach (FeedItem fi in feed)
            {
                if (!fd.ContainsKey(fi.PubDate))
                {
                    fd[fi.PubDate] = fi;
                }
            }

            m_DomainFeed = (from pair in fd
                            orderby pair.Key descending
                            select pair.Value).ToList();

            m_FavouriteFeed      = (from fi in m_DomainFeed where Favourited(fi) select fi).ToList();
            m_DisplayedFeed      = m_DomainFeed.GetRange(CurrentOffset, Math.Min(m_DomainFeed.Count, 50));
            CurrentOffset        = 0;
            FeedGrid.DataContext = m_DisplayedFeed;
        }
        public void Train(Session sess)
        {
            var graph     = tf.get_default_graph();
            var stopwatch = Stopwatch.StartNew();

            sess.run(tf.global_variables_initializer());
            var saver = tf.train.Saver(tf.global_variables());

            var train_batches         = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
            var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;

            Tensor    is_training = graph.OperationByName("is_training");
            Tensor    model_x     = graph.OperationByName("x");
            Tensor    model_y     = graph.OperationByName("y");
            Tensor    loss        = graph.OperationByName("loss/Mean");
            Operation optimizer   = graph.OperationByName("loss/Adam");
            Tensor    global_step = graph.OperationByName("Variable");
            Tensor    accuracy    = graph.OperationByName("accuracy/accuracy");

            stopwatch = Stopwatch.StartNew();
            int step = 0;

            foreach (var(x_batch, y_batch, total) in train_batches)
            {
                (_, step, loss_value) = sess.run((optimizer, global_step, loss),
                                                 (model_x, x_batch), (model_y, y_batch), (is_training, true));
                if (step == 1 || step % 10 == 0)
                {
                    Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}.");
                }

                if (step % 100 == 0)
                {
                    // Test accuracy with validation data for each epoch.
                    var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
                    var(sum_accuracy, cnt) = (0.0f, 0);
                    foreach (var(valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
                    {
                        var valid_feed_dict = new FeedDict
                        {
                            [model_x]     = valid_x_batch,
                            [model_y]     = valid_y_batch,
                            [is_training] = false
                        };
                        float accuracy_value = sess.run(accuracy, (model_x, valid_x_batch), (model_y, valid_y_batch), (is_training, false));
                        sum_accuracy += accuracy_value;
                        cnt          += 1;
                    }

                    var valid_accuracy = sum_accuracy / cnt;

                    print($"\nValidation Accuracy = {valid_accuracy.ToString("P")}\n");

                    // Save model
                    if (valid_accuracy > max_accuracy)
                    {
                        max_accuracy = valid_accuracy;
                        saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step);
                        print("Model is saved.\n");
                    }
                }
            }
        }
Ejemplo n.º 4
0
        private bool Train(Session sess, Graph graph)
        {
            var stopwatch = Stopwatch.StartNew();

            sess.run(tf.global_variables_initializer());
            var saver = tf.train.Saver(tf.global_variables());

            var    train_batches         = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
            var    num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
            double max_accuracy          = 0;

            Tensor    is_training = graph.OperationByName("is_training");
            Tensor    model_x     = graph.OperationByName("x");
            Tensor    model_y     = graph.OperationByName("y");
            Tensor    loss        = graph.OperationByName("loss/Mean");
            Operation optimizer   = graph.OperationByName("loss/Adam");
            Tensor    global_step = graph.OperationByName("Variable");
            Tensor    accuracy    = graph.OperationByName("accuracy/accuracy");

            stopwatch = Stopwatch.StartNew();
            int i = 0;

            foreach (var(x_batch, y_batch, total) in train_batches)
            {
                i++;
                var train_feed_dict = new FeedDict
                {
                    [model_x]     = x_batch,
                    [model_y]     = y_batch,
                    [is_training] = true,
                };

                var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
                loss_value = result[2];
                var step = (int)result[1];
                if (step % 10 == 0)
                {
                    var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
                    Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
                }

                if (step % 100 == 0)
                {
                    // Test accuracy with validation data for each epoch.
                    var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
                    var(sum_accuracy, cnt) = (0.0f, 0);
                    foreach (var(valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
                    {
                        var valid_feed_dict = new FeedDict
                        {
                            [model_x]     = valid_x_batch,
                            [model_y]     = valid_y_batch,
                            [is_training] = false
                        };
                        var   result1        = sess.run(accuracy, valid_feed_dict);
                        float accuracy_value = result1;
                        sum_accuracy += accuracy_value;
                        cnt          += 1;
                    }

                    var valid_accuracy = sum_accuracy / cnt;

                    print($"\nValidation Accuracy = {valid_accuracy}\n");

                    // Save model
                    if (valid_accuracy > max_accuracy)
                    {
                        max_accuracy = valid_accuracy;
                        saver.save(sess, $"{dataDir}/word_cnn.ckpt", global_step: step);
                        print("Model is saved.\n");
                    }
                }
            }

            return(false);
        }
        protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
        {
            var stopwatch = Stopwatch.StartNew();

            Console.WriteLine("Building dataset...");
            var path = UseSubset ? SUBSET_PATH : TRAIN_PATH;

            int[][] x               = null;
            int[]   y               = null;
            int     alphabet_size   = 0;
            int     vocabulary_size = 0;

            if (model_name == "vd_cnn")
            {
                (x, y, alphabet_size) = DataHelpers.build_char_dataset(path, model_name, CHAR_MAX_LEN, DataLimit = null, shuffle: !UseSubset);
            }
            else
            {
                var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
                vocabulary_size = len(word_dict);
                (x, y)          = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
            }

            Console.WriteLine("\tDONE ");

            var(train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
            Console.WriteLine("Training set size: " + train_x.len);
            Console.WriteLine("Test set size: " + valid_x.len);

            Console.WriteLine("Import graph...");
            var meta_file = model_name + ".meta";

            tf.train.import_meta_graph(Path.Join("graph", meta_file));
            Console.WriteLine("\tDONE " + stopwatch.Elapsed);

            sess.run(tf.global_variables_initializer());
            var saver = tf.train.Saver(tf.global_variables());

            var    train_batches         = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
            var    num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1;
            double max_accuracy          = 0;

            Tensor    is_training = graph.OperationByName("is_training");
            Tensor    model_x     = graph.OperationByName("x");
            Tensor    model_y     = graph.OperationByName("y");
            Tensor    loss        = graph.OperationByName("loss/Mean"); // word_cnn
            Operation optimizer   = graph.OperationByName("loss/Adam"); // word_cnn
            Tensor    global_step = graph.OperationByName("Variable");
            Tensor    accuracy    = graph.OperationByName("accuracy/accuracy");

            stopwatch = Stopwatch.StartNew();
            int i = 0;

            foreach (var(x_batch, y_batch, total) in train_batches)
            {
                i++;
                var train_feed_dict = new FeedDict
                {
                    [model_x]     = x_batch,
                    [model_y]     = y_batch,
                    [is_training] = true,
                };
                //Console.WriteLine("x: " + x_batch.ToString() + "\n");
                //Console.WriteLine("y: " + y_batch.ToString());
                // original python:
                //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
                var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
                loss_value = result[2];
                var step = (int)result[1];
                if (step % 10 == 0)
                {
                    var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total);
                    Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value}. Estimated training time: {estimate}");
                }

                if (step % 100 == 0)
                {
                    // # Test accuracy with validation data for each epoch.
                    var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
                    var(sum_accuracy, cnt) = (0.0f, 0);
                    foreach (var(valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
                    {
                        var valid_feed_dict = new FeedDict
                        {
                            [model_x]     = valid_x_batch,
                            [model_y]     = valid_y_batch,
                            [is_training] = false
                        };
                        var   result1        = sess.run(accuracy, valid_feed_dict);
                        float accuracy_value = result1;
                        sum_accuracy += accuracy_value;
                        cnt          += 1;
                    }

                    var valid_accuracy = sum_accuracy / cnt;

                    print($"\nValidation Accuracy = {valid_accuracy}\n");

                    //    # Save model
                    if (valid_accuracy > max_accuracy)
                    {
                        max_accuracy = valid_accuracy;
                        // saver.save(sess, $"{dataDir}/{model_name}.ckpt", global_step: step.ToString());
                        print("Model is saved.\n");
                    }
                }
            }

            return(false);
        }
Ejemplo n.º 6
0
        /*************************************************************
        *
        * FetchLocalFeedAsDictionary
        *
        *************************************************************/

        private FeedDict FetchLocalFeedAsDictionary ()
        {
            FeedDict fd = new FeedDict ();

            string path = RootDir;

            try
            {
                var folders = Directory.EnumerateDirectories (path);
                foreach (var folder in folders)
                {
                    DateTime dt;
                    FeedItem fi = null;
                    string f = new DirectoryInfo (folder).Name;
                    try
                    {
                        dt = DateTime.ParseExact (f, "yyyy-MMM-dd HHmm", CultureInfo.InvariantCulture);
                    }
                    catch (Exception )
                    {
                        continue;
                    }
                    try
                    {
                        using (StreamReader sr = new StreamReader (folder + "\\feed.xml"))
                        {
                            XmlSerializer xr = new XmlSerializer (typeof (FeedItem));
                            fi = (FeedItem) xr.Deserialize (sr);
                        }
                    }
                    catch (Exception ex)
                    {
                        MessageBox.Show (string.Format ("Failed to deserialize {0}. {1}. Will skip", folder + "\\feed.xml", ex.Message));
                        continue;
                    }
                    fd[dt] = fi;
                }
            }
            catch (Exception e)
            {
                MessageBox.Show (string.Format ("Failed to enumerate locally downloaded podcasts. {0}", e.Message));
            }
            return fd;
        }