コード例 #1
0
        /// <summary>
        /// The train.
        /// </summary>
        /// <param name="scenarioId">
        /// The scenario id.
        /// </param>
        /// <returns>
        /// The <see cref="ScenarioTrainings"/>.
        /// </returns>
        public ScenarioTrainings Train(string scenarioId)
        {
            using (var dbContext = new OpenAIEntities1())
            {
                Scenario scenario = dbContext.Scenarios.Find(scenarioId);
                if (scenario == null)
                {
                    return(null);
                }

                DataTable table = this.ParseCsv(scenario.Contents);

                // Get hypothesis for each feature / output
                var trainings = new ScenarioTrainings
                {
                    ScenarioId          = scenarioId,
                    TrainingByFeatureId = new Dictionary <string, TrainerHelper>()
                };

                trainingByScenario[scenarioId] = trainings;
                foreach (DataColumn column in table.Columns)
                {
                    TrainerHelper thing = trainer.Train(table, column.ColumnName);
                    trainings.TrainingByFeatureId[column.ColumnName] = thing;
                }

                return(trainings);
            }
        }
コード例 #2
0
        private static IPredictor TrainCore(IHostEnvironment env, IChannel ch, RoleMappedData data, ITrainer trainer, RoleMappedData validData,
                                            IComponentFactory <ICalibratorTrainer> calibrator, int maxCalibrationExamples, bool?cacheData, IPredictor inputPredictor = null, RoleMappedData testData = null)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ch, nameof(ch));
            ch.CheckValue(data, nameof(data));
            ch.CheckValue(trainer, nameof(trainer));
            ch.CheckValueOrNull(validData);
            ch.CheckValueOrNull(inputPredictor);

            AddCacheIfWanted(env, ch, trainer, ref data, cacheData);
            ch.Trace("Training");
            if (validData != null)
            {
                AddCacheIfWanted(env, ch, trainer, ref validData, cacheData);
            }

            if (inputPredictor != null && !trainer.Info.SupportsIncrementalTraining)
            {
                ch.Warning("Ignoring " + nameof(TrainCommand.Arguments.InputModelFile) +
                           ": Trainer does not support incremental training.");
                inputPredictor = null;
            }
            ch.Assert(validData == null || trainer.Info.SupportsValidation);
            var predictor   = trainer.Train(new TrainContext(data, validData, testData, inputPredictor));
            var caliTrainer = calibrator?.CreateComponent(env);

            return(CalibratorUtils.TrainCalibratorIfNeeded(env, ch, caliTrainer, maxCalibrationExamples, trainer, predictor, data));
        }
コード例 #3
0
        public static long GetTrainingLength(ITrainer trainer, List <TrainingElement> trainingData)
        {
            var stopWatch = new Stopwatch();

            stopWatch.Start();
            trainer.Train(trainingData);
            stopWatch.Stop();
            return(stopWatch.ElapsedMilliseconds);
        }
コード例 #4
0
 public IActionResult Train([FromQuery] int iterations, int approximationRank, double learningRate)
 {
     try
     {
         return(Ok(_trainer.Train(iterations, approximationRank, learningRate)));
     }
     catch (Exception e)
     {
         _logger.LogError(e, e.Message);
         return(StatusCode(500, e.Message));
     }
 }
