public override GraphProto GetProto(int inputTensorId, SyftController ctrl) { FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId); if (activation != null) { this.Forward(input_tensor); } NodeProto node = new NodeProto { Input = { inputTensorId.ToString() }, Output = { activation.ToString() }, OpType = "Softmax", Attribute = { new AttributeProto { Name = "axis", Type = AttributeProto.Types.AttributeType.Int, I = this.dim } } }; ValueInfoProto input_info = input_tensor.GetValueInfoProto(); GraphProto g = new GraphProto { Name = Guid.NewGuid().ToString("N"), Node = { node }, Initializer = { }, Input = { input_info }, Output = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() }, }; return(g); }
public static ModelProto MakeModel(List <NodeProto> nodes, string producerName, string name, string domain, string producerVersion, long modelVersion, List <ModelArgs> inputs, List <ModelArgs> outputs, List <ModelArgs> intermediateValues, List <TensorProto> initializers) { Contracts.CheckValue(nodes, nameof(nodes)); Contracts.CheckValue(inputs, nameof(inputs)); Contracts.CheckValue(outputs, nameof(outputs)); Contracts.CheckValue(intermediateValues, nameof(intermediateValues)); Contracts.CheckValue(initializers, nameof(initializers)); Contracts.CheckNonEmpty(producerName, nameof(producerName)); Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckNonEmpty(domain, nameof(domain)); Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion)); var model = new ModelProto(); model.Domain = domain; model.ProducerName = producerName; model.ProducerVersion = producerVersion; model.IrVersion = (long)UniversalModelFormat.Onnx.Version.IrVersion; model.ModelVersion = modelVersion; model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 }); model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 7 }); model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); graph.Name = name; foreach (var arg in inputs) { var val = new ValueInfoProto(); graph.Input.Add(val); MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams); } foreach (var arg in outputs) { var val = new ValueInfoProto(); graph.Output.Add(val); MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams); } foreach (var arg in intermediateValues) { var val = new ValueInfoProto(); graph.ValueInfo.Add(val); MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams); } graph.Initializer.AddRange(initializers); return(model); }
private static ValueInfoProto MakeValue(ValueInfoProto value, string name, TensorProto.Types.DataType dataType, List<long> dims, List<bool> dimsParam) { Contracts.CheckValue(value, nameof(value)); Contracts.CheckNonEmpty(name, nameof(name)); value.Name = name; if (value.Type == null) value.Type = new TypeProto(); MakeType(value.Type, dataType, dims, dimsParam); return value; }
private static Layer ParseInputVariableNode(ValueInfoProto node, Dictionary <string, TensorType> types, Dictionary <string, OutputConnector> outputConns) { var desiredType = TensorType.From(node.Type); types.Add(node.Name, desiredType); if (desiredType.ElementType == typeof(double)) { var layer = new InputVariable <double>(node.Name, desiredType.Dimensions); outputConns.Add(node.Name, layer.Value); return(layer); } else { throw new NotSupportedException(); } }
internal static void SetDim(ValueInfoProto valueInfo, int dimIndex, DimParamOrValue dimParamOrValue) { var shape = valueInfo.Type.TensorType.Shape; var dims = shape.Dim; var dim = dims[dimIndex]; if (dim.ValueCase == TensorShapeProto.Types.Dimension.ValueOneofCase.DimValue) { // TODO: Should perhaps be parameter that says // bool shouldSetDimFor(dim) if (dim.DimValue == 1) { SetDim(dim, dimParamOrValue); } } }
// See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Dropout public override GraphProto GetProto(int inputTensorId, SyftController ctrl) { FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId); this.Forward(input_tensor); NodeProto node = new NodeProto { Input = { inputTensorId.ToString() }, Output = { activation.ToString(), _mask_source.Id.ToString() }, Name = this.name, OpType = "Dropout", DocString = "" }; node.Attribute.Add(new AttributeProto { Name = "ratio", Type = AttributeProto.Types.AttributeType.Float, F = this.rate }); node.Attribute.Add(new AttributeProto { Name = "is_test", Type = AttributeProto.Types.AttributeType.Int, I = 1 }); ValueInfoProto input_info = input_tensor.GetValueInfoProto(); GraphProto g = new GraphProto { Name = Guid.NewGuid().ToString("N"), Node = { node }, Initializer = { }, Input = { input_info }, Output = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto(), _mask_source.GetValueInfoProto() }, }; return(g); }
public static ModelProto MakeModel(List <NodeProto> nodes, string producerName, string name, string domain, List <ModelArgs> inputs, List <ModelArgs> outputs, List <ModelArgs> intermediateValues) { Contracts.CheckValue(nodes, nameof(nodes)); Contracts.CheckValue(inputs, nameof(inputs)); Contracts.CheckValue(outputs, nameof(outputs)); Contracts.CheckValue(outputs, nameof(intermediateValues)); Contracts.CheckNonEmpty(producerName, nameof(producerName)); Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckNonEmpty(domain, nameof(domain)); var model = new ModelProto(); model.Domain = domain; model.ProducerName = producerName; model.Graph = new GraphProto(); var graph = model.Graph; graph.Node.Add(nodes); graph.Name = name; foreach (var arg in inputs) { var val = new ValueInfoProto(); graph.Input.Add(val); MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams); } foreach (var arg in outputs) { var val = new ValueInfoProto(); graph.Output.Add(val); MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams); } foreach (var arg in intermediateValues) { var val = new ValueInfoProto(); graph.ValueInfo.Add(val); MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams); } return(model); }
public ValueInfoProto GetValueInfoProto() { ValueInfoProto i = new ValueInfoProto { Name = this.Id.ToString(), Type = new TypeProto { TensorType = new TypeProto.Types.Tensor { ElemType = TensorProto.Types.DataType.Int32, Shape = new TensorShapeProto { Dim = { Array.ConvertAll(this.Shape, val => new TensorShapeProto.Types.Dimension { DimValue = val }) } } } } }; return(i); }
// See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm public override GraphProto GetProto(int inputTensorId, SyftController ctrl) { FloatTensor input_tensor = ctrl.floatTensorFactory.Get(inputTensorId); this.Forward(input_tensor); NodeProto node = new NodeProto { Input = { inputTensorId.ToString(), _weights.Id.ToString() }, Output = { activation.ToString() }, Name = this.name, OpType = "Gemm", DocString = "" }; if (_biased) { node.Input.Add(_bias.Id.ToString()); } node.Attribute.Add(new AttributeProto { Name = "alpha", Type = AttributeProto.Types.AttributeType.Float, F = 1.0f }); node.Attribute.Add(new AttributeProto { Name = "beta", Type = AttributeProto.Types.AttributeType.Float, F = 1.0f }); node.Attribute.Add(new AttributeProto { Name = "broadcast", Type = AttributeProto.Types.AttributeType.Int, I = 1 }); TensorProto w_init = _weights.GetProto(); ValueInfoProto input_info = input_tensor.GetValueInfoProto(); ValueInfoProto w_info = _weights.GetValueInfoProto(); GraphProto g = new GraphProto { Name = Guid.NewGuid().ToString("N"), Node = { node }, Initializer = { w_init }, Input = { input_info, w_info }, Output = { ctrl.floatTensorFactory.Get(activation).GetValueInfoProto() }, }; if (_biased) { TensorProto b_init = _bias.GetProto(); ValueInfoProto b_info = _bias.GetValueInfoProto(); g.Initializer.Add(b_init); g.Input.Add(b_info); } else { // The Gemm schema, must have 3 inputs (must have a bias) float[] tmpData = new float[1] { 0 }; int[] tmpDims = new int[1] { 1 }; FloatTensor tmpBias = ctrl.floatTensorFactory.Create(_data: tmpData, _shape: tmpDims, _autograd: false, _keepgrads: false); g.Initializer.Add(tmpBias.GetProto()); g.Input.Add(tmpBias.GetValueInfoProto()); g.Node[0].Input.Add(tmpBias.Id.ToString()); } return(g); }
public static ValueInfoProto Initialize(this ValueInfoProto ValueInfo, string name, TypeProto type) { ValueInfo.Name = name; ValueInfo.Type = type; return(ValueInfo); }
public static string Print(this ValueInfoProto value) { return($"{value.Name}: {value.Type.Print()}"); }