Ejemplo n.º 1
0
        public override GraphProto GetProto(int inputTensorId, SyftController ctrl)
        {
            FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId);

            if (activation != null)
            {
                this.Forward(input_tensor);
            }

            NodeProto node = new NodeProto
            {
                Input     = { inputTensorId.ToString() },
                Output    = { activation.ToString() },
                OpType    = "Softmax",
                Attribute = { new AttributeProto {
                                  Name = "axis",
                                  Type = AttributeProto.Types.AttributeType.Int,
                                  I    = this.dim
                              } }
            };

            ValueInfoProto input_info = input_tensor.GetValueInfoProto();

            GraphProto g = new GraphProto
            {
                Name        = Guid.NewGuid().ToString("N"),
                Node        = { node },
                Initializer = {  },
                Input       = { input_info },
                Output      = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() },
            };

            return(g);
        }
Ejemplo n.º 2
0
        public static ModelProto MakeModel(List <NodeProto> nodes, string producerName, string name,
                                           string domain, string producerVersion, long modelVersion, List <ModelArgs> inputs,
                                           List <ModelArgs> outputs, List <ModelArgs> intermediateValues, List <TensorProto> initializers)
        {
            Contracts.CheckValue(nodes, nameof(nodes));
            Contracts.CheckValue(inputs, nameof(inputs));
            Contracts.CheckValue(outputs, nameof(outputs));
            Contracts.CheckValue(intermediateValues, nameof(intermediateValues));
            Contracts.CheckValue(initializers, nameof(initializers));
            Contracts.CheckNonEmpty(producerName, nameof(producerName));
            Contracts.CheckNonEmpty(name, nameof(name));
            Contracts.CheckNonEmpty(domain, nameof(domain));
            Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion));

            var model = new ModelProto();

            model.Domain          = domain;
            model.ProducerName    = producerName;
            model.ProducerVersion = producerVersion;
            model.IrVersion       = (long)UniversalModelFormat.Onnx.Version.IrVersion;
            model.ModelVersion    = modelVersion;
            model.OpsetImport.Add(new OperatorSetIdProto()
            {
                Domain = "ai.onnx.ml", Version = 1
            });
            model.OpsetImport.Add(new OperatorSetIdProto()
            {
                Domain = "", Version = 7
            });
            model.Graph = new GraphProto();
            var graph = model.Graph;

            graph.Node.Add(nodes);
            graph.Name = name;
            foreach (var arg in inputs)
            {
                var val = new ValueInfoProto();
                graph.Input.Add(val);
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            }

            foreach (var arg in outputs)
            {
                var val = new ValueInfoProto();
                graph.Output.Add(val);
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            }

            foreach (var arg in intermediateValues)
            {
                var val = new ValueInfoProto();
                graph.ValueInfo.Add(val);
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            }

            graph.Initializer.AddRange(initializers);

            return(model);
        }
Ejemplo n.º 3
0
        private static ValueInfoProto MakeValue(ValueInfoProto value, string name, TensorProto.Types.DataType dataType,
            List<long> dims, List<bool> dimsParam)
        {
            Contracts.CheckValue(value, nameof(value));
            Contracts.CheckNonEmpty(name, nameof(name));

            value.Name = name;
            if (value.Type == null)
                value.Type = new TypeProto();

            MakeType(value.Type, dataType, dims, dimsParam);
            return value;
        }
Ejemplo n.º 4
0
        private static Layer ParseInputVariableNode(ValueInfoProto node, Dictionary <string, TensorType> types, Dictionary <string, OutputConnector> outputConns)
        {
            var desiredType = TensorType.From(node.Type);

            types.Add(node.Name, desiredType);

            if (desiredType.ElementType == typeof(double))
            {
                var layer = new InputVariable <double>(node.Name, desiredType.Dimensions);
                outputConns.Add(node.Name, layer.Value);
                return(layer);
            }
            else
            {
                throw new NotSupportedException();
            }
        }
Ejemplo n.º 5
0
        internal static void SetDim(ValueInfoProto valueInfo,
                                    int dimIndex, DimParamOrValue dimParamOrValue)
        {
            var shape = valueInfo.Type.TensorType.Shape;
            var dims  = shape.Dim;
            var dim   = dims[dimIndex];

            if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue)
            {
                // TODO: Should perhaps be parameter that says
                //       bool shouldSetDimFor(dim)
                if (dim.DimValue == 1)
                {
                    SetDim(dim, dimParamOrValue);
                }
            }
        }
Ejemplo n.º 6
0
        // See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Dropout
        public override GraphProto GetProto(int inputTensorId, SyftController ctrl)
        {
            FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId);

            this.Forward(input_tensor);

            NodeProto node = new NodeProto
            {
                Input     = { inputTensorId.ToString() },
                Output    = { activation.ToString(), _mask_source.Id.ToString() },
                Name      = this.name,
                OpType    = "Dropout",
                DocString = ""
            };

            node.Attribute.Add(new AttributeProto {
                Name = "ratio",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = this.rate
            });
            node.Attribute.Add(new AttributeProto {
                Name = "is_test",
                Type = AttributeProto.Types.AttributeType.Int,
                I    = 1
            });

            ValueInfoProto input_info = input_tensor.GetValueInfoProto();

            GraphProto g = new GraphProto
            {
                Name        = Guid.NewGuid().ToString("N"),
                Node        = { node },
                Initializer = {  },
                Input       = { input_info },
                Output      =
                {
                    ctrl.floatTensorFactory.Get(activation).GetValueInfoProto(),
                    _mask_source.GetValueInfoProto()
                },
            };

            return(g);
        }
Ejemplo n.º 7
0
        public static ModelProto MakeModel(List <NodeProto> nodes, string producerName, string name, string domain, List <ModelArgs> inputs,
                                           List <ModelArgs> outputs, List <ModelArgs> intermediateValues)
        {
            Contracts.CheckValue(nodes, nameof(nodes));
            Contracts.CheckValue(inputs, nameof(inputs));
            Contracts.CheckValue(outputs, nameof(outputs));
            Contracts.CheckValue(outputs, nameof(intermediateValues));
            Contracts.CheckNonEmpty(producerName, nameof(producerName));
            Contracts.CheckNonEmpty(name, nameof(name));
            Contracts.CheckNonEmpty(domain, nameof(domain));

            var model = new ModelProto();

            model.Domain       = domain;
            model.ProducerName = producerName;
            model.Graph        = new GraphProto();
            var graph = model.Graph;

            graph.Node.Add(nodes);
            graph.Name = name;
            foreach (var arg in inputs)
            {
                var val = new ValueInfoProto();
                graph.Input.Add(val);
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            }

            foreach (var arg in outputs)
            {
                var val = new ValueInfoProto();
                graph.Output.Add(val);
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            }

            foreach (var arg in intermediateValues)
            {
                var val = new ValueInfoProto();
                graph.ValueInfo.Add(val);
                MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
            }

            return(model);
        }
Ejemplo n.º 8
0
        public ValueInfoProto GetValueInfoProto()
        {
            ValueInfoProto i = new ValueInfoProto
            {
                Name = this.Id.ToString(),
                Type = new TypeProto
                {
                    TensorType = new TypeProto.Types.Tensor
                    {
                        ElemType = TensorProto.Types.DataType.Int32,
                        Shape    = new TensorShapeProto
                        {
                            Dim = { Array.ConvertAll(this.Shape, val => new TensorShapeProto.Types.Dimension {
                                    DimValue = val
                                }) }
                        }
                    }
                }
            };

            return(i);
        }
Ejemplo n.º 9
0
        // See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
        public override GraphProto GetProto(int inputTensorId, SyftController ctrl)
        {
            FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId);

            this.Forward(input_tensor);

            NodeProto node = new NodeProto
            {
                Input     = { inputTensorId.ToString(), _weights.Id.ToString() },
                Output    = { activation.ToString() },
                Name      = this.name,
                OpType    = "Gemm",
                DocString = ""
            };

            if (_biased)
            {
                node.Input.Add(_bias.Id.ToString());
            }

            node.Attribute.Add(new AttributeProto {
                Name = "alpha",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = 1.0f
            });
            node.Attribute.Add(new AttributeProto {
                Name = "beta",
                Type = AttributeProto.Types.AttributeType.Float,
                F    = 1.0f
            });
            node.Attribute.Add(new AttributeProto {
                Name = "broadcast",
                Type = AttributeProto.Types.AttributeType.Int,
                I    = 1
            });

            TensorProto w_init = _weights.GetProto();

            ValueInfoProto input_info = input_tensor.GetValueInfoProto();
            ValueInfoProto w_info     = _weights.GetValueInfoProto();

            GraphProto g = new GraphProto
            {
                Name        = Guid.NewGuid().ToString("N"),
                Node        = { node },
                Initializer = { w_init },
                Input       = { input_info, w_info },
                Output      = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() },
            };

            if (_biased)
            {
                TensorProto    b_init = _bias.GetProto();
                ValueInfoProto b_info = _bias.GetValueInfoProto();
                g.Initializer.Add(b_init);
                g.Input.Add(b_info);
            }
            else
            {
                // The Gemm schema, must have 3 inputs (must have a bias)
                float[] tmpData = new float[1] {
                    0
                };
                int[] tmpDims = new int[1] {
                    1
                };
                FloatTensor tmpBias = ctrl.floatTensorFactory.Create(_data: tmpData, _shape: tmpDims, _autograd: false, _keepgrads: false);
                g.Initializer.Add(tmpBias.GetProto());
                g.Input.Add(tmpBias.GetValueInfoProto());
                g.Node[0].Input.Add(tmpBias.Id.ToString());
            }

            return(g);
        }
Ejemplo n.º 10
0
 public static ValueInfoProto Initialize(this ValueInfoProto ValueInfo, string name, TypeProto type)
 {
     ValueInfo.Name = name;
     ValueInfo.Type = type;
     return(ValueInfo);
 }
Ejemplo n.º 11
0
 public static string Print(this ValueInfoProto value)
 {
     return($"{value.Name}: {value.Type.Print()}");
 }