public GlobalAveragePool2dLayerArgument Convert(GlobalAveragePool layer, ConvertContext context) { return(new GlobalAveragePool2dLayerArgument { KernelSize = (uint)(layer.Input.Dimensions[2] * layer.Input.Dimensions[3]), Channels = (uint)(layer.Input.Dimensions[1]) }); }
public void Infer(GlobalAveragePool layer, GlobalAveragePool2dLayerArgument argument, InferenceContext context) { var inputAlloc = context.MainMemoryMap[layer.Input.Connection.From]; var outputAlloc = context.MainMemoryMap[layer.Output]; argument.Flags = K210LayerFlags.MainMemoryOutput; argument.MainMemoryInputAddress = inputAlloc.GetAddress(); argument.MainMemoryOutputAddress = outputAlloc.GetAddress(); }
public override void Process(TransformContext context) { var avgPool = (AveragePool2d)context.MatchedLayers[0]; var input = avgPool.Input.Connection.From; var output = avgPool.Output; avgPool.Input.ClearConnection(); var newAvg = new GlobalAveragePool(input.Dimensions); newAvg.Input.SetConnection(input); var oldOuts = output.Connections.Select(o => o.To).ToList(); foreach (var oldOut in oldOuts) { oldOut.SetConnection(newAvg.Output); } }
private Layer ConvertMean(tflite.Operator op) { var inputs = op.GetInputsArray(); var input = _graph.Tensors(inputs[0]).Value; var axes = _model.GetTensor <int>(_graph.Tensors(inputs[1]).Value); if (axes.ToArray().SequenceEqual(new[] { 1, 2 })) { var layer = new GlobalAveragePool(input.GetShapeArray().ToNCHW()); _inputs.Add(layer.Input, inputs[0]); var reshape = new Reshape(layer.Output.Dimensions, new[] { -1, layer.Output.Dimensions[1] }); reshape.Input.SetConnection(layer.Output); _outputs.Add(op.Outputs(0), layer.Output); return(reshape); } else { throw new LayerNotSupportedException(op.ToString(), "Only [1,2] axis mean is supported"); } }