コード例 #1
0
ファイル: OnnxUtils.cs プロジェクト: sazae657/machinelearning
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                try
                {
                    _session = new InferenceSession(modelFile,
                                                    SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
                }
                catch (OnnxRuntimeException)
                {
                    if (fallbackToCpu)
                    {
                        _session = new InferenceSession(modelFile);
                    }
                    else
                    {
                        // If called from OnnxTransform, is caught and rethrown
                        throw;
                    }
                }
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            try
            {
                // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
                // doesn't expose full type information via its C# APIs.
                ModelFile = modelFile;
                var model = new OnnxCSharpToProtoWrapper.ModelProto();
                using (var modelStream = File.OpenRead(modelFile))
                    using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
                        model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

                // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
                var inputTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Input)
                {
                    inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                }

                var initializerTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Initializer)
                {
                    initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
                }

                var outputTypePool = new Dictionary <string, DataViewType>();
                // Build casters which maps NamedOnnxValue to .NET objects.
                var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();
                foreach (var valueInfo in model.Graph.Output)
                {
                    outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                    casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
                }

                var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
                var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
                var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

                // Create a view to the used ONNX model from ONNXRuntime's perspective.
                ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

                Graph = model.Graph;
            }
            catch
            {
                _session.Dispose();
                _session = null;
                throw;
            }
        }
コード例 #2
0
 public static void NodeAddAttributes(NodeProto node, string argName, GraphProto value)
 => node.Attribute.Add(MakeAttribute(argName, value));
コード例 #3
0
ファイル: OnnxUtils.cs プロジェクト: LuanNg/machinelearning
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu           = false,
                         bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null)
        {
            ModelFile = modelFile;
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _ownModelFile = ownModelFile;
            _disposed     = false;

            if (gpuDeviceId != null)
            {
                // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms.
                // This code path will be re-enabled when there is appropriate support in onnxruntime
                throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!");
            }
            else
            {
                _session = new InferenceSession(modelFile);
            }

            // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
            // doesn't expose full type information via its C# APIs.
            ModelFile = modelFile;
            var model = new OnnxCSharpToProtoWrapper.ModelProto();

            using (var modelStream = File.OpenRead(modelFile))
                using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10))
                    model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

            // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
            var inputTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Input)
            {
                inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
            }

            var initializerTypePool = new Dictionary <string, DataViewType>();

            foreach (var valueInfo in model.Graph.Initializer)
            {
                initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
            }

            var outputTypePool = new Dictionary <string, DataViewType>();
            // Build casters which maps NamedOnnxValue to .NET objects.
            var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();

            foreach (var valueInfo in model.Graph.Output)
            {
                outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
            }

            var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
            var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
            var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

            // Create a view to the used ONNX model from ONNXRuntime's perspective.
            ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

            Graph = model.Graph;
        }