コード例 #5
0
        protected override IPredictor Train(TrainContext ctx)
        {
            var data = ctx.TrainingSet;

            Contracts.CheckValue(data, "data");
            Contracts.CheckValue(_trainer, "_trainer");

            IDataView view = data.Data;

            // Preprocess only for training.
            if (_args.preTrainType != null)
            {
                using (var ch2 = Host.Start("PreProcessTraining"))
                {
                    ch2.Info("Applies a preprocess only for training: {0}", _args.preTrainType);
                    var trSett = ScikitSubComponent <IDataTransform, SignatureDataTransform> .AsSubComponent(_args.preTrainType);

                    _preTrainProcess = trSett.CreateInstance(Host, view);
                }
                view = _preTrainProcess;
            }

            // Preprocess.
            if (_args.preType != null)
            {
                using (var ch2 = Host.Start("PreProcess"))
                {
                    ch2.Info("Applies a preprocess: {0}", _args.preType);
                    var trSett = ScikitSubComponent <IDataTransform, SignatureDataTransform> .AsSubComponent(_args.preType);

                    _preProcess = trSett.CreateInstance(Host, view);
                }
            }
            else
            {
                _preProcess = new PassThroughTransform(Host, new PassThroughTransform.Arguments {
                }, view);
            }
            view = _preProcess;

            // New RoleDataMapping
            var roles = data.Schema.GetColumnRoleNames()
                        .Where(kvp => kvp.Key.Value != CR.Feature.Value)
                        .Where(kvp => kvp.Key.Value != CR.Group.Value)
                        .Where(kvp => kvp.Key.Value != CR.Label.Value)
                        .Where(kvp => kvp.Key.Value != CR.Name.Value)
                        .Where(kvp => kvp.Key.Value != CR.Weight.Value);

            if (data.Schema.Feature != null)
            {
                roles = roles.Prepend(CR.Feature.Bind(data.Schema.Feature.Value.Name));
            }
            if (data.Schema.Group != null)
            {
                roles = roles.Prepend(CR.Group.Bind(data.Schema.Group.Value.Name));
            }
            if (data.Schema.Label != null)
            {
                roles = roles.Prepend(CR.Label.Bind(data.Schema.Label.Value.Name));
            }
            if (data.Schema.Weight != null)
            {
                roles = roles.Prepend(CR.Weight.Bind(data.Schema.Weight.Value.Name));
            }
            var td = new RoleMappedData(view, roles);

            // Train.
            if (_args.predictorType != null)
            {
                using (var ch2 = Host.Start("Training"))
                {
                    var sch1 = SchemaHelper.ToString(data.Schema.Schema);
                    var sch2 = SchemaHelper.ToString(td.Schema.Schema);
                    ch2.Info("Initial schema: {0}", sch1);
                    ch2.Info("DataViewSchema before training: {0}", sch2);
                    ch2.Info("Train a predictor: {0}", _args.predictorType);
                    _predictor = _trainer.Train(td);
                }
            }

            // Predictor as a transform.
            {
                using (var ch2 = Host.Start("Predictor as Transform"))
                {
                    ch2.Info("Creates a transfrom from a predictor");
                    _inputColumn          = td.Schema.Feature.Value.Name;
                    _predictorAsTransform = new TransformFromValueMapper(Host, _predictor as IValueMapper,
                                                                         view, td.Schema.Feature.Value.Name, _outputColumn);
                }
                view = _predictorAsTransform;
            }

            // Postprocess.
            if (_args.postType != null)
            {
                using (var ch2 = Host.Start("PostProcess"))
                {
                    ch2.Info("Applies a postprocess: {0}", _args.postType);
                    var postSett = ScikitSubComponent <IDataTransform, SignatureDataTransform> .AsSubComponent(_args.postType);

                    _postProcess = postSett.CreateInstance(Host, view);
                }
            }
            else
            {
                _postProcess = null;
            }
            return(CreatePredictor());
        }
