public void PrepareTrainingSets(string path, IProgressWriter progressWriter = null)
        {
            Vocab = File.ReadAllText(path).Distinct().OrderBy(x => x).ToList();

            var text = File.ReadAllText(path);

            var batcher  = new SequenceBatcher <float, char>(Vocab.Count, (ch, i) => Vocab[i] == ch ? 1 : 0);
            var matrices = batcher.BatchSamples(text.ToList(), new BatchDimension(BatchDimensionType.BatchSize, _batchSize), progressWriter);

            TrainingSet = new SequentialDataSet <float>(matrices);
            TestSet     = new SequentialDataSet <float>(new List <Matrix <float> >());     // No testing yet.
        }
Exemple #2
0
        /// <summary>
        /// Tries to enable Math.NET native MKL provider. If not found and download is enabled, downloads
        /// and exctracts MKL into [Retia.dll location]\x64.
        /// </summary>
        /// <param name="tryDownload">Whether MKL download is enabled.</param>
        /// <param name="progressWriter">Progress writer.</param>
        /// <returns>True if could use MKL provider, false otherwise.</returns>
        public static bool TryUseMkl(bool tryDownload = true, IProgressWriter progressWriter = null)
        {
            if (Control.TryUseNativeMKL())
            {
                progressWriter?.Message("Using MKL.");
                return(true);
            }

            if (!tryDownload)
            {
                progressWriter?.Message("Couldn't use MKL and download is disabled. Using slow math provider.");
                return(false);
            }

            progressWriter?.Message("Couldn't use MKL right away, trying to download...");

            string tempPath   = Path.Combine(Path.GetTempPath(), FileName);
            string extractDir = Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "x64");
            var    downloader = new FileDownloader(progressWriter);

            if (!Directory.Exists(extractDir))
            {
                Directory.CreateDirectory(extractDir);
            }
            if (!downloader.DownloadAndExtract(MKLPackage, tempPath, file =>
            {
                var files = file.SelectEntries("*.*", @"build/x64");
                foreach (var entry in files)
                {
                    using (var stream = new FileStream(Path.Combine(extractDir, Path.GetFileName(entry.FileName)), FileMode.Create, FileAccess.Write))
                    {
                        entry.Extract(stream);
                    }
                }
            }, true))
            {
                return(false);
            }

            if (!Control.TryUseNativeMKL())
            {
                progressWriter?.Message("Still can't use MKL, giving up.");
                return(false);
            }

            return(true);
        }
Exemple #3
0
        public List <Sample <T> > BatchSamples(List <LinearSample <T> > samples, BatchDimension dimension, IProgressWriter progressWriter = null)
        {
            int batchSize, batchCount;

            switch (dimension.Type)
            {
            case BatchDimensionType.BatchSize:
                batchCount = samples.Count / dimension.Dimension;
                batchSize  = dimension.Dimension;
                break;

            case BatchDimensionType.BatchCount:
                batchCount = dimension.Dimension;
                batchSize  = samples.Count / dimension.Dimension;
                break;

            default:
                throw new ArgumentOutOfRangeException();
            }

            return(BatchSamples(samples, batchCount, batchSize, progressWriter));
        }
Exemple #4
0
        private List <Sample <T> > BatchSamples(List <LinearSample <T> > samples, int batchCount, int batchSize, IProgressWriter progressWriter)
        {
            var result     = new List <Sample <T> >();
            int inputSize  = samples[0].Input.Length;
            int targetSize = samples[0].Target.Length;

            var tracker = new ProgressTracker(batchSize);

            for (int sampleIdx = 0; sampleIdx < batchCount; sampleIdx++)
            {
                if (tracker.ShouldReport(sampleIdx))
                {
                    progressWriter?.SetItemProgress(sampleIdx, batchCount, "Batching");
                }

                var sample = new Sample <T>(inputSize, targetSize, batchSize);
                for (int col = 0; col < batchSize; col++)
                {
                    var input = samples[sampleIdx + col * batchCount];

                    for (int i = 0; i < input.Input.Length; i++)
                    {
                        sample.Input[i, col] = input.Input[i];
                    }

                    for (int i = 0; i < input.Target.Length; i++)
                    {
                        sample.Target[i, col] = input.Target[i];
                    }
                }

                result.Add(sample);
            }

            progressWriter?.ItemComplete();

            return(result);
        }
Exemple #5
0
        private List <Matrix <TType> > BatchSamples(List <TInput> samples, int batchCount, int batchSize, IProgressWriter progressWriter)
        {
            var result = new List <Matrix <TType> >();

            var tracker = new ProgressTracker(batchSize);

            for (int sampleIdx = 0; sampleIdx < batchCount; sampleIdx++)
            {
                if (tracker.ShouldReport(sampleIdx))
                {
                    progressWriter?.SetItemProgress(sampleIdx, batchCount, "Batching");
                }

                var matrix = Matrix <TType> .Build.Dense(_size, batchSize);

                for (int col = 0; col < batchSize; col++)
                {
                    var input = samples[sampleIdx + col * batchCount];

                    for (int i = 0; i < _size; i++)
                    {
                        matrix[i, col] = _mapper(input, i);
                    }
                }

                result.Add(matrix);
            }

            progressWriter?.ItemComplete();

            return(result);
        }
Exemple #6
0
 public ProgressCallbackState(IProgressWriter writer)
 {
     _writer = writer;
 }
Exemple #7
0
 /// <summary>
 /// Creates a new file downloader.
 /// </summary>
 /// <param name="progressWriter">Optional progress writer.</param>
 public FileDownloader(IProgressWriter progressWriter = null)
 {
     _progressWriter = progressWriter;
 }