// See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm public override GraphProto GetProto(int inputTensorId, SyftController ctrl) { FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId); this.Forward(input_tensor); NodeProto node = new NodeProto { Input = { inputTensorId.ToString(), _weights.Id.ToString() }, Output = { activation.ToString() }, Name = this.name, OpType = "Gemm", DocString = "" }; if (_biased) { node.Input.Add(_bias.Id.ToString()); } node.Attribute.Add(new AttributeProto { Name = "alpha", Type = AttributeProto.Types.AttributeType.Float, F = 1.0f }); node.Attribute.Add(new AttributeProto { Name = "beta", Type = AttributeProto.Types.AttributeType.Float, F = 1.0f }); node.Attribute.Add(new AttributeProto { Name = "broadcast", Type = AttributeProto.Types.AttributeType.Int, I = 1 }); TensorProto w_init = _weights.GetProto(); ValueInfoProto input_info = input_tensor.GetValueInfoProto(); ValueInfoProto w_info = _weights.GetValueInfoProto(); GraphProto g = new GraphProto { Name = Guid.NewGuid().ToString("N"), Node = { node }, Initializer = { w_init }, Input = { input_info, w_info }, Output = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() }, }; if (_biased) { TensorProto b_init = _bias.GetProto(); ValueInfoProto b_info = _bias.GetValueInfoProto(); g.Initializer.Add(b_init); g.Input.Add(b_info); } else { // The Gemm schema, must have 3 inputs (must have a bias) float[] tmpData = new float[1] { 0 }; int[] tmpDims = new int[1] { 1 }; FloatTensor tmpBias = ctrl.floatTensorFactory.Create(_data: tmpData, _shape: tmpDims, _autograd: false, _keepgrads: false); g.Initializer.Add(tmpBias.GetProto()); g.Input.Add(tmpBias.GetValueInfoProto()); g.Node[0].Input.Add(tmpBias.Id.ToString()); } return(g); }