public void AddMatrixMultiplyTest() { float[] base1_data = new float[] { 1, 2, 3, 4 }; int[] base1_shape = new int[] { 2, 2 }; var base1 = new FloatTensor(base1_data, base1_shape); float[] base2_data = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 }; int[] base2_shape = new int[] { 3, 3 }; var base2 = new FloatTensor(base2_data, base2_shape); float[] data = new float[] { 1, 2, 3, 4, 5, 6 }; int[] tensor1_shape = new int[] { 2, 3 }; int[] tensor2_shape = new int[] { 3, 2 }; var tensor1 = new FloatTensor(data, tensor1_shape); var tensor2 = new FloatTensor(data, tensor2_shape); base1.AddMatrixMultiply(tensor1, tensor2); base2.AddMatrixMultiply(tensor2, tensor1); for (int i = 0; i < base1_shape[0]; i++) { for (int j = 0; j < base1_shape[1]; j++) { float mm_res = base1_data[i * base1_shape[1] + j]; for (int k = 0; k < tensor1_shape[1]; k++) { mm_res += tensor1[i, k] * tensor2[k, j]; } Assert.AreEqual(base1[i, j], mm_res); } } for (int i = 0; i < base2_shape[0]; i++) { for (int j = 0; j < base2_shape[1]; j++) { float mm_res = base2_data[i * base2_shape[1] + j]; for (int k = 0; k < tensor2_shape[1]; k++) { mm_res += tensor2[i, k] * tensor1[k, j]; } Assert.AreEqual(base2[i, j], mm_res); } } }
public string processMessage(string json_message) { //Debug.LogFormat("<color=green>SyftController.processMessage {0}</color>", json_message); Command msgObj = JsonUtility.FromJson <Command>(json_message); if (msgObj.functionCall == "createTensor") { FloatTensor tensor = new FloatTensor(msgObj.data, msgObj.shape); tensor.Shader = shader; tensors.Add(tensor.Id, tensor); Debug.LogFormat("<color=magenta>createTensor:</color> {0}", string.Join(", ", tensor.Data)); string id = tensor.Id.ToString(); return(id); } else { if (msgObj.objectType == "tensor") { //Below check needs additions/fix. bool success = true; if (msgObj.objectIndex > FloatTensor.CreatedObjectCount) { return("Invalid objectIndex: " + msgObj.objectIndex); } FloatTensor tensor = tensors[msgObj.objectIndex]; if (msgObj.functionCall == "init_add_matrix_multiply") { FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]]; tensor.ElementwiseMultiplication(tensor_1); } else if (msgObj.functionCall == "inline_elementwise_subtract") { FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]]; tensor.ElementwiseSubtract(tensor_1); } else if (msgObj.functionCall == "multiply_derivative") { FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]]; tensor.MultiplyDerivative(tensor_1); } else if (msgObj.functionCall == "add_matrix_multiply") { FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]]; FloatTensor tensor_2 = tensors [msgObj.tensorIndexParams [1]]; tensor.AddMatrixMultiply(tensor_1, tensor_2); } else if (msgObj.functionCall == "print") { return(tensor.Print()); } else if (msgObj.functionCall == "abs") { // calls the function on our tensor object tensor.Abs(); } else if (msgObj.functionCall == "neg") { tensor.Neg(); } else if (msgObj.functionCall == "add") { FloatTensor tensor_1 = tensors [msgObj.tensorIndexParams [0]]; FloatTensor output = tensor_1.Add(tensor_1); tensors.Add(output.Id, output); string id = output.Id.ToString(); return(id); } else if (msgObj.functionCall == "scalar_multiply") { //get the scalar, cast it and multiply tensor.ScalarMultiplication((float)msgObj.tensorIndexParams[0]); } else { success = false; } if (success) { return(msgObj.functionCall + ": OK"); } } } return("SyftController.processMessage: Command not found."); }