Example #1
0
        public string TrainModel(IpfsJob job)
        {
            var tmpInput  = Ipfs.Get <JToken>(job.input);
            var tmpTarget = Ipfs.Get <JToken>(job.target);

            var seq = CreateSequential(job.Model);

            var inputData   = tmpInput.SelectToken("data").ToObject <float[]>();
            var inputShape  = tmpInput.SelectToken("shape").ToObject <int[]>();
            var inputTensor = controller.floatTensorFactory.Create(_data: inputData, _shape: inputShape, _autograd: true);

            var targetData   = tmpTarget.SelectToken("data").ToObject <float[]>();
            var targetShape  = tmpTarget.SelectToken("shape").ToObject <int[]>();
            var targetTensor = controller.floatTensorFactory.Create(_data: targetData, _shape: targetShape, _autograd: true);

            var grad = controller.floatTensorFactory.Create(_data: new float[] { 1, 1, 1, 1 },
                                                            _shape: new int[] { 4, 1 });

            Loss loss;

            switch (job.config.criterion)
            {
            case "mseloss":
                loss = new MSELoss(this.controller);
                break;

            case "categorical_crossentropy":
                loss = new CategoricalCrossEntropyLoss(this.controller);
                break;

            case "cross_entropy_loss":
                loss = new CrossEntropyLoss(this.controller, 1);     // TODO -- real value
                break;

            case "nll_loss":
                loss = new NLLLoss(this.controller);
                break;

            default:
                loss = new MSELoss(this.controller);
                break;
            }

            var optimizer = new SGD(this.controller, seq.getParameters(), job.config.lr, 0, 0);

            for (var i = 0; i < job.config.iters; ++i)
            {
                var pred = seq.Forward(inputTensor);
                var l    = loss.Forward(pred, targetTensor);
                l.Backward();

                // TODO -- better batch size
                optimizer.Step(100, i);
            }

            var resultJob = new Ipfs();
            var response  = resultJob.Write(new IpfsJob(job.input, job.target, seq.GetConfig(), job.config));

            return(response.Hash);
        }
Example #2
0
        public void TrainModel(MonoBehaviour owner, string input, string target, IpfsJob job, int modelId)
        {
            var tmpInput  = Ipfs.Get <JToken>(input);
            var tmpTarget = Ipfs.Get <JToken>(target);

            var seq = CreateSequential(job.Model);

            var inputData   = tmpInput.SelectToken("data").ToObject <float[]>();
            var inputShape  = tmpInput.SelectToken("shape").ToObject <int[]>();
            var inputTensor = controller.floatTensorFactory.Create(_data: inputData, _shape: inputShape, _autograd: true);

            var targetData   = tmpTarget.SelectToken("data").ToObject <float[]>();
            var targetShape  = tmpTarget.SelectToken("shape").ToObject <int[]>();
            var targetTensor = controller.floatTensorFactory.Create(_data: targetData, _shape: targetShape, _autograd: true);

            var grad = controller.floatTensorFactory.Create(_data: new float[] { 1, 1, 1, 1 },
                                                            _shape: new int[] { 4, 1 });

            // 10 epochs .. make configurable
            for (var i = 0; i < 10; ++i)
            {
                var pred = seq.Forward(inputTensor);

                var loss = pred.Sub(targetTensor).Pow(2);
                loss.Backward(grad);

                foreach (var p in seq.getParameters())
                {
                    var pTensor = controller.floatTensorFactory.Get(p);
                    pTensor.Sub(pTensor.Grad, inline: true);
                }
            }

            var resultJob = new Ipfs();
            var config    = new IpfsJobConfig(job.config.lr);
            var response  = resultJob.Write(new IpfsJob(seq.GetConfig(), config));

            var req = new Request();

            owner.StartCoroutine(req.AddWeights(owner, modelId, response.Hash));
        }