コード例 #1
0
        public IntTensor Create(int[] _shape,
                                int[] _data = null,
                                ComputeBuffer _dataBuffer  = null,
                                ComputeBuffer _shapeBuffer = null,
                                ComputeShader _shader      = null,
                                bool _copyData             = true,
                                bool _dataOnGpu            = false,
                                bool _autograd             = false,
                                bool _keepgrads            = false,
                                string _creation_op        = null)
        {
            IntTensor tensor = new IntTensor();

            tensor.init(this,
                        _shape,
                        _data,
                        _dataBuffer,
                        _shapeBuffer,
                        _shader,
                        _copyData,
                        _dataOnGpu,
                        _autograd,
                        _keepgrads,
                        _creation_op);

            tensors.Add(tensor.Id, tensor);


            return(tensor);
        }
コード例 #2
0
ファイル: TestTorchTensor.cs プロジェクト: hxjj/TorchSharp
        public void SumTest()
        {
            var data = new float[] { 1.0f, 2.0f, 3.0f };

            var res1   = FloatTensor.From(data).Sum();
            var res1_0 = res1.DataItem <float>();

            Assert.Equal(6.0f, res1_0);

            var res2   = FloatTensor.From(data).Sum(type: ScalarType.Double);
            var res2_0 = res2.DataItem <double>();

            Assert.Equal(6.0, res2_0);

            // summing integers gives long unless type is explicitly specified
            var dataInt32 = new int[] { 1, 2, 3 };
            var res3      = IntTensor.From(dataInt32).Sum();

            Assert.Equal(ScalarType.Long, res3.Type);
            var res3_0 = res3.DataItem <long>();

            Assert.Equal(6L, res3_0);

            // summing integers gives long unless type is explicitly specified
            var res4 = IntTensor.From(dataInt32).Sum(type: ScalarType.Int);

            Assert.Equal(ScalarType.Int, res4.Type);
            var res4_0 = res4.DataItem <int>();

            Assert.Equal(6L, res4_0);
        }
コード例 #3
0
        public void WriteAndReadIntTensorViaDiskFile()
        {
            const int size = 10;

            var file = new DiskFile("test1D.dat", "rwb");

            Assert.NotNull(file);
            Assert.True(file.CanWrite);

            var tensor0 = new IntTensor(size);

            for (var i = 0; i < size; ++i)
            {
                tensor0[i] = (int)i;
            }

            file.WriteTensor(tensor0);
            Assert.Equal(size * sizeof(int), file.Position);
            file.Seek(0);

            var tensor1 = new IntTensor(size);
            var rd      = file.ReadTensor(tensor1);

            Assert.Equal(rd, size);
            Assert.Equal(size * sizeof(int), file.Position);

            for (var i = 0; i < rd; ++i)
            {
                Assert.Equal(tensor1[i], tensor1[i]);
            }

            file.Close();
            Assert.False(file.IsOpen);
        }
コード例 #4
0
        public IntTensor Create(int[] _shape,
                                int[] _data = null,
                                ComputeBuffer _dataBuffer    = null,
                                ComputeBuffer _shapeBuffer   = null,
                                ComputeBuffer _stridesBuffer = null,
                                bool _copyData      = true,
                                bool _dataOnGpu     = false,
                                string _creation_op = null)
        {
            // leave this IF statement - it is used for testing.
            if (ctrl.allow_new_tensors)
            {
                IntTensor tensor = new IntTensor();

                tensor.Init(this,
                            _shape,
                            _data,
                            _dataBuffer,
                            _shapeBuffer,
                            _stridesBuffer,
                            shader,
                            _copyData,
                            _dataOnGpu,
                            _creation_op);

                tensors.Add(tensor.Id, tensor);

                return(tensor);
            }

            throw new Exception("Attempted to Create a new IntTensor");
        }
コード例 #5
0
ファイル: TestTorchTensor.cs プロジェクト: hxjj/TorchSharp
        public void CreateIntTensorOnes()
        {
            var         shape = new long[] { 2, 2 };
            TorchTensor t     = IntTensor.Ones(shape);

            Assert.Equal(shape, t.Shape);
            Assert.Equal(1, t[0, 0].DataItem <int>());
            Assert.Equal(1, t[1, 1].DataItem <int>());
        }
コード例 #6
0
ファイル: TestTorchTensor.cs プロジェクト: hxjj/TorchSharp
        public void CreateIntTensorOnesCheckData()
        {
            var ones = IntTensor.Ones(new long[] { 2, 2 });
            var data = ones.Data <int>();

            for (int i = 0; i < 4; i++)
            {
                Assert.Equal(1, data[i]);
            }
        }
コード例 #7
0
        public void TextIndexSet()
        {
            var tensor = IntTensor.Zeros(new long[] { 2 });

            using (var value = 1.ToTorchTensor())
            {
                tensor[0] = value;
                Assert.AreEqual(tensor.Data <int>()[0], 1);
            }
        }
コード例 #8
0
        public IntTensor Sample(FloatTensor input, int dim = 1)
        {
            input.Autograd = true;

            FloatTensor pred = Forward(input);

            IntTensor   actions      = pred.Sample(dim);
            FloatTensor action_preds = pred.IndexSelect(actions, -1);

            history.Add(new FloatTensor[2] {
                action_preds, null
            });
            return(actions);
        }