コード例 #6
0
 /// <summary>
 /// Convenience train extension for the case where one has only a training set with no auxiliary information.
 /// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
 /// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
 /// </summary>
 /// <param name="trainer">The trainer</param>
 /// <param name="trainData">The training data.</param>
 /// <returns>The trained predictor</returns>
 public static TPredictor Train <TPredictor>(this ITrainer <TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
 => trainer.Train(new TrainContext(trainData));
コード例 #7
0
 /// <summary>
 /// Convenience train extension for the case where one has only a training set with no auxiliary information.
 /// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
 /// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
 /// </summary>
 /// <param name="trainer">The trainer</param>
 /// <param name="trainData">The training data.</param>
 /// <returns>The trained predictor</returns>
 public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
 => trainer.Train(new TrainContext(trainData));
コード例 #8
0
 public void StochasticGradientDescentNewBenchmark()
 {
     _trainer.Train(_trainingData);
 }
コード例 #9
0
 public IPredictor Train(TrainContext context)
 {
     return(_trainer.Train(context));
 }
コード例 #10
0
 public void StochasticGradientDescentBenchmark()
 {
     _trainer.Train(_trainingData.ToList());
 }
コード例 #11
0
ファイル: Program.cs プロジェクト: ostorc/NeuralNetowork
        /// <summary>
        /// Enterance method
        /// </summary>
        /// <param name="args">Comand line arguments</param>
        private static void Main(string[] args)
        {
            Console.WriteLine(
                new System.Text.RegularExpressions.Regex("^\\s*(?:\\+?(\\d{1,3}))?[-. (]*(\\d{3})[-. )]*(\\d{3})[-. ]*(\\d{4})(?: *x(\\d+))?\\s*$")
                .Matches(Console.ReadLine()).Cast <System.Text.RegularExpressions.Capture>().Select(x => x.Value).Aggregate((i, j) => i + " \n" + j)
                );
            Console.ReadLine();
            TrainerManager tm = new TrainerManager();

            //ICalculatableNetwork nn2 =
            //    tm.GetTrainer(LearningSubject.AndGate, LearningObject.Network).Train<ICalculatableNetwork>();


            //ICalculatableNeuron n1 =
            //    tm.GetTrainer(LearningSubject.AndGate, LearningObject.Neuron).Train<ICalculatableNeuron>();
            //ICalculatableNeuron n2 =
            //    tm.GetTrainer(LearningSubject.OrGate, LearningObject.Neuron).Train<ICalculatableNeuron>();
            //ICalculatableNetwork nn1 =
            //    tm.GetTrainer(LearningSubject.XorGate, LearningObject.Network).Train<ICalculatableNetwork>();


            ITrainer imageTrainer =
                tm.GetTrainer(LearningSubject.Image, LearningObject.Network);
            var imageNetwork = imageTrainer.Train <ICalculatableNetwork>();

            var d = imageTrainer.GetTestData();

            double succesRate = 0;

            foreach (var pair in d)
            {
                var di   = (new DigitImage(pair.Key, (byte)pair.Value.ToList().IndexOf(1)));
                var data = (imageNetwork.Calculate(pair.Key).ToList());
                if (di.Label == data.IndexOf(data.Max()))
                {
                    succesRate += 1.0 / d.Count;
                }
            }
            Console.WriteLine(succesRate);
            //var m = tm.GetTrainer(LearningSubject.XorGate, (ILearnable) nn);
            //var m2 = tm.GetTrainer(LearningSubject.XorGate, LearningObject.Network).Train<ICalculatableNetwork>();
            //var m3 = m2.Calculate(1,1).First();

            //ITrainer andTrainer = new AndGateNeuronTrainer((ILearnableNeuron) n1);
            //ITrainer orTrainer = new OrGateNeuronTrainer((ILearnableNeuron) n2);
            //ITrainer xorNetworkTrainer = new XorGateNetworkTrainer((ILearnableNetwork) nn);


            //andTrainer.Train();
            //orTrainer.Train();
            //// Not posible -> Use multilayer
            //// xorNeuronTrainer.Train();
            //xorNetworkTrainer.Train();

            //Console.WriteLine("----------AND Gate----------");
            //Console.WriteLine($"(1,1) => {n1.Calculate(1, 1)}");
            //Console.WriteLine($"(1,0) => {n1.Calculate(1, 0)}");
            //Console.WriteLine($"(0,1) => {n1.Calculate(0, 1)}");
            //Console.WriteLine($"(0,0) => {n1.Calculate(0, 0)}");
            //Console.WriteLine("----------OR  Gate----------");
            //Console.WriteLine($"(1,1) => {n2.Calculate(1, 1)}");
            //Console.WriteLine($"(1,0) => {n2.Calculate(1, 0)}");
            //Console.WriteLine($"(0,1) => {n2.Calculate(0, 1)}");
            //Console.WriteLine($"(0,0) => {n2.Calculate(0, 0)}");
            //Console.WriteLine("----------XOR  Gate----------");
            //Console.WriteLine($"(1,1) => {nn1.Calculate(1, 1).First()}");
            //Console.WriteLine($"(1,0) => {nn1.Calculate(1, 0).First()}");
            //Console.WriteLine($"(0,1) => {nn1.Calculate(0, 1).First()}");
            //Console.WriteLine($"(0,0) => {nn1.Calculate(0, 0).First()}");
            //Console.WriteLine("----------AND  Gate----------");
            //Console.WriteLine($"(1,1) => {nn2.Calculate(1, 1).First()}");
            //Console.WriteLine($"(1,0) => {nn2.Calculate(1, 0).First()}");
            //Console.WriteLine($"(0,1) => {nn2.Calculate(0, 1).First()}");
            //Console.WriteLine($"(0,0) => {nn2.Calculate(0, 0).First()}");
            //DigitImageLoader loader = new DigitImageLoader("train-images.idx3-ubyte", "train-labels.idx1-ubyte");
            //var data = loader.Load().OrderBy(x=>x.Label);
            //foreach (DigitImage image in data)
            //{
            //    Console.WriteLine(image);
            //    Console.ReadLine();
            //}
            Console.ReadLine();
        }