/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> /// <param name="shapeDictionary"></param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null) { ModelFile = modelFile; // If we don't own the model file, _disposed should be false to prevent deleting user's file. _ownModelFile = ownModelFile; _disposed = false; if (gpuDeviceId != null) { try { _session = new InferenceSession(modelFile, SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value)); } catch (OnnxRuntimeException) { if (fallbackToCpu) { _session = new InferenceSession(modelFile); } else { // If called from OnnxTransform, is caught and rethrown throw; } } } else { _session = new InferenceSession(modelFile); } try { // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. ModelFile = modelFile; var model = new OnnxCSharpToProtoWrapper.ModelProto(); using (var modelStream = File.OpenRead(modelFile)) using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var initializerTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Initializer) { initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); // Create a view to the used ONNX model from ONNXRuntime's perspective. ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); Graph = model.Graph; } catch { _session.Dispose(); _session = null; throw; } }
public static void NodeAddAttributes(NodeProto node, string argName, GraphProto value) => node.Attribute.Add(MakeAttribute(argName, value));
/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> /// <param name="shapeDictionary"></param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null) { ModelFile = modelFile; // If we don't own the model file, _disposed should be false to prevent deleting user's file. _ownModelFile = ownModelFile; _disposed = false; if (gpuDeviceId != null) { // The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms. // This code path will be re-enabled when there is appropriate support in onnxruntime throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!"); } else { _session = new InferenceSession(modelFile); } // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. ModelFile = modelFile; var model = new OnnxCSharpToProtoWrapper.ModelProto(); using (var modelStream = File.OpenRead(modelFile)) using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var initializerTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Initializer) { initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); // Create a view to the used ONNX model from ONNXRuntime's perspective. ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); Graph = model.Graph; }
/// <summary> /// Constructs OnnxModel object from file. /// </summary> /// <param name="modelFile">Model file path.</param> /// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param> /// <param name="fallbackToCpu">If true, resumes CPU execution quietly upon GPU error.</param> /// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is /// no longer needed.</param> /// <param name="shapeDictionary"></param> /// <param name="recursionLimit">Optional, specifies the Protobuf CodedInputStream recursion limit. Default value is 100.</param> /// <param name="interOpNumThreads">Controls the number of threads used to parallelize the execution of the graph (across nodes).</param> /// <param name="intraOpNumThreads">Controls the number of threads to use to run the model.</param> public OnnxModel(string modelFile, int?gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile = false, IDictionary <string, int[]> shapeDictionary = null, int recursionLimit = 100, int?interOpNumThreads = null, int?intraOpNumThreads = null) { // If we don't own the model file, _disposed should be false to prevent deleting user's file. _disposed = false; if (gpuDeviceId != null) { try { _session = new InferenceSession(modelFile, SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value)); } catch (OnnxRuntimeException) { if (fallbackToCpu) { var sessionOptions = new SessionOptions() { InterOpNumThreads = interOpNumThreads.GetValueOrDefault(), IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault() }; _session = new InferenceSession(modelFile, sessionOptions); } else { // If called from OnnxTransform, is caught and rethrown throw; } } } else { var sessionOptions = new SessionOptions() { InterOpNumThreads = interOpNumThreads.GetValueOrDefault(), IntraOpNumThreads = intraOpNumThreads.GetValueOrDefault() }; _session = new InferenceSession(modelFile, sessionOptions); } try { // Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime // doesn't expose full type information via its C# APIs. var model = new OnnxCSharpToProtoWrapper.ModelProto(); // If we own the model file set the DeleteOnClose flag so it is always deleted. if (ownModelFile) { ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Read, 4096, FileOptions.DeleteOnClose); } else { ModelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read); } // The CodedInputStream auto closes the stream, and we need to make sure that our main stream stays open, so creating a new one here. using (var modelStream = new FileStream(modelFile, FileMode.Open, FileAccess.Read, FileShare.Delete | FileShare.Read)) using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, recursionLimit)) model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); // Parse actual input and output types stored in the loaded ONNX model to get their DataViewType's. var inputTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Input) { inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); } var initializerTypePool = new Dictionary <string, DataViewType>(); foreach (var valueInfo in model.Graph.Initializer) { initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType); } var outputTypePool = new Dictionary <string, DataViewType>(); // Build casters which maps NamedOnnxValue to .NET objects. var casterPool = new Dictionary <string, Func <NamedOnnxValue, object> >(); foreach (var valueInfo in model.Graph.Output) { outputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type); casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType); } var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null); var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool); var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null); // Create a view to the used ONNX model from ONNXRuntime's perspective. ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers); Graph = model.Graph; } catch { _session.Dispose(); _session = null; throw; } }
/// <summary>Remove unnecessary initializer reshapes from graph.</summary> // https://github.com/microsoft/onnxruntime/blob/master/tools/python/remove_initializer_from_input.py public static void RemoveUnnecessaryInitializerReshapes(this GraphProto graph) { var nameToInitializer = graph.Initializer.ToDictionary(i => i.Name, i => i); var nodes = graph.Node; var valueInfos = graph.ValueInfo; var nodesToRemove = new List <NodeProto>(); for (int nodeIndex = 0; nodeIndex < nodes.Count; nodeIndex++) { var node = nodes[nodeIndex]; var opSpec = Ops.Reshape.Spec; if (node.OpType == opSpec.OpType) { var inputs = node.Input; var outputs = node.Output; // Expected Reshape takes 2 inputs and has 1 output if (inputs.Count == opSpec.Inputs && outputs.Count == opSpec.Outputs) { var dataName = inputs[0]; var shapeName = inputs[1]; var reshapeOutputName = outputs[0]; // Both inputs must be initializers ("static") if (nameToInitializer.TryGetValue(dataName, out var dataInitializer) && nameToInitializer.TryGetValue(shapeName, out var shapeInitializer)) { // TODO: Check initializer not used in other nodes var outputShapeValue = valueInfos.Single(v => v.Name, reshapeOutputName); var outputShapeDims = outputShapeValue.Type.TensorType.Shape.Dim; var allValue = outputShapeDims.All(d => d.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue); if (allValue) { var outputShape = outputShapeDims.Select(d => d.DimValue).ToArray(); var allPositive = outputShape.All(d => d > 0); if (allPositive) { // Check shape compared to initializer shape var dataShape = dataInitializer.Dims.ToArray(); var outputShapeProductSum = outputShape.Product(); var dataShapeProductSum = dataShape.Product(); if (outputShapeProductSum == dataShapeProductSum) { // Change data shape to the reshape output shape directly dataInitializer.Dims.Clear(); dataInitializer.Dims.AddRange(outputShape); // Remove reshape data shape both as initializer and input graph.Initializer.TryRemove(i => i.Name, shapeName); graph.Input.TryRemove(i => i.Name, shapeName); nodesToRemove.Add(node); // Replace reshape output name with data name directly in all nodes ReplaceInput(nodes, reshapeOutputName, dataName); } } } } } } } foreach (var node in nodesToRemove) { nodes.Remove(node); } }
/// <summary>Clean graph for inference.</summary> public static void Clean(this GraphProto graph) { graph.RemoveInitializersFromInputs(); graph.RemoveUnnecessaryInitializerReshapes(); }
// See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm public override GraphProto GetProto(int inputTensorId, SyftController ctrl) { FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId); this.Forward(input_tensor); NodeProto node = new NodeProto { Input = { inputTensorId.ToString(), _weights.Id.ToString() }, Output = { activation.ToString() }, Name = this.name, OpType = "Gemm", DocString = "" }; if (_biased) { node.Input.Add(_bias.Id.ToString()); } node.Attribute.Add(new AttributeProto { Name = "alpha", Type = AttributeProto.Types.AttributeType.Float, F = 1.0f }); node.Attribute.Add(new AttributeProto { Name = "beta", Type = AttributeProto.Types.AttributeType.Float, F = 1.0f }); node.Attribute.Add(new AttributeProto { Name = "broadcast", Type = AttributeProto.Types.AttributeType.Int, I = 1 }); TensorProto w_init = _weights.GetProto(); ValueInfoProto input_info = input_tensor.GetValueInfoProto(); ValueInfoProto w_info = _weights.GetValueInfoProto(); GraphProto g = new GraphProto { Name = Guid.NewGuid().ToString("N"), Node = { node }, Initializer = { w_init }, Input = { input_info, w_info }, Output = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() }, }; if (_biased) { TensorProto b_init = _bias.GetProto(); ValueInfoProto b_info = _bias.GetValueInfoProto(); g.Initializer.Add(b_init); g.Input.Add(b_info); } else { // The Gemm schema, must have 3 inputs (must have a bias) float[] tmpData = new float[1] { 0 }; int[] tmpDims = new int[1] { 1 }; FloatTensor tmpBias = ctrl.floatTensorFactory.Create(_data: tmpData, _shape: tmpDims, _autograd: false, _keepgrads: false); g.Initializer.Add(tmpBias.GetProto()); g.Input.Add(tmpBias.GetValueInfoProto()); g.Node[0].Input.Add(tmpBias.Id.ToString()); } return(g); }
/// <summary> /// Set dimension of inputs, value infos, outputs and potential Reshape ops. /// Default sets leading dimension to dynamic batch size 'N'. /// </summary> public static void SetDim(this GraphProto graph) => graph.SetDim(dimIndex: 0, DimParamOrValue.New("N"));
public void ProcessMessage(string json_message, MonoBehaviour owner, Action <string> response) { Command msgObj = JsonUtility.FromJson <Command> (json_message); try { switch (msgObj.objectType) { case "Optimizer": { if (msgObj.functionCall == "create") { string optimizer_type = msgObj.tensorIndexParams[0]; // Extract parameters List <int> p = new List <int>(); for (int i = 1; i < msgObj.tensorIndexParams.Length; i++) { p.Add(int.Parse(msgObj.tensorIndexParams[i])); } List <float> hp = new List <float>(); for (int i = 0; i < msgObj.hyperParams.Length; i++) { hp.Add(float.Parse(msgObj.hyperParams[i])); } Optimizer optim = null; if (optimizer_type == "sgd") { optim = new SGD(this, p, hp[0], hp[1], hp[2]); } else if (optimizer_type == "rmsprop") { optim = new RMSProp(this, p, hp[0], hp[1], hp[2], hp[3]); } else if (optimizer_type == "adam") { optim = new Adam(this, p, hp[0], hp[1], hp[2], hp[3], hp[4]); } response(optim.Id.ToString()); return; } else { Optimizer optim = this.GetOptimizer(msgObj.objectIndex); response(optim.ProcessMessage(msgObj, this)); return; } } case "FloatTensor": { if (msgObj.objectIndex == 0 && msgObj.functionCall == "create") { FloatTensor tensor = floatTensorFactory.Create(_shape: msgObj.shape, _data: msgObj.data, _shader: this.Shader); response(tensor.Id.ToString()); return; } else { FloatTensor tensor = floatTensorFactory.Get(msgObj.objectIndex); // Process message's function response(tensor.ProcessMessage(msgObj, this)); return; } } case "IntTensor": { if (msgObj.objectIndex == 0 && msgObj.functionCall == "create") { int[] data = new int[msgObj.data.Length]; for (int i = 0; i < msgObj.data.Length; i++) { data[i] = (int)msgObj.data[i]; } IntTensor tensor = intTensorFactory.Create(_shape: msgObj.shape, _data: data); response(tensor.Id.ToString()); return; } else { IntTensor tensor = intTensorFactory.Get(msgObj.objectIndex); // Process message's function response(tensor.ProcessMessage(msgObj, this)); return; } } case "agent": { if (msgObj.functionCall == "create") { Layer model = (Layer)GetModel(int.Parse(msgObj.tensorIndexParams[0])); Optimizer optimizer = optimizers[int.Parse(msgObj.tensorIndexParams[1])]; response(new Syft.NN.RL.Agent(this, model, optimizer).Id.ToString()); return; } //Debug.Log("Getting Model:" + msgObj.objectIndex); Syft.NN.RL.Agent agent = this.GetAgent(msgObj.objectIndex); response(agent.ProcessMessageLocal(msgObj, this)); return; } case "model": { if (msgObj.functionCall == "create") { string model_type = msgObj.tensorIndexParams[0]; Debug.LogFormat("<color=magenta>createModel:</color> {0}", model_type); if (model_type == "linear") { response(this.BuildLinear(msgObj.tensorIndexParams).Id.ToString()); return; } else if (model_type == "relu") { response(this.BuildReLU().Id.ToString()); return; } else if (model_type == "log") { response(this.BuildLog().Id.ToString()); return; } else if (model_type == "dropout") { response(this.BuildDropout(msgObj.tensorIndexParams).Id.ToString()); return; } else if (model_type == "sigmoid") { response(this.BuildSigmoid().Id.ToString()); return; } else if (model_type == "sequential") { response(this.BuildSequential().Id.ToString()); return; } else if (model_type == "softmax") { response(this.BuildSoftmax(msgObj.tensorIndexParams).Id.ToString()); return; } else if (model_type == "logsoftmax") { response(this.BuildLogSoftmax(msgObj.tensorIndexParams).Id.ToString()); return; } else if (model_type == "tanh") { response(new Tanh(this).Id.ToString()); return; } else if (model_type == "crossentropyloss") { response(new CrossEntropyLoss(this, int.Parse(msgObj.tensorIndexParams[1])).Id.ToString()); return; } else if (model_type == "categorical_crossentropy") { response(new CategoricalCrossEntropyLoss(this).Id.ToString()); return; } else if (model_type == "nllloss") { response(new NLLLoss(this).Id.ToString()); return; } else if (model_type == "mseloss") { response(new MSELoss(this).Id.ToString()); return; } else if (model_type == "embedding") { response(new Embedding(this, int.Parse(msgObj.tensorIndexParams[1]), int.Parse(msgObj.tensorIndexParams[2])).Id.ToString()); return; } else { Debug.LogFormat("<color=red>Model Type Not Found:</color> {0}", model_type); } } else { //Debug.Log("Getting Model:" + msgObj.objectIndex); Model model = this.GetModel(msgObj.objectIndex); response(model.ProcessMessage(msgObj, this)); return; } response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall); return; } case "controller": { if (msgObj.functionCall == "num_tensors") { response(floatTensorFactory.Count() + ""); return; } else if (msgObj.functionCall == "num_models") { response(models.Count + ""); return; } else if (msgObj.functionCall == "new_tensors_allowed") { Debug.LogFormat("New Tensors Allowed:{0}", msgObj.tensorIndexParams[0]); if (msgObj.tensorIndexParams[0] == "True") { allow_new_tensors = true; } else if (msgObj.tensorIndexParams[0] == "False") { allow_new_tensors = false; } else { throw new Exception("Invalid parameter for new_tensors_allowed. Did you mean true or false?"); } response(allow_new_tensors + ""); return; } else if (msgObj.functionCall == "load_floattensor") { FloatTensor tensor = floatTensorFactory.Create(filepath: msgObj.tensorIndexParams[0], _shader: this.Shader); response(tensor.Id.ToString()); return; } else if (msgObj.functionCall == "set_seed") { Random.InitState(int.Parse(msgObj.tensorIndexParams[0])); response("Random seed set!"); return; } else if (msgObj.functionCall == "concatenate") { List <int> tensor_ids = new List <int>(); for (int i = 1; i < msgObj.tensorIndexParams.Length; i++) { tensor_ids.Add(int.Parse(msgObj.tensorIndexParams[i])); } FloatTensor result = Functional.Concatenate(floatTensorFactory, tensor_ids, int.Parse(msgObj.tensorIndexParams[0])); response(result.Id.ToString()); return; } else if (msgObj.functionCall == "ones") { int[] dims = new int[msgObj.tensorIndexParams.Length]; for (int i = 0; i < msgObj.tensorIndexParams.Length; i++) { dims[i] = int.Parse(msgObj.tensorIndexParams[i]); } FloatTensor result = Functional.Ones(floatTensorFactory, dims); response(result.Id.ToString()); return; } else if (msgObj.functionCall == "randn") { int[] dims = new int[msgObj.tensorIndexParams.Length]; for (int i = 0; i < msgObj.tensorIndexParams.Length; i++) { dims[i] = int.Parse(msgObj.tensorIndexParams[i]); } FloatTensor result = Functional.Randn(floatTensorFactory, dims); response(result.Id.ToString()); return; } else if (msgObj.functionCall == "random") { int[] dims = new int[msgObj.tensorIndexParams.Length]; for (int i = 0; i < msgObj.tensorIndexParams.Length; i++) { dims[i] = int.Parse(msgObj.tensorIndexParams[i]); } FloatTensor result = Functional.Random(floatTensorFactory, dims); response(result.Id.ToString()); return; } else if (msgObj.functionCall == "zeros") { int[] dims = new int[msgObj.tensorIndexParams.Length]; for (int i = 0; i < msgObj.tensorIndexParams.Length; i++) { dims[i] = int.Parse(msgObj.tensorIndexParams[i]); } FloatTensor result = Functional.Zeros(floatTensorFactory, dims); response(result.Id.ToString()); return; } else if (msgObj.functionCall == "model_from_json") { Debug.Log("Loading Model from JSON:"); var json_str = msgObj.tensorIndexParams[0]; var config = JObject.Parse(json_str); Sequential model; if ((string)config["class_name"] == "Sequential") { model = this.BuildSequential(); } else { response("Unity Error: SyftController.processMessage: while Loading model, Class :" + config["class_name"] + " is not implemented"); return; } for (int i = 0; i < config["config"].ToList().Count; i++) { var layer_desc = config["config"][i]; var layer_config_desc = layer_desc["config"]; if ((string)layer_desc["class_name"] == "Linear") { int previous_output_dim; if (i == 0) { previous_output_dim = (int)layer_config_desc["batch_input_shape"][layer_config_desc["batch_input_shape"].ToList().Count - 1]; } else { previous_output_dim = (int)layer_config_desc["units"]; } string[] parameters = { "linear", previous_output_dim.ToString(), layer_config_desc["units"].ToString(), "Xavier" }; Layer layer = this.BuildLinear(parameters); model.AddLayer(layer); string activation_name = layer_config_desc["activation"].ToString(); if (activation_name != "linear") { Layer activation; if (activation_name == "softmax") { parameters = new string[] { activation_name, "1" }; activation = this.BuildSoftmax(parameters); } else if (activation_name == "relu") { activation = this.BuildReLU(); } else { response("Unity Error: SyftController.processMessage: while Loading activations, Activation :" + activation_name + " is not implemented"); return; } model.AddLayer(activation); } } else { response("Unity Error: SyftController.processMessage: while Loading layers, Layer :" + layer_desc["class_name"] + " is not implemented"); return; } } response(model.Id.ToString()); return; } else if (msgObj.functionCall == "from_proto") { Debug.Log("Loading Model from ONNX:"); var filename = msgObj.tensorIndexParams[0]; var input = File.OpenRead(filename); ModelProto modelProto = ModelProto.Parser.ParseFrom(input); Sequential model = this.BuildSequential(); foreach (NodeProto node in modelProto.Graph.Node) { Layer layer; GraphProto g = ONNXTools.GetSubGraphFromNodeAndMainGraph(node, modelProto.Graph); if (node.OpType == "Gemm") { layer = new Linear(this, g); } else if (node.OpType == "Dropout") { layer = new Dropout(this, g); } else if (node.OpType == "Relu") { layer = new ReLU(this, g); } else if (node.OpType == "Softmax") { layer = new Softmax(this, g); } else { response("Unity Error: SyftController.processMessage: Layer not yet implemented for deserialization:"); return; } model.AddLayer(layer); } response(model.Id.ToString()); return; } else if (msgObj.functionCall == "to_proto") { ModelProto model = this.ToProto(msgObj.tensorIndexParams); string filename = msgObj.tensorIndexParams[2]; string type = msgObj.tensorIndexParams[3]; if (type == "json") { response(model.ToString()); } else { using (var output = File.Create(filename)) { model.WriteTo(output); } response(new FileInfo(filename).FullName); } return; } response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall); return; } case "Grid": if (msgObj.functionCall == "learn") { var inputId = int.Parse(msgObj.tensorIndexParams[0]); var targetId = int.Parse(msgObj.tensorIndexParams[1]); response(this.grid.Run(inputId, targetId, msgObj.configurations, owner)); return; } if (msgObj.functionCall == "getResults") { this.grid.GetResults(msgObj.experimentId, response); return; } // like getResults but doesn't pause to wait for results // this function will return right away telling you if // it knows whether or not it is done if (msgObj.functionCall == "checkStatus") { this.grid.CheckStatus(msgObj.experimentId, response); return; } break; default: break; } } catch (Exception e) { Debug.LogFormat("<color=red>{0}</color>", e.ToString()); response("Unity Error: " + e.ToString()); return; } // If not executing createTensor or tensor function, return default error. response("Unity Error: SyftController.processMessage: Command not found:" + msgObj.objectType + ":" + msgObj.functionCall); return; }