コード例 #4
0
        /// <summary>
        /// Constructs OnnxModel object from file.
        /// </summary>
        /// <param name="modelFile">Model file path.</param>
        /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
        /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param>
        /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
        /// no longer needed.</param>
        /// <param name="shapeDictionary"></param>
        /// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param>
        /// <param name="interOpNumThreads">Controls the number of threads used to parallelize the execution of the graph (across nodes).</param>
        /// <param name="intraOpNumThreads">Controls the number of threads to use to run the model.</param>
        public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu               = false,
                         bool ownModelFile     = false, IDictionary <string, int[]> shapeDictionary = null, int recursionLimit = 100,
                         int?interOpNumThreads = null, int?intraOpNumThreads = null)
        {
            // If we don't own the model file, _disposed should be false to prevent deleting user's file.
            _disposed = false;

            if (gpuDeviceId != null)
            {
                try
                {
                    _session = new InferenceSession(modelFile,
                                                    SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
                }
                catch (OnnxRuntimeException)
                {
                    if (fallbackToCpu)
                    {
                        var sessionOptions = new SessionOptions()
                        {
                            InterOpNumThreads = interOpNumThreads.GetValueOrDefault(),
                            IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault()
                        };
                        _session = new InferenceSession(modelFile, sessionOptions);
                    }
                    else
                    {
                        // If called from OnnxTransform, is caught and rethrown
                        throw;
                    }
                }
            }
            else
            {
                var sessionOptions = new SessionOptions()
                {
                    InterOpNumThreads = interOpNumThreads.GetValueOrDefault(),
                    IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault()
                };
                _session = new InferenceSession(modelFile, sessionOptions);
            }

            try
            {
                // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
                // doesn't expose full type information via its C# APIs.
                var model = new OnnxCSharpToProtoWrapper.ModelProto();
                // If we own the model file set the DeleteOnClose flag so it is always deleted.
                if (ownModelFile)
                {
                    ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose);
                }
                else
                {
                    ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read);
                }

                // The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here.
                using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read))
                    using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, recursionLimit))
                        model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream);

                // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's.
                var inputTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Input)
                {
                    inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                }

                var initializerTypePool = new Dictionary <string, DataViewType>();
                foreach (var valueInfo in model.Graph.Initializer)
                {
                    initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);
                }

                var outputTypePool = new Dictionary <string, DataViewType>();
                // Build casters which maps NamedOnnxValue to .NET objects.
                var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >();
                foreach (var valueInfo in model.Graph.Output)
                {
                    outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
                    casterPool[valueInfo.Name]     = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
                }

                var inputInfos  = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
                var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
                var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

                // Create a view to the used ONNX model from ONNXRuntime's perspective.
                ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);

                Graph = model.Graph;
            }
            catch
            {
                _session.Dispose();
                _session = null;
                throw;
            }
        }
コード例 #5
0
        /// <summary>Remove unnecessary initializer reshapes from graph.</summary>
        // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py
        public static void RemoveUnnecessaryInitializerReshapes(this GraphProto graph)
        {
            var nameToInitializer = graph.Initializer.ToDictionary(i => i.Name, i => i);

            var nodes      = graph.Node;
            var valueInfos = graph.ValueInfo;

            var nodesToRemove = new List <NodeProto>();

            for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++)
            {
                var node = nodes[nodeIndex];

                var opSpec = Ops.Reshape.Spec;
                if (node.OpType == opSpec.OpType)
                {
                    var inputs  = node.Input;
                    var outputs = node.Output;

                    // Expected Reshape takes 2 inputs and has 1 output
                    if (inputs.Count == opSpec.Inputs && outputs.Count == opSpec.Outputs)
                    {
                        var dataName          = inputs[0];
                        var shapeName         = inputs[1];
                        var reshapeOutputName = outputs[0];

                        // Both inputs must be initializers ("static")
                        if (nameToInitializer.TryGetValue(dataName, out var dataInitializer) &&
                            nameToInitializer.TryGetValue(shapeName, out var shapeInitializer))
                        {
                            // TODO: Check initializer not used in other nodes

                            var outputShapeValue = valueInfos.Single(v => v.Name, reshapeOutputName);

                            var outputShapeDims = outputShapeValue.Type.TensorType.Shape.Dim;
                            var allValue        = outputShapeDims.All(d => d.ValueCase ==
                                                                      TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue);
                            if (allValue)
                            {
                                var outputShape = outputShapeDims.Select(d => d.DimValue).ToArray();

                                var allPositive = outputShape.All(d => d > 0);
                                if (allPositive)
                                {
                                    // Check shape compared to initializer shape
                                    var dataShape = dataInitializer.Dims.ToArray();

                                    var outputShapeProductSum = outputShape.Product();
                                    var dataShapeProductSum   = dataShape.Product();

                                    if (outputShapeProductSum == dataShapeProductSum)
                                    {
                                        // Change data shape to the reshape output shape directly
                                        dataInitializer.Dims.Clear();
                                        dataInitializer.Dims.AddRange(outputShape);

                                        // Remove reshape data shape both as initializer and input
                                        graph.Initializer.TryRemove(i => i.Name, shapeName);
                                        graph.Input.TryRemove(i => i.Name, shapeName);

                                        nodesToRemove.Add(node);

                                        // Replace reshape output name with data name directly in all nodes
                                        ReplaceInput(nodes, reshapeOutputName, dataName);
                                    }
                                }
                            }
                        }
                    }
                }
            }
            foreach (var node in nodesToRemove)
            {
                nodes.Remove(node);
            }
        }
コード例 #6
0
 /// <summary>Clean graph for inference.</summary>
 public static void Clean(this GraphProto graph)
 {
     graph.RemoveInitializersFromInputs();
     graph.RemoveUnnecessaryInitializerReshapes();
 }
コード例 #7
0
        // See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
        public override GraphProto GetProto(int inputTensorId, SyftController ctrl)
        {
            FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId);

            this.Forward(input_tensor);

            NodeProto node = new NodeProto
            {
                Input     = { inputTensorId.ToString(), _weights.Id.ToString() },
                Output    = { activation.ToString() },
                Name      = this.name,
                OpType    = "Gemm",
                DocString = ""
            };

            if (_biased)
            {
                node.Input.Add(_bias.Id.ToString());
            }

            node.Attribute.Add(new AttributeProto {
                Name = "alpha",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = 1.0f
            });
            node.Attribute.Add(new AttributeProto {
                Name = "beta",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = 1.0f
            });
            node.Attribute.Add(new AttributeProto {
                Name = "broadcast",
                Type = AttributeProto.Types.AttributeType.Int,
                I    = 1
            });

            TensorProto w_init = _weights.GetProto();

            ValueInfoProto input_info = input_tensor.GetValueInfoProto();
            ValueInfoProto w_info     = _weights.GetValueInfoProto();

            GraphProto g = new GraphProto
            {
                Name        = Guid.NewGuid().ToString("N"),
                Node        = { node },
                Initializer = { w_init },
                Input       = { input_info, w_info },
                Output      = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() },
            };

            if (_biased)
            {
                TensorProto    b_init = _bias.GetProto();
                ValueInfoProto b_info = _bias.GetValueInfoProto();
                g.Initializer.Add(b_init);
                g.Input.Add(b_info);
            }
            else
            {
                // The Gemm schema, must have 3 inputs (must have a bias)
                float[] tmpData = new float[1] {
                    0
                };
                int[] tmpDims = new int[1] {
                    1
                };
                FloatTensor tmpBias = ctrl.floatTensorFactory.Create(_data: tmpData, _shape: tmpDims, _autograd: false, _keepgrads: false);
                g.Initializer.Add(tmpBias.GetProto());
                g.Input.Add(tmpBias.GetValueInfoProto());
                g.Node[0].Input.Add(tmpBias.Id.ToString());
            }

            return(g);
        }
コード例 #8
0
 /// <summary>
 /// Set dimension of inputs, value infos, outputs and potential Reshape ops.
 /// Default sets leading dimension to dynamic batch size 'N'.
 /// </summary>
 public static void SetDim(this GraphProto graph) =>
 graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
コード例 #9
0
ファイル: SyftController.cs プロジェクト: ygambhir/OpenMined
        public void ProcessMessage(string json_message, MonoBehaviour owner, Action <string> response)
        {
            Command msgObj = JsonUtility.FromJson <Command> (json_message);

            try
            {
                switch (msgObj.objectType)
                {
                case "Optimizer":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string optimizer_type = msgObj.tensorIndexParams[0];

                        // Extract parameters
                        List <int> p = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            p.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        List <float> hp = new List <float>();
                        for (int i = 0; i < msgObj.hyperParams.Length; i++)
                        {
                            hp.Add(float.Parse(msgObj.hyperParams[i]));
                        }

                        Optimizer optim = null;

                        if (optimizer_type == "sgd")
                        {
                            optim = new SGD(this, p, hp[0], hp[1], hp[2]);
                        }
                        else if (optimizer_type == "rmsprop")
                        {
                            optim = new RMSProp(this, p, hp[0], hp[1], hp[2], hp[3]);
                        }
                        else if (optimizer_type == "adam")
                        {
                            optim = new Adam(this, p, hp[0], hp[1], hp[2], hp[3], hp[4]);
                        }

                        response(optim.Id.ToString());
                        return;
                    }
                    else
                    {
                        Optimizer optim = this.GetOptimizer(msgObj.objectIndex);
                        response(optim.ProcessMessage(msgObj, this));

                        return;
                    }
                }

                case "FloatTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(_shape: msgObj.shape, _data: msgObj.data, _shader: this.Shader);
                        response(tensor.Id.ToString());
                        return;
                    }
                    else
                    {
                        FloatTensor tensor = floatTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        response(tensor.ProcessMessage(msgObj, this));
                        return;
                    }
                }

                case "IntTensor":
                {
                    if (msgObj.objectIndex == 0 && msgObj.functionCall == "create")
                    {
                        int[] data = new int[msgObj.data.Length];
                        for (int i = 0; i < msgObj.data.Length; i++)
                        {
                            data[i] = (int)msgObj.data[i];
                        }
                        IntTensor tensor = intTensorFactory.Create(_shape: msgObj.shape, _data: data);
                        response(tensor.Id.ToString());
                        return;
                    }
                    else
                    {
                        IntTensor tensor = intTensorFactory.Get(msgObj.objectIndex);
                        // Process message's function
                        response(tensor.ProcessMessage(msgObj, this));
                        return;
                    }
                }

                case "agent":
                {
                    if (msgObj.functionCall == "create")
                    {
                        Layer     model     = (Layer)GetModel(int.Parse(msgObj.tensorIndexParams[0]));
                        Optimizer optimizer = optimizers[int.Parse(msgObj.tensorIndexParams[1])];
                        response(new Syft.NN.RL.Agent(this, model, optimizer).Id.ToString());
                        return;
                    }

                    //Debug.Log("Getting Model:" + msgObj.objectIndex);
                    Syft.NN.RL.Agent agent = this.GetAgent(msgObj.objectIndex);
                    response(agent.ProcessMessageLocal(msgObj, this));
                    return;
                }

                case "model":
                {
                    if (msgObj.functionCall == "create")
                    {
                        string model_type = msgObj.tensorIndexParams[0];

                        Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type);

                        if (model_type == "linear")
                        {
                            response(this.BuildLinear(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "relu")
                        {
                            response(this.BuildReLU().Id.ToString());
                            return;
                        }
                        else if (model_type == "log")
                        {
                            response(this.BuildLog().Id.ToString());
                            return;
                        }
                        else if (model_type == "dropout")
                        {
                            response(this.BuildDropout(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "sigmoid")
                        {
                            response(this.BuildSigmoid().Id.ToString());
                            return;
                        }
                        else if (model_type == "sequential")
                        {
                            response(this.BuildSequential().Id.ToString());
                            return;
                        }
                        else if (model_type == "softmax")
                        {
                            response(this.BuildSoftmax(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "logsoftmax")
                        {
                            response(this.BuildLogSoftmax(msgObj.tensorIndexParams).Id.ToString());
                            return;
                        }
                        else if (model_type == "tanh")
                        {
                            response(new Tanh(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "crossentropyloss")
                        {
                            response(new CrossEntropyLoss(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString());
                            return;
                        }
                        else if (model_type == "categorical_crossentropy")
                        {
                            response(new CategoricalCrossEntropyLoss(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "nllloss")
                        {
                            response(new NLLLoss(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "mseloss")
                        {
                            response(new MSELoss(this).Id.ToString());
                            return;
                        }
                        else if (model_type == "embedding")
                        {
                            response(new Embedding(this, int.Parse(msgObj.tensorIndexParams[1]), int.Parse(msgObj.tensorIndexParams[2])).Id.ToString());
                            return;
                        }
                        else
                        {
                            Debug.LogFormat("<color=red>Model Type Not Found:</color> {0}", model_type);
                        }
                    }
                    else
                    {
                        //Debug.Log("Getting Model:" + msgObj.objectIndex);
                        Model model = this.GetModel(msgObj.objectIndex);
                        response(model.ProcessMessage(msgObj, this));
                        return;
                    }
                    response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                    return;
                }

                case "controller":
                {
                    if (msgObj.functionCall == "num_tensors")
                    {
                        response(floatTensorFactory.Count() + "");
                        return;
                    }
                    else if (msgObj.functionCall == "num_models")
                    {
                        response(models.Count + "");
                        return;
                    }
                    else if (msgObj.functionCall == "new_tensors_allowed")
                    {
                        Debug.LogFormat("New Tensors Allowed:{0}", msgObj.tensorIndexParams[0]);
                        if (msgObj.tensorIndexParams[0] == "True")
                        {
                            allow_new_tensors = true;
                        }
                        else if (msgObj.tensorIndexParams[0] == "False")
                        {
                            allow_new_tensors = false;
                        }
                        else
                        {
                            throw new Exception("Invalid parameter for new_tensors_allowed. Did you mean true or false?");
                        }

                        response(allow_new_tensors + "");
                        return;
                    }
                    else if (msgObj.functionCall == "load_floattensor")
                    {
                        FloatTensor tensor = floatTensorFactory.Create(filepath: msgObj.tensorIndexParams[0], _shader: this.Shader);
                        response(tensor.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "set_seed")
                    {
                        Random.InitState(int.Parse(msgObj.tensorIndexParams[0]));
                        response("Random seed set!");
                        return;
                    }
                    else if (msgObj.functionCall == "concatenate")
                    {
                        List <int> tensor_ids = new List <int>();
                        for (int i = 1; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            tensor_ids.Add(int.Parse(msgObj.tensorIndexParams[i]));
                        }
                        FloatTensor result = Functional.Concatenate(floatTensorFactory, tensor_ids, int.Parse(msgObj.tensorIndexParams[0]));
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "ones")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Ones(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "randn")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Randn(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "random")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Random(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "zeros")
                    {
                        int[] dims = new int[msgObj.tensorIndexParams.Length];
                        for (int i = 0; i < msgObj.tensorIndexParams.Length; i++)
                        {
                            dims[i] = int.Parse(msgObj.tensorIndexParams[i]);
                        }
                        FloatTensor result = Functional.Zeros(floatTensorFactory, dims);
                        response(result.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "model_from_json")
                    {
                        Debug.Log("Loading Model from JSON:");
                        var json_str = msgObj.tensorIndexParams[0];
                        var config   = JObject.Parse(json_str);

                        Sequential model;

                        if ((string)config["class_name"] == "Sequential")
                        {
                            model = this.BuildSequential();
                        }
                        else
                        {
                            response("Unity Error: SyftController.processMessage: while Loading model, Class :" + config["class_name"] + " is not implemented");
                            return;
                        }

                        for (int i = 0; i < config["config"].ToList().Count; i++)
                        {
                            var layer_desc        = config["config"][i];
                            var layer_config_desc = layer_desc["config"];

                            if ((string)layer_desc["class_name"] == "Linear")
                            {
                                int previous_output_dim;

                                if (i == 0)
                                {
                                    previous_output_dim = (int)layer_config_desc["batch_input_shape"][layer_config_desc["batch_input_shape"].ToList().Count - 1];
                                }
                                else
                                {
                                    previous_output_dim = (int)layer_config_desc["units"];
                                }

                                string[] parameters = { "linear", previous_output_dim.ToString(), layer_config_desc["units"].ToString(), "Xavier" };
                                Layer    layer      = this.BuildLinear(parameters);
                                model.AddLayer(layer);

                                string activation_name = layer_config_desc["activation"].ToString();

                                if (activation_name != "linear")
                                {
                                    Layer activation;
                                    if (activation_name == "softmax")
                                    {
                                        parameters = new string[] { activation_name, "1" };
                                        activation = this.BuildSoftmax(parameters);
                                    }
                                    else if (activation_name == "relu")
                                    {
                                        activation = this.BuildReLU();
                                    }
                                    else
                                    {
                                        response("Unity Error: SyftController.processMessage: while Loading activations, Activation :" + activation_name + " is not implemented");
                                        return;
                                    }
                                    model.AddLayer(activation);
                                }
                            }
                            else
                            {
                                response("Unity Error: SyftController.processMessage: while Loading layers, Layer :" + layer_desc["class_name"] + " is not implemented");
                                return;
                            }
                        }

                        response(model.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "from_proto")
                    {
                        Debug.Log("Loading Model from ONNX:");
                        var filename = msgObj.tensorIndexParams[0];

                        var        input      = File.OpenRead(filename);
                        ModelProto modelProto = ModelProto.Parser.ParseFrom(input);

                        Sequential model = this.BuildSequential();

                        foreach (NodeProto node in modelProto.Graph.Node)
                        {
                            Layer      layer;
                            GraphProto g = ONNXTools.GetSubGraphFromNodeAndMainGraph(node, modelProto.Graph);
                            if (node.OpType == "Gemm")
                            {
                                layer = new Linear(this, g);
                            }
                            else if (node.OpType == "Dropout")
                            {
                                layer = new Dropout(this, g);
                            }
                            else if (node.OpType == "Relu")
                            {
                                layer = new ReLU(this, g);
                            }
                            else if (node.OpType == "Softmax")
                            {
                                layer = new Softmax(this, g);
                            }
                            else
                            {
                                response("Unity Error: SyftController.processMessage: Layer not yet implemented for deserialization:");
                                return;
                            }
                            model.AddLayer(layer);
                        }

                        response(model.Id.ToString());
                        return;
                    }
                    else if (msgObj.functionCall == "to_proto")
                    {
                        ModelProto model    = this.ToProto(msgObj.tensorIndexParams);
                        string     filename = msgObj.tensorIndexParams[2];
                        string     type     = msgObj.tensorIndexParams[3];
                        if (type == "json")
                        {
                            response(model.ToString());
                        }
                        else
                        {
                            using (var output = File.Create(filename))
                            {
                                model.WriteTo(output);
                            }
                            response(new FileInfo(filename).FullName);
                        }
                        return;
                    }

                    response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
                    return;
                }

                case "Grid":
                    if (msgObj.functionCall == "learn")
                    {
                        var inputId  = int.Parse(msgObj.tensorIndexParams[0]);
                        var targetId = int.Parse(msgObj.tensorIndexParams[1]);

                        response(this.grid.Run(inputId, targetId, msgObj.configurations, owner));
                        return;
                    }

                    if (msgObj.functionCall == "getResults")
                    {
                        this.grid.GetResults(msgObj.experimentId, response);
                        return;
                    }

                    // like getResults but doesn't pause to wait for results
                    // this function will return right away telling you if
                    // it knows whether or not it is done
                    if (msgObj.functionCall == "checkStatus")
                    {
                        this.grid.CheckStatus(msgObj.experimentId, response);
                        return;
                    }

                    break;

                default:
                    break;
                }
            }
            catch (Exception e)
            {
                Debug.LogFormat("<color=red>{0}</color>", e.ToString());
                response("Unity Error: " + e.ToString());
                return;
            }

            // If not executing createTensor or tensor function, return default error.

            response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall);
            return;
        }