public QuantizedAddLayerArgument Convert(QuantizedAdd layer, ConvertContext context) { var inputARange = context.Quantization.Distributions[layer.InputA.Connection.From].Global; var inputBRange = context.Quantization.Distributions[layer.InputB.Connection.From].Global; var outputRange = context.Quantization.Distributions[layer.Output].Global; (var sa, var ba) = inputARange.GetScaleBias(8); (var sb, var bb) = inputBRange.GetScaleBias(8); (var so, var bo) = outputRange.GetScaleBias(8); (var mulA, var shiftA) = Quantizer.ExtractValueAndShift(sb, 32, 32); (var mulB, var shiftB) = Quantizer.ExtractValueAndShift(sa, 32, 32); (var mulO, var shiftO) = Quantizer.ExtractValueAndShift(so / (sa * sb), 32, 32); return(new QuantizedAddLayerArgument { InputAOffset = (int)ba, InputAMul = (int)Math.Round(mulA), InputAShift = shiftA, InputBOffset = (int)bb, InputBMul = (int)Math.Round(mulB), InputBShift = shiftB, OutputOffset = (int)(-bo), OutputMul = (int)Math.Round(mulO), OutputShift = shiftO, Count = (uint)(layer.Output.Dimensions.GetSize()) }); }
public void Infer(QuantizedAdd layer, QuantizedAddLayerArgument argument, InferenceContext context) { var inputAAlloc = context.MainMemoryMap[layer.InputA.Connection.From]; var inputBAlloc = context.MainMemoryMap[layer.InputB.Connection.From]; var outputAlloc = context.MainMemoryMap[layer.Output]; argument.Flags = K210LayerFlags.MainMemoryOutput; argument.MainMemoryInputAAddress = inputAAlloc.GetAddress(); argument.MainMemoryInputBAddress = inputBAlloc.GetAddress(); argument.MainMemoryOutputAddress = outputAlloc.GetAddress(); }
public override void Process(TransformContext context) { var add = (Add)context.MatchedLayers[0]; var inputA = add.InputA.Connection.From.Owner.InputConnectors[0].Connection.From; var inputB = add.InputB.Connection.From.Owner.InputConnectors[0].Connection.From; var output = add.Output; var quantAdd = new QuantizedAdd(add.InputA.Dimensions, add.InputB.Dimensions); var dequant = new Dequantize(quantAdd.Output.Dimensions); quantAdd.InputA.SetConnection(inputA); quantAdd.InputB.SetConnection(inputB); dequant.Input.SetConnection(quantAdd.Output); var oldOuts = output.Connections.Select(o => o.To).ToList(); foreach (var oldOut in oldOuts) { oldOut.SetConnection(dequant.Output); } }