コード例 #9
0
ファイル: IntTensorGpuTest.cs プロジェクト: vyomshm/OpenMined
 public void AssertEqualTensorsData(IntTensor t1, IntTensor t2, double delta = 0.0d)
 {
     int[] data1 = new int[t1.Size];
     t1.DataBuffer.GetData(data1);
     int[] data2 = new int[t2.Size];
     t2.DataBuffer.GetData(data2);
     Assert.AreEqual(t1.DataBuffer.count, t2.DataBuffer.count);
     Assert.AreEqual(t1.DataBuffer.stride, t2.DataBuffer.stride);
     Assert.AreNotEqual(t1.DataBuffer.GetNativeBufferPtr(), t2.DataBuffer.GetNativeBufferPtr());
     Assert.AreEqual(data1.Length, data2.Length);
     for (var i = 0; i < data1.Length; ++i)
     {
         //Debug.LogFormat("Asserting {0} equals {1} with accuracy {2} where diff is {3}", data1[i], data2[i], delta, data1[i] - data2[i]);
         Assert.AreEqual(data1[i], data2[i], delta);
     }
 }
コード例 #10
0
ファイル: TestTorchTensor.cs プロジェクト: hxjj/TorchSharp
        public void TestSubInPlace()
        {
            var x = IntTensor.Ones(new long[] { 100, 100 });
            var y = IntTensor.Ones(new long[] { 100, 100 });

            x.SubInPlace(y);

            var xdata = x.Data <int>();

            for (int i = 0; i < 100; i++)
            {
                for (int j = 0; j < 100; j++)
                {
                    Assert.Equal(0, xdata[i + j]);
                }
            }
        }
コード例 #11
0
ファイル: TestTorchTensor.cs プロジェクト: hxjj/TorchSharp
        public void RandomTest()
        {
            var res = FloatTensor.Random(new long[] { 2 });

            Assert.Equal(new long[] { 2 }, res.Shape);

            var res1 = ShortTensor.RandomIntegers(10, new long[] { 200 });

            Assert.Equal(new long[] { 200 }, res1.Shape);

            var res2 = IntTensor.RandomIntegers(10, new long[] { 200 });

            Assert.Equal(new long[] { 200 }, res2.Shape);

            var res3 = LongTensor.RandomIntegers(10, new long[] { 200 });

            Assert.Equal(new long[] { 200 }, res3.Shape);

            var res4 = ByteTensor.RandomIntegers(10, new long[] { 200 });

            Assert.Equal(new long[] { 200 }, res4.Shape);

            var res5 = SByteTensor.RandomIntegers(10, new long[] { 200 });

            Assert.Equal(new long[] { 200 }, res5.Shape);

            var res6 = HalfTensor.RandomIntegers(10, new long[] { 200 });

            Assert.Equal(new long[] { 200 }, res6.Shape);

            //var res7 = ComplexHalfTensor.RandomIntegers(10, new long[] { 200 });
            //Assert.Equal(new long[] { 200 }, res7.Shape);

            //var res8 = ComplexFloatTensor.RandomIntegers(10, new long[] { 200 });
            //Assert.Equal(new long[] { 200 }, res8.Shape);

            //var res9 = ComplexDoubleTensor.RandomIntegers(10, new long[] { 200 });
            //Assert.Equal(new long[] { 200 }, res9.Shape);
        }
コード例 #12
0
        public IntTensor Create(int[] _shape,
                                int[] _data = null,
                                ComputeBuffer _dataBuffer  = null,
                                ComputeBuffer _shapeBuffer = null,
                                ComputeShader _shader      = null,
                                bool _copyData             = true,
                                bool _dataOnGpu            = false,
                                bool _autograd             = false,
                                bool _keepgrads            = false,
                                string _creation_op        = null)
        {
            if (ctrl.allow_new_tensors)
            {
                IntTensor tensor = new IntTensor();

                tensor.init(this,
                            _shape,
                            _data,
                            _dataBuffer,
                            _shapeBuffer,
                            shader,
                            _copyData,
                            _dataOnGpu,
                            _autograd,
                            _keepgrads,
                            _creation_op);

                tensors.Add(tensor.Id, tensor);

                return(tensor);
            }
            else
            {
                throw new Exception("Attempted to Create a new IntTensor");
            }
        }
コード例 #13
0
 /// <summary>
 ///   Read ints from the file into the given int tensor.
 /// </summary>
 /// <param name="tensor">A tensor to place the data in after reading it from the file.</param>
 /// <returns>The number of ints read.</returns>
 public long ReadTensor(IntTensor tensor)
 {
     return(THFile_readIntRaw(this.handle, tensor.Data, tensor.NumElements));
 }
コード例 #14
0
        public string processMessage(string json_message, MonoBehaviour owner)
        {
            //Debug.LogFormat("<color=green>SyftController.processMessage {0}</color>", json_message);

            Command msgObj = JsonUtility.FromJson <Command> (json_message);

            try
            {
                switch (msgObj.objectType)
                {
                case "Optimizer":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string optimizer_type = msgObj.tensorIndexParams[0];

                        // Extract parameters
                        List <int> p = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            p.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        List <float> hp = new List <float>();
                        for (int i = 0; i < msgObj.hyperParams.Length; i++)
                        {
                            hp.Add(float.Parse(msgObj.hyperParams[i]));
                        }

                        Optimizer optim = null;

                        if (optimizer_type == "sgd")
                        {
                            optim = new SGD(this, p, hp[0], hp[1], hp[2]);
                        }
                        else if (optimizer_type == "rmsprop")
                        {
                            optim = new RMSProp(this, p, hp[0], hp[1], hp[2], hp[3]);
                        }
                        else if (optimizer_type == "adam")
                        {
                            optim = new Adam(this, p, hp[0], hp[1], hp[2], hp[3], hp[4]);
                        }

                        return(optim.Id.ToString());
                    }
                    else
                    {
                        Optimizer optim = this.getOptimizer(msgObj.objectIndex);
                        return(optim.ProcessMessage(msgObj, this));
                    }
                }

                case "FloatTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(_shape: msgObj.shape, _data: msgObj.data, _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else
                    {
                        FloatTensor tensor = floatTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        return(tensor.ProcessMessage(msgObj, this));
                    }
                }

                case "IntTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        int[] data = new int[msgObj.data.Length];
                        for (int i = 0; i < msgObj.data.Length; i++)
                        {
                            data[i] = (int)msgObj.data[i];
                        }
                        IntTensor tensor = intTensorFactory.Create(_shape: msgObj.shape, _data: data, _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else
                    {
                        IntTensor tensor = intTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        return(tensor.ProcessMessage(msgObj, this));
                    }
                }

                case "agent":
                {
                    if (msgObj.functionCall == "create")
                    {
                        Layer     model     = (Layer)getModel(int.Parse(msgObj.tensorIndexParams[0]));
                        Optimizer optimizer = optimizers[int.Parse(msgObj.tensorIndexParams[1])];
                        return(new Syft.NN.RL.Agent(this, model, optimizer).Id.ToString());
                    }

                    //Debug.Log("Getting Model:" + msgObj.objectIndex);
                    Syft.NN.RL.Agent agent = this.getAgent(msgObj.objectIndex);
                    return(agent.ProcessMessageLocal(msgObj, this));
                }

                case "model":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string model_type = msgObj.tensorIndexParams[0];

                        Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);

                        if (model_type == "linear")
                        {
                            return(this.BuildLinear(msgObj.tensorIndexParams).Id.ToString());
                        }
                        else if (model_type == "relu")
                        {
                            return(this.BuildReLU().Id.ToString());
                        }
                        else if (model_type == "log")
                        {
                            return(this.BuildLog().Id.ToString());
                        }
                        else if (model_type == "dropout")
                        {
                            return(this.BuildDropout(msgObj.tensorIndexParams).Id.ToString());
                        }
                        else if (model_type == "sigmoid")
                        {
                            return(this.BuildSigmoid().Id.ToString());
                        }
                        else if (model_type == "sequential")
                        {
                            return(this.BuildSequential().Id.ToString());
                        }
                        else if (model_type == "softmax")
                        {
                            return(this.BuildSoftmax(msgObj.tensorIndexParams).Id.ToString());
                        }
                        else if (model_type == "logsoftmax")
                        {
                            return(this.BuildLogSoftmax(msgObj.tensorIndexParams).Id.ToString());
                        }
                        else if (model_type == "tanh")
                        {
                            return(new Tanh(this).Id.ToString());
                        }
                        else if (model_type == "crossentropyloss")
                        {
                            return(new CrossEntropyLoss(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                        }
                        else if (model_type == "categorical_crossentropy")
                        {
                            return(new CategoricalCrossEntropyLoss(this).Id.ToString());
                        }
                        else if (model_type == "nllloss")
                        {
                            return(new NLLLoss(this).Id.ToString());
                        }
                        else if (model_type == "mseloss")
                        {
                            return(new MSELoss(this).Id.ToString());
                        }
                        else if (model_type == "embedding")
                        {
                            return(new Embedding(this, int.Parse(msgObj.tensorIndexParams[1]), int.Parse(msgObj.tensorIndexParams[2])).Id.ToString());
                        }
                        else
                        {
                            Debug.LogFormat("<color=red>Model Type Not Found:</color> {0}", model_type);
                        }
                    }
                    else
                    {
                        //Debug.Log("Getting Model:" + msgObj.objectIndex);
                        Model model = this.getModel(msgObj.objectIndex);
                        return(model.ProcessMessage(msgObj, this));
                    }
                    return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                }

                case "controller":
                {
                    if (msgObj.functionCall == "num_tensors")
                    {
                        return(floatTensorFactory.Count() + "");
                    }
                    else if (msgObj.functionCall == "num_models")
                    {
                        return(models.Count + "");
                    }
                    else if (msgObj.functionCall == "new_tensors_allowed")
                    {
                        Debug.LogFormat("New Tensors Allowed:{0}", msgObj.tensorIndexParams[0]);
                        if (msgObj.tensorIndexParams[0] == "True")
                        {
                            allow_new_tensors = true;
                        }
                        else if (msgObj.tensorIndexParams[0] == "False")
                        {
                            allow_new_tensors = false;
                        }
                        else
                        {
                            throw new Exception("Invalid parameter for new_tensors_allowed. Did you mean true or false?");
                        }

                        return(allow_new_tensors + "");
                    }
                    else if (msgObj.functionCall == "load_floattensor")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(filepath: msgObj.tensorIndexParams[0], _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else if (msgObj.functionCall == "set_seed")
                    {
                        Random.InitState(int.Parse(msgObj.tensorIndexParams[0]));
                        return("Random seed set!");
                    }
                    else if (msgObj.functionCall == "concatenate")
                    {
                        List <int> tensor_ids = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            tensor_ids.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        FloatTensor result = Functional.Concatenate(floatTensorFactory, tensor_ids, int.Parse(msgObj.tensorIndexParams[0]));
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "ones")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Ones(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "randn")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Randn(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "random")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Random(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "zeros")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Zeros(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "model_from_json")
                    {
                        Debug.Log("Loading Model from JSON:");
                        var json_str = msgObj.tensorIndexParams[0];
                        var config   = JObject.Parse(json_str);

                        Sequential model;

                        if ((string)config["class_name"] == "Sequential")
                        {
                            model = this.BuildSequential();
                        }
                        else
                        {
                            return("Unity Error: SyftController.processMessage: while Loading model, Class :" + config["class_name"] + " is not implemented");
                        }

                        for (int i = 0; i < config["config"].ToList().Count; i++)
                        {
                            var layer_desc        = config["config"][i];
                            var layer_config_desc = layer_desc["config"];

                            if ((string)layer_desc["class_name"] == "Linear")
                            {
                                int previous_output_dim;

                                if (i == 0)
                                {
                                    previous_output_dim = (int)layer_config_desc["batch_input_shape"][layer_config_desc["batch_input_shape"].ToList().Count - 1];
                                }
                                else
                                {
                                    previous_output_dim = (int)layer_config_desc["units"];
                                }

                                string[] parameters = new string[] { "linear", previous_output_dim.ToString(), layer_config_desc["units"].ToString(), "Xavier" };
                                Layer    layer      = this.BuildLinear(parameters);
                                model.AddLayer(layer);

                                string activation_name = layer_config_desc["activation"].ToString();

                                if (activation_name != "linear")
                                {
                                    Layer activation;
                                    if (activation_name == "softmax")
                                    {
                                        parameters = new string[] { activation_name, "1" };
                                        activation = this.BuildSoftmax(parameters);
                                    }
                                    else if (activation_name == "relu")
                                    {
                                        activation = this.BuildReLU();
                                    }
                                    else
                                    {
                                        return("Unity Error: SyftController.processMessage: while Loading activations, Activation :" + activation_name + " is not implemented");
                                    }
                                    model.AddLayer(activation);
                                }
                            }
                            else
                            {
                                return("Unity Error: SyftController.processMessage: while Loading layers, Layer :" + layer_desc["class_name"] + " is not implemented");
                            }
                        }

                        return(model.Id.ToString());
                    }
                    return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                }

                case "Grid":
                    if (msgObj.functionCall == "learn")
                    {
                        var inputId  = int.Parse(msgObj.tensorIndexParams[0]);
                        var targetId = int.Parse(msgObj.tensorIndexParams[1]);

                        var g = new Grid(this);
                        g.Run(inputId, targetId, msgObj.configurations, owner);

                        return("");
                    }
                    break;

                default:
                    break;
                }
            }
            catch (Exception e)
            {
                Debug.LogFormat("<color=red>{0}</color>", e.ToString());
                return("Unity Error: " + e.ToString());
            }

            // If not executing createTensor or tensor function, return default error.
            return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
        }
コード例 #15
0
ファイル: SyftController.cs プロジェクト: korymath/OpenMined
        public string processMessage(string json_message)
        {
            //Debug.LogFormat("<color=green>SyftController.processMessage {0}</color>", json_message);

            Command msgObj = JsonUtility.FromJson <Command> (json_message);

            try
            {
                switch (msgObj.objectType)
                {
                case "FloatTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(_shape: msgObj.shape, _data: msgObj.data, _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else
                    {
                        FloatTensor tensor = floatTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        return(tensor.ProcessMessage(msgObj, this));
                    }
                }

                case "IntTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        int[] data = new int[msgObj.data.Length];
                        for (int i = 0; i < msgObj.data.Length; i++)
                        {
                            data[i] = (int)msgObj.data[i];
                        }
                        IntTensor tensor = intTensorFactory.Create(_shape: msgObj.shape, _data: data, _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else
                    {
                        IntTensor tensor = intTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        return(tensor.ProcessMessage(msgObj, this));
                    }
                }

                case "model":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string model_type = msgObj.tensorIndexParams[0];

                        if (model_type == "linear")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0} : {1} {2}", model_type,
                                            msgObj.tensorIndexParams[1], msgObj.tensorIndexParams[2]);
                            Linear model = new Linear(this, int.Parse(msgObj.tensorIndexParams[1]), int.Parse(msgObj.tensorIndexParams[2]));
                            return(model.Id.ToString());
                        }
                        else if (model_type == "sigmoid")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);
                            Sigmoid model = new Sigmoid(this);
                            return(model.Id.ToString());
                        }
                        else if (model_type == "sequential")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);
                            Sequential model = new Sequential(this);
                            return(model.Id.ToString());
                        }
                        else if (model_type == "policy")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);
                            Policy model = new Policy(this, (Layer)getModel(int.Parse(msgObj.tensorIndexParams[1])));
                            return(model.Id.ToString());
                        }
                        else if (model_type == "tanh")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);
                            Tanh model = new Tanh(this);
                            return(model.Id.ToString());
                        }
                        else if (model_type == "crossentropyloss")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);
                            CrossEntropyLoss model = new CrossEntropyLoss(this);
                            return(model.Id.ToString());
                        }
                        else if (model_type == "mseloss")
                        {
                            Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);
                            MSELoss model = new MSELoss(this);
                            return(model.Id.ToString());
                        }
                    }
                    else
                    {
                        Model model = this.getModel(msgObj.objectIndex);
                        return(model.ProcessMessage(msgObj, this));
                    }
                    return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                }

                case "controller":
                {
                    if (msgObj.functionCall == "num_tensors")
                    {
                        return(floatTensorFactory.Count() + "");
                    }
                    else if (msgObj.functionCall == "num_models")
                    {
                        return(models.Count + "");
                    }
                    else if (msgObj.functionCall == "new_tensors_allowed")
                    {
                        Debug.LogFormat("New Tensors Allowed:{0}", msgObj.tensorIndexParams[0]);
                        if (msgObj.tensorIndexParams[0] == "True")
                        {
                            allow_new_tensors = true;
                        }
                        else if (msgObj.tensorIndexParams[0] == "False")
                        {
                            allow_new_tensors = false;
                        }
                        else
                        {
                            throw new Exception("Invalid parameter for new_tensors_allowed. Did you mean true or false?");
                        }

                        return(allow_new_tensors + "");
                    }
                    return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                }

                default:
                    break;
                }
            }
            catch (Exception e)
            {
                Debug.LogFormat("<color=red>{0}</color>", e.ToString());
                return("Unity Error: " + e.ToString());
            }

            // If not executing createTensor or tensor function, return default error.
            return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
        }
コード例 #16
0
 /// <summary>
 ///   Write ints to the file from the given int tensor.
 /// </summary>
 /// <param name="tensor">A tensor containing data to be written to the file.</param>
 /// <returns>The number of ints written.</returns>
 public long WriteTensor(IntTensor tensor)
 {
     return(THFile_writeIntRaw(this.handle, tensor.Data, tensor.NumElements));
 }
コード例 #17
0
ファイル: SyftController.cs プロジェクト: iamtrask/ml-agents
        public string processMessage(string json_message)
        {
            //Debug.LogFormat("<color=green>SyftController.processMessage {0}</color>", json_message);

            Command msgObj = JsonUtility.FromJson <Command> (json_message);

            try
            {
                switch (msgObj.objectType)
                {
                case "Optimizer":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string optimizer_type = msgObj.tensorIndexParams[0];

                        // Extract parameters
                        List <int> p = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            p.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        List <float> hp = new List <float>();
                        for (int i = 0; i < msgObj.hyperParams.Length; i++)
                        {
                            hp.Add(float.Parse(msgObj.hyperParams[i]));
                        }

                        Optimizer optim = null;

                        if (optimizer_type == "sgd")
                        {
                            optim = new SGD(this, p, hp[0], hp[1], hp[2]);
                        }
                        else if (optimizer_type == "rmsprop")
                        {
                            optim = new RMSProp(this, p, hp[0], hp[1], hp[2], hp[3]);
                        }
                        else if (optimizer_type == "adam")
                        {
                            optim = new Adam(this, p, hp[0], hp[1], hp[2], hp[3], hp[4]);
                        }

                        return(optim.Id.ToString());
                    }
                    else
                    {
                        Optimizer optim = this.getOptimizer(msgObj.objectIndex);
                        return(optim.ProcessMessage(msgObj, this));
                    }
                }

                case "FloatTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(_shape: msgObj.shape, _data: msgObj.data, _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else
                    {
                        FloatTensor tensor = floatTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        return(tensor.ProcessMessage(msgObj, this));
                    }
                }

                case "IntTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        int[] data = new int[msgObj.data.Length];
                        for (int i = 0; i < msgObj.data.Length; i++)
                        {
                            data[i] = (int)msgObj.data[i];
                        }
                        IntTensor tensor = intTensorFactory.Create(_shape: msgObj.shape, _data: data, _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else
                    {
                        IntTensor tensor = intTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        return(tensor.ProcessMessage(msgObj, this));
                    }
                }

                case "agent":
                {
                    if (msgObj.functionCall == "create")
                    {
                        Layer     model     = (Layer)getModel(int.Parse(msgObj.tensorIndexParams[0]));
                        Optimizer optimizer = optimizers[int.Parse(msgObj.tensorIndexParams[1])];
                        return(new Syft.NN.RL.Agent(this, model, optimizer).Id.ToString());
                    }

                    //Debug.Log("Getting Model:" + msgObj.objectIndex);
                    Syft.NN.RL.Agent agent = this.getAgent(msgObj.objectIndex);
                    return(agent.ProcessMessageLocal(msgObj, this));
                }

                case "model":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string model_type = msgObj.tensorIndexParams[0];

                        Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);

                        if (model_type == "linear")
                        {
                            return(new Linear(this, int.Parse(msgObj.tensorIndexParams[1]),
                                              int.Parse(msgObj.tensorIndexParams[2]),
                                              msgObj.tensorIndexParams[3]).Id.ToString());
                        }
                        else if (model_type == "relu")
                        {
                            return(new ReLU(this).Id.ToString());
                        }
                        else if (model_type == "log")
                        {
                            return(new Log(this).Id.ToString());
                        }
                        else if (model_type == "dropout")
                        {
                            return(new Dropout(this, float.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                        }
                        else if (model_type == "sigmoid")
                        {
                            return(new Sigmoid(this).Id.ToString());
                        }
                        else if (model_type == "sequential")
                        {
                            return(new Sequential(this).Id.ToString());
                        }
                        else if (model_type == "softmax")
                        {
                            return(new Softmax(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                        }
                        else if (model_type == "logsoftmax")
                        {
                            return(new LogSoftmax(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                        }
                        else if (model_type == "tanh")
                        {
                            return(new Tanh(this).Id.ToString());
                        }
                        else if (model_type == "crossentropyloss")
                        {
                            return(new CrossEntropyLoss(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                        }
                        else if (model_type == "nllloss")
                        {
                            return(new NLLLoss(this).Id.ToString());
                        }
                        else if (model_type == "mseloss")
                        {
                            return(new MSELoss(this).Id.ToString());
                        }
                        else if (model_type == "embedding")
                        {
                            return(new Embedding(this, int.Parse(msgObj.tensorIndexParams[1]), int.Parse(msgObj.tensorIndexParams[2])).Id.ToString());
                        }
                        else
                        {
                            Debug.LogFormat("<color=red>Model Type Not Found:</color> {0}", model_type);
                        }
                    }
                    else
                    {
                        //Debug.Log("Getting Model:" + msgObj.objectIndex);
                        Model model = this.getModel(msgObj.objectIndex);
                        return(model.ProcessMessage(msgObj, this));
                    }
                    return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                }

                case "controller":
                {
                    if (msgObj.functionCall == "num_tensors")
                    {
                        return(floatTensorFactory.Count() + "");
                    }
                    else if (msgObj.functionCall == "num_models")
                    {
                        return(models.Count + "");
                    }
                    else if (msgObj.functionCall == "new_tensors_allowed")
                    {
                        Debug.LogFormat("New Tensors Allowed:{0}", msgObj.tensorIndexParams[0]);
                        if (msgObj.tensorIndexParams[0] == "True")
                        {
                            allow_new_tensors = true;
                        }
                        else if (msgObj.tensorIndexParams[0] == "False")
                        {
                            allow_new_tensors = false;
                        }
                        else
                        {
                            throw new Exception("Invalid parameter for new_tensors_allowed. Did you mean true or false?");
                        }

                        return(allow_new_tensors + "");
                    }
                    else if (msgObj.functionCall == "load_floattensor")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(filepath: msgObj.tensorIndexParams[0], _shader: this.Shader);
                        return(tensor.Id.ToString());
                    }
                    else if (msgObj.functionCall == "set_seed")
                    {
                        Random.InitState(int.Parse(msgObj.tensorIndexParams[0]));
                        return("Random seed set!");
                    }
                    else if (msgObj.functionCall == "concatenate")
                    {
                        List <int> tensor_ids = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            tensor_ids.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        FloatTensor result = Functional.Concatenate(floatTensorFactory, tensor_ids, int.Parse(msgObj.tensorIndexParams[0]));
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "ones")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Ones(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "randn")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Randn(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "random")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Random(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    else if (msgObj.functionCall == "zeros")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Zeros(floatTensorFactory, dims);
                        return(result.Id.ToString());
                    }
                    return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                }

                default:
                    break;
                }
            }
            catch (Exception e)
            {
                Debug.LogFormat("<color=red>{0}</color>", e.ToString());
                return("Unity Error: " + e.ToString());
            }

            // If not executing createTensor or tensor function, return default error.
            return("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
        }
コード例 #18
0
    /// Uses the continuous inputs or dicrete inputs of the player to
    /// decide action
    public void DecideAction()
    {
        if (ctrl == null)
        {
            ctrl = brain.brainParameters.syft.controller;
        }

        if (policy == null)
        {
            found_policy = false;
        }
        else
        {
            found_policy = true;
        }

        if (ctrl.getAgent(1234) != null)
        {
            policy = ctrl.getAgent(1234);
            if (found_policy == false)
            {
                foreach (KeyValuePair <int, global::Agent> idAgent in brain.agents)
                {
                    idAgent.Value.Reset();
                }
            }
        }
        else
        {
            policy = null;
        }

        //The states are collected in order to debug the CollectStates method.
        Dictionary <int, List <float> > states  = brain.CollectStates();
        Dictionary <int, float>         rewards = brain.CollectRewards();
        Dictionary <int, bool>          dones   = brain.CollectDones();

        if (brain.brainParameters.actionSpaceType == StateType.continuous)
        {
            float[] action = new float[brain.brainParameters.actionSize];
            foreach (ContinuousPlayerAction cha in continuousPlayerActions)
            {
                if (Input.GetKey(cha.key))
                {
                    action[cha.index] = cha.value;
                }
            }
            Dictionary <int, float[]> actions = new Dictionary <int, float[]>();
            foreach (KeyValuePair <int, global::Agent> idAgent in brain.agents)
            {
                actions.Add(idAgent.Key, action);
            }
            brain.SendActions(actions);
        }
        else
        {
            float[] action = new float[1] {
                defaultAction
            };
            foreach (DiscretePlayerAction dha in discretePlayerActions)
            {
                if (Input.GetKey(dha.key))
                {
                    action[0] = (float)dha.value;
                    break;
                }
            }
            Dictionary <int, float[]> actions = new Dictionary <int, float[]>();
            foreach (KeyValuePair <int, global::Agent> idAgent in brain.agents)
            {
                if (policy == null)
                {
                    // do nothing - you don't have a network
                    actions.Add(idAgent.Key, new float[1] {
                        0
                    });
                }
                else
                {
                    //input = [Number of agents x state size]
                    FloatTensor input = ctrl.floatTensorFactory.Create(_shape: new int[] { 1, states[idAgent.Key].Count },
                                                                       _data: states[idAgent.Key].ToArray());

                    IntTensor pred = policy.Sample(input);
                    actions.Add(idAgent.Key, new float[1] {
                        pred.Data[0]
                    });
                }
            }

            brain.SendActions(actions);
        }
    }
コード例 #19
0
ファイル: IntTensorGpuTest.cs プロジェクト: withai/OpenMined
 public void AssertApproximatelyEqualTensorsData(IntTensor t1, IntTensor t2)
 {
     AssertEqualTensorsData(t1, t2, .0001f);
 }
コード例 #20
0
ファイル: SyftController.cs プロジェクト: ygambhir/OpenMined
        public void ProcessMessage(string json_message, MonoBehaviour owner, Action <string> response)
        {
            Command msgObj = JsonUtility.FromJson <Command> (json_message);

            try
            {
                switch (msgObj.objectType)
                {
                case "Optimizer":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string optimizer_type = msgObj.tensorIndexParams[0];

                        // Extract parameters
                        List <int> p = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            p.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        List <float> hp = new List <float>();
                        for (int i = 0; i < msgObj.hyperParams.Length; i++)
                        {
                            hp.Add(float.Parse(msgObj.hyperParams[i]));
                        }

                        Optimizer optim = null;

                        if (optimizer_type == "sgd")
                        {
                            optim = new SGD(this, p, hp[0], hp[1], hp[2]);
                        }
                        else if (optimizer_type == "rmsprop")
                        {
                            optim = new RMSProp(this, p, hp[0], hp[1], hp[2], hp[3]);
                        }
                        else if (optimizer_type == "adam")
                        {
                            optim = new Adam(this, p, hp[0], hp[1], hp[2], hp[3], hp[4]);
                        }

                        response(optim.Id.ToString());
                        return;
                    }
                    else
                    {
                        Optimizer optim = this.GetOptimizer(msgObj.objectIndex);
                        response(optim.ProcessMessage(msgObj, this));

                        return;
                    }
                }

                case "FloatTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(_shape: msgObj.shape, _data: msgObj.data, _shader: this.Shader);
                        response(tensor.Id.ToString());
                        return;
                    }
                    else
                    {
                        FloatTensor tensor = floatTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        response(tensor.ProcessMessage(msgObj, this));
                        return;
                    }
                }

                case "IntTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        int[] data = new int[msgObj.data.Length];
                        for (int i = 0; i < msgObj.data.Length; i++)
                        {
                            data[i] = (int)msgObj.data[i];
                        }
                        IntTensor tensor = intTensorFactory.Create(_shape: msgObj.shape, _data: data);
                        response(tensor.Id.ToString());
                        return;
                    }
                    else
                    {
                        IntTensor tensor = intTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        response(tensor.ProcessMessage(msgObj, this));
                        return;
                    }
                }

                case "agent":
                {
                    if (msgObj.functionCall == "create")
                    {
                        Layer     model     = (Layer)GetModel(int.Parse(msgObj.tensorIndexParams[0]));
                        Optimizer optimizer = optimizers[int.Parse(msgObj.tensorIndexParams[1])];
                        response(new Syft.NN.RL.Agent(this, model, optimizer).Id.ToString());
                        return;
                    }

                    //Debug.Log("Getting Model:" + msgObj.objectIndex);
                    Syft.NN.RL.Agent agent = this.GetAgent(msgObj.objectIndex);
                    response(agent.ProcessMessageLocal(msgObj, this));
                    return;
                }

                case "model":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string model_type = msgObj.tensorIndexParams[0];

                        Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);

                        if (model_type == "linear")
                        {
                            response(this.BuildLinear(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "relu")
                        {
                            response(this.BuildReLU().Id.ToString());
                            return;
                        }
                        else if (model_type == "log")
                        {
                            response(this.BuildLog().Id.ToString());
                            return;
                        }
                        else if (model_type == "dropout")
                        {
                            response(this.BuildDropout(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "sigmoid")
                        {
                            response(this.BuildSigmoid().Id.ToString());
                            return;
                        }
                        else if (model_type == "sequential")
                        {
                            response(this.BuildSequential().Id.ToString());
                            return;
                        }
                        else if (model_type == "softmax")
                        {
                            response(this.BuildSoftmax(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "logsoftmax")
                        {
                            response(this.BuildLogSoftmax(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "tanh")
                        {
                            response(new Tanh(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "crossentropyloss")
                        {
                            response(new CrossEntropyLoss(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                            return;
                        }
                        else if (model_type == "categorical_crossentropy")
                        {
                            response(new CategoricalCrossEntropyLoss(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "nllloss")
                        {
                            response(new NLLLoss(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "mseloss")
                        {
                            response(new MSELoss(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "embedding")
                        {
                            response(new Embedding(this, int.Parse(msgObj.tensorIndexParams[1]), int.Parse(msgObj.tensorIndexParams[2])).Id.ToString());
                            return;
                        }
                        else
                        {
                            Debug.LogFormat("<color=red>Model Type Not Found:</color> {0}", model_type);
                        }
                    }
                    else
                    {
                        //Debug.Log("Getting Model:" + msgObj.objectIndex);
                        Model model = this.GetModel(msgObj.objectIndex);
                        response(model.ProcessMessage(msgObj, this));
                        return;
                    }
                    response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                    return;
                }

                case "controller":
                {
                    if (msgObj.functionCall == "num_tensors")
                    {
                        response(floatTensorFactory.Count() + "");
                        return;
                    }
                    else if (msgObj.functionCall == "num_models")
                    {
                        response(models.Count + "");
                        return;
                    }
                    else if (msgObj.functionCall == "new_tensors_allowed")
                    {
                        Debug.LogFormat("New Tensors Allowed:{0}", msgObj.tensorIndexParams[0]);
                        if (msgObj.tensorIndexParams[0] == "True")
                        {
                            allow_new_tensors = true;
                        }
                        else if (msgObj.tensorIndexParams[0] == "False")
                        {
                            allow_new_tensors = false;
                        }
                        else
                        {
                            throw new Exception("Invalid parameter for new_tensors_allowed. Did you mean true or false?");
                        }

                        response(allow_new_tensors + "");
                        return;
                    }
                    else if (msgObj.functionCall == "load_floattensor")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(filepath: msgObj.tensorIndexParams[0], _shader: this.Shader);
                        response(tensor.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "set_seed")
                    {
                        Random.InitState(int.Parse(msgObj.tensorIndexParams[0]));
                        response("Random seed set!");
                        return;
                    }
                    else if (msgObj.functionCall == "concatenate")
                    {
                        List <int> tensor_ids = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            tensor_ids.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        FloatTensor result = Functional.Concatenate(floatTensorFactory, tensor_ids, int.Parse(msgObj.tensorIndexParams[0]));
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "ones")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Ones(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "randn")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Randn(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "random")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Random(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "zeros")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Zeros(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "model_from_json")
                    {
                        Debug.Log("Loading Model from JSON:");
                        var json_str = msgObj.tensorIndexParams[0];
                        var config   = JObject.Parse(json_str);

                        Sequential model;

                        if ((string)config["class_name"] == "Sequential")
                        {
                            model = this.BuildSequential();
                        }
                        else
                        {
                            response("Unity Error: SyftController.processMessage: while Loading model, Class :" + config["class_name"] + " is not implemented");
                            return;
                        }

                        for (int i = 0; i < config["config"].ToList().Count; i++)
                        {
                            var layer_desc        = config["config"][i];
                            var layer_config_desc = layer_desc["config"];

                            if ((string)layer_desc["class_name"] == "Linear")
                            {
                                int previous_output_dim;

                                if (i == 0)
                                {
                                    previous_output_dim = (int)layer_config_desc["batch_input_shape"][layer_config_desc["batch_input_shape"].ToList().Count - 1];
                                }
                                else
                                {
                                    previous_output_dim = (int)layer_config_desc["units"];
                                }

                                string[] parameters = { "linear", previous_output_dim.ToString(), layer_config_desc["units"].ToString(), "Xavier" };
                                Layer    layer      = this.BuildLinear(parameters);
                                model.AddLayer(layer);

                                string activation_name = layer_config_desc["activation"].ToString();

                                if (activation_name != "linear")
                                {
                                    Layer activation;
                                    if (activation_name == "softmax")
                                    {
                                        parameters = new string[] { activation_name, "1" };
                                        activation = this.BuildSoftmax(parameters);
                                    }
                                    else if (activation_name == "relu")
                                    {
                                        activation = this.BuildReLU();
                                    }
                                    else
                                    {
                                        response("Unity Error: SyftController.processMessage: while Loading activations, Activation :" + activation_name + " is not implemented");
                                        return;
                                    }
                                    model.AddLayer(activation);
                                }
                            }
                            else
                            {
                                response("Unity Error: SyftController.processMessage: while Loading layers, Layer :" + layer_desc["class_name"] + " is not implemented");
                                return;
                            }
                        }

                        response(model.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "from_proto")
                    {
                        Debug.Log("Loading Model from ONNX:");
                        var filename = msgObj.tensorIndexParams[0];

                        var        input      = File.OpenRead(filename);
                        ModelProto modelProto = ModelProto.Parser.ParseFrom(input);

                        Sequential model = this.BuildSequential();

                        foreach (NodeProto node in modelProto.Graph.Node)
                        {
                            Layer      layer;
                            GraphProto g = ONNXTools.GetSubGraphFromNodeAndMainGraph(node, modelProto.Graph);
                            if (node.OpType == "Gemm")
                            {
                                layer = new Linear(this, g);
                            }
                            else if (node.OpType == "Dropout")
                            {
                                layer = new Dropout(this, g);
                            }
                            else if (node.OpType == "Relu")
                            {
                                layer = new ReLU(this, g);
                            }
                            else if (node.OpType == "Softmax")
                            {
                                layer = new Softmax(this, g);
                            }
                            else
                            {
                                response("Unity Error: SyftController.processMessage: Layer not yet implemented for deserialization:");
                                return;
                            }
                            model.AddLayer(layer);
                        }

                        response(model.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "to_proto")
                    {
                        ModelProto model    = this.ToProto(msgObj.tensorIndexParams);
                        string     filename = msgObj.tensorIndexParams[2];
                        string     type     = msgObj.tensorIndexParams[3];
                        if (type == "json")
                        {
                            response(model.ToString());
                        }
                        else
                        {
                            using (var output = File.Create(filename))
                            {
                                model.WriteTo(output);
                            }
                            response(new FileInfo(filename).FullName);
                        }
                        return;
                    }

                    response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                    return;
                }

                case "Grid":
                    if (msgObj.functionCall == "learn")
                    {
                        var inputId  = int.Parse(msgObj.tensorIndexParams[0]);
                        var targetId = int.Parse(msgObj.tensorIndexParams[1]);

                        response(this.grid.Run(inputId, targetId, msgObj.configurations, owner));
                        return;
                    }

                    if (msgObj.functionCall == "getResults")
                    {
                        this.grid.GetResults(msgObj.experimentId, response);
                        return;
                    }

                    // like getResults but doesn't pause to wait for results
                    // this function will return right away telling you if
                    // it knows whether or not it is done
                    if (msgObj.functionCall == "checkStatus")
                    {
                        this.grid.CheckStatus(msgObj.experimentId, response);
                        return;
                    }

                    break;

                default:
                    break;
                }
            }
            catch (Exception e)
            {
                Debug.LogFormat("<color=red>{0}</color>", e.ToString());
                response("Unity Error: " + e.ToString());
                return;
            }

            // If not executing createTensor or tensor function, return default error.

            response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
            return;
        }