Beispiel #1
0
        // See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
        public override GraphProto GetProto(int inputTensorId, SyftController ctrl)
        {
            FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId);

            this.Forward(input_tensor);

            NodeProto node = new NodeProto
            {
                Input     = { inputTensorId.ToString(), _weights.Id.ToString() },
                Output    = { activation.ToString() },
                Name      = this.name,
                OpType    = "Gemm",
                DocString = ""
            };

            if (_biased)
            {
                node.Input.Add(_bias.Id.ToString());
            }

            node.Attribute.Add(new AttributeProto {
                Name = "alpha",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = 1.0f
            });
            node.Attribute.Add(new AttributeProto {
                Name = "beta",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = 1.0f
            });
            node.Attribute.Add(new AttributeProto {
                Name = "broadcast",
                Type = AttributeProto.Types.AttributeType.Int,
                I    = 1
            });

            TensorProto w_init = _weights.GetProto();

            ValueInfoProto input_info = input_tensor.GetValueInfoProto();
            ValueInfoProto w_info     = _weights.GetValueInfoProto();

            GraphProto g = new GraphProto
            {
                Name        = Guid.NewGuid().ToString("N"),
                Node        = { node },
                Initializer = { w_init },
                Input       = { input_info, w_info },
                Output      = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() },
            };

            if (_biased)
            {
                TensorProto    b_init = _bias.GetProto();
                ValueInfoProto b_info = _bias.GetValueInfoProto();
                g.Initializer.Add(b_init);
                g.Input.Add(b_info);
            }
            else
            {
                // The Gemm schema, must have 3 inputs (must have a bias)
                float[] tmpData = new float[1] {
                    0
                };
                int[] tmpDims = new int[1] {
                    1
                };
                FloatTensor tmpBias = ctrl.floatTensorFactory.Create(_data: tmpData, _shape: tmpDims, _autograd: false, _keepgrads: false);
                g.Initializer.Add(tmpBias.GetProto());
                g.Input.Add(tmpBias.GetValueInfoProto());
                g.Node[0].Input.Add(tmpBias.Id.ToString());
            }

            return(g);
        }