コード例 #1
0
        public void AddMatrixMultiplyTest()
        {
            float[] base1_data  = new float[] { 1, 2, 3, 4 };
            int[]   base1_shape = new int[] { 2, 2 };
            var     base1       = new FloatTensor(base1_data, base1_shape);

            float[] base2_data  = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
            int[]   base2_shape = new int[] { 3, 3 };
            var     base2       = new FloatTensor(base2_data, base2_shape);

            float[] data          = new float[] { 1, 2, 3, 4, 5, 6 };
            int[]   tensor1_shape = new int[] { 2, 3 };
            int[]   tensor2_shape = new int[] { 3, 2 };
            var     tensor1       = new FloatTensor(data, tensor1_shape);
            var     tensor2       = new FloatTensor(data, tensor2_shape);

            base1.AddMatrixMultiply(tensor1, tensor2);
            base2.AddMatrixMultiply(tensor2, tensor1);

            for (int i = 0; i < base1_shape[0]; i++)
            {
                for (int j = 0; j < base1_shape[1]; j++)
                {
                    float mm_res = base1_data[i * base1_shape[1] + j];
                    for (int k = 0; k < tensor1_shape[1]; k++)
                    {
                        mm_res += tensor1[i, k] * tensor2[k, j];
                    }
                    Assert.AreEqual(base1[i, j], mm_res);
                }
            }

            for (int i = 0; i < base2_shape[0]; i++)
            {
                for (int j = 0; j < base2_shape[1]; j++)
                {
                    float mm_res = base2_data[i * base2_shape[1] + j];
                    for (int k = 0; k < tensor2_shape[1]; k++)
                    {
                        mm_res += tensor2[i, k] * tensor1[k, j];
                    }
                    Assert.AreEqual(base2[i, j], mm_res);
                }
            }
        }
コード例 #2
0
        public string processMessage(string json_message)
        {
            //Debug.LogFormat("<color=green>SyftController.processMessage {0}</color>", json_message);

            Command msgObj = JsonUtility.FromJson <Command>(json_message);

            if (msgObj.functionCall == "createTensor")
            {
                FloatTensor tensor = new FloatTensor(msgObj.data, msgObj.shape);
                tensor.Shader = shader;
                tensors.Add(tensor.Id, tensor);

                Debug.LogFormat("<color=magenta>createTensor:</color> {0}", string.Join(", ", tensor.Data));

                string id = tensor.Id.ToString();

                return(id);
            }
            else
            {
                if (msgObj.objectType == "tensor")
                {
                    //Below check needs additions/fix.
                    bool success = true;
                    if (msgObj.objectIndex > FloatTensor.CreatedObjectCount)

                    {
                        return("Invalid objectIndex: " + msgObj.objectIndex);
                    }

                    FloatTensor tensor = tensors[msgObj.objectIndex];

                    if (msgObj.functionCall == "init_add_matrix_multiply")
                    {
                        FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]];
                        tensor.ElementwiseMultiplication(tensor_1);
                    }
                    else if (msgObj.functionCall == "inline_elementwise_subtract")
                    {
                        FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]];
                        tensor.ElementwiseSubtract(tensor_1);
                    }
                    else if (msgObj.functionCall == "multiply_derivative")
                    {
                        FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]];
                        tensor.MultiplyDerivative(tensor_1);
                    }
                    else if (msgObj.functionCall == "add_matrix_multiply")
                    {
                        FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]];
                        FloatTensor tensor_2 = tensors [msgObj.tensorIndexParams [1]];
                        tensor.AddMatrixMultiply(tensor_1, tensor_2);
                    }
                    else if (msgObj.functionCall == "print")
                    {
                        return(tensor.Print());
                    }
                    else if (msgObj.functionCall == "abs")
                    {
                        // calls the function on our tensor object
                        tensor.Abs();
                    }
                    else if (msgObj.functionCall == "neg")
                    {
                        tensor.Neg();
                    }
                    else if (msgObj.functionCall == "add")
                    {
                        FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]];

                        FloatTensor output = tensor_1.Add(tensor_1);
                        tensors.Add(output.Id, output);
                        string id = output.Id.ToString();
                        return(id);
                    }

                    else if (msgObj.functionCall == "scalar_multiply")
                    {
                        //get the scalar, cast it and multiply
                        tensor.ScalarMultiplication((float)msgObj.tensorIndexParams[0]);
                    }
                    else
                    {
                        success = false;
                    }

                    if (success)
                    {
                        return(msgObj.functionCall + ": OK");
                    }
                }
            }

            return("SyftController.processMessage: Command not found.");
        }