Beispiel #1
0
        public void WrapperVersusCommandLine()
        {
            var vwArgs                    = "--cb_explore_adf --cb_type dr --epsilon 0.2";
            int numExamples               = 1000;
            var inputFile                 = Path.GetTempFileName();
            var vwFile                    = inputFile + ".vw.gz";
            var wrapperPredictionFile     = inputFile + ".vw.wrapper.pred";
            var commandLinePredictionFile = inputFile + ".vw.commandline.pred";

            var rand     = new Random();
            var contexts = new bool[numExamples].Select((_, i) => TestContext.CreateRandom(rand, i));

            File.WriteAllLines(inputFile, contexts.Select(c => JsonConvert.SerializeObject(c)));

            using (var reader = new StreamReader(inputFile))
                using (var writer = new StreamWriter(new GZipStream(File.Create(vwFile), CompressionLevel.Optimal)))
                {
                    VowpalWabbitJsonToString.Convert(reader, writer);
                }

            OfflineTrainer.Train(vwArgs, inputFile, wrapperPredictionFile);

            Process.Start(new ProcessStartInfo
            {
                FileName        = "vw.exe",
                Arguments       = $"{vwArgs} -p {commandLinePredictionFile} -d {vwFile}",
                CreateNoWindow  = true,
                UseShellExecute = false
            }).WaitForExit();

            var wrapperActionScore = File
                                     .ReadAllLines(wrapperPredictionFile)
                                     .Select(l => JsonConvert.DeserializeObject <WrapperPredictionLine>(l))
                                     .ToList();

            var commandLineActionScore = File
                                         .ReadAllLines(commandLinePredictionFile)
                                         .Where(l => !string.IsNullOrWhiteSpace(l))
                                         .Select(l => l.Split(new string[] { "," }, StringSplitOptions.RemoveEmptyEntries))
                                         .Select(ll => ll.Select(l => l.Split(':')).Select(l => new { Action = Convert.ToInt32(l[0]), Prob = Convert.ToSingle(l[1]) }))
                                         .ToList();

            Assert.AreEqual(wrapperActionScore.Count, commandLineActionScore.Count);

            for (int i = 0; i < wrapperActionScore.Count; i++)
            {
                Assert.IsTrue(commandLineActionScore[i].Select(ap => ap.Action).SequenceEqual(wrapperActionScore[i].Actions));
                Assert.IsTrue(commandLineActionScore[i].Select(ap => ap.Prob).SequenceEqual(wrapperActionScore[i].Probs, new FloatComparer()));
            }

            File.Delete(inputFile);
            File.Delete(vwFile);
            File.Delete(wrapperPredictionFile);
        }
Beispiel #2
0
        static void Main(string[] args)
        {
            try
            {
                var stopwatch = Stopwatch.StartNew();

                var storageAccount = new CloudStorageAccount(new StorageCredentials("storage name", "storage key"), false);

                var outputDirectory = @"c:\temp\";
                Directory.CreateDirectory(outputDirectory);
                var startTimeInclusive = new DateTime(2016, 8, 11, 19, 0, 0);
                var endTimeExclusive   = new DateTime(2016, 8, 18, 0, 0, 0);
                var outputFile         = Path.Combine(outputDirectory, $"{startTimeInclusive:yyyy-MM-dd_HH}-{endTimeExclusive:yyyy-MM-dd_HH}.json");

                // download and merge blob data
                using (var writer = new StreamWriter(outputFile))
                {
                    AzureBlobDownloader.Download(storageAccount, startTimeInclusive, endTimeExclusive, writer, outputDirectory).Wait();
                }

                // pre-process JSON
                JsonTransform.TransformIgnoreProperties(outputFile, outputFile + ".small",
                                                        "Somefeatures");

                outputFile += ".small";
                // filter broken events
                JsonTransform.TransformFixMarginal(outputFile,
                                                   numExpectedActions: 10, // examples with different number of actions are ignored
                                                   startingNamespace: 'G', // starting namespace of the marginal features, if more than one marginal features then the next letter is used, e.g. G for the first one, H for second, and so on.
                                                   marginalProperties: new TupleList <string, string>
                {
                    // The property parent and name to create marginal features for
                    { "DVideoFeatures", "VideoId" },
                    //{ "DVideoFeatures", "VideoTitle" }, // uncomment if more marginal features are needed
                });

                outputFile += ".fixed";

                using (var reader = new StreamReader(outputFile))
                    using (var writer = new StreamWriter(new GZipStream(File.Create(outputFile + ".vw.gz"), CompressionLevel.Optimal)))
                    {
                        VowpalWabbitJsonToString.Convert(reader, writer);
                    }

                var bags = new[] { 1, 2, 4, 6, 8, 10 }.Select(a => "--bag " + a);
                var softmaxes = new[] { 0, 1, 2, 4, 8, 16, 32 }.Select(a => "--softmax --lambda " + a);
                var epsilons = new[] { .33333f, .2f, .1f, .05f }.Select(a => "--epsilon " + a);

                var arguments = Util.Expand(
                    epsilons.Union(bags).Union(softmaxes),
                    new[] { "--cb_type ips", "--cb_type mtr", "--cb_type dr" },
                    new[] { "-q AB -q UD" },
                    new[] { 0.005, 0.01, 0.02, 0.1 }.Select(l => string.Format(CultureInfo.InvariantCulture, "-l {0}", l))
                    )
                                .Select(a => $"--cb_explore_adf {a} --interact ud ")
                                .ToList();

                var sep         = "\t";
                var historyFile = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "mwt.experiments");
                using (var historyWriter = new StreamWriter(File.Open(historyFile, FileMode.OpenOrCreate)))
                {
                    for (int i = 0; i < arguments.Count; i++)
                    {
                        var startTime              = DateTime.UtcNow;
                        var outputPredictionFile   = $"{outputFile}.prediction";
                        var outputPrediction2hFile = $"{outputFile}.{i + 1}.2h.prediction";

                        // VW training
                        OfflineTrainer.Train(arguments[i],
                                             outputFile,
                                             predictionFile: outputPrediction2hFile,
                                             reloadInterval: TimeSpan.FromHours(2),
                                             cacheFilePrefix: null); // null to use input file's name for cache, see the method documentation for more details

                        var metricResult = Metrics.Compute(outputFile, outputPredictionFile, outputPrediction2hFile);

                        historyWriter.WriteLine($"{startTime}{sep}{arguments[i]}{sep}{string.Join(sep, metricResult.Select(m => m.Name + sep + m.Value))}");
                    }
                }

                Console.WriteLine("\ndone " + stopwatch.Elapsed);
                Console.WriteLine("Run information is added to: ", historyFile);
            }
            catch (Exception ex)
            {
                Console.WriteLine($"Exception: {ex.Message}. {ex.StackTrace}");
            }

            Console.ReadKey();
        }