Ejemplo n.º 1
0
 public GlobalAveragePool2dLayerArgument Convert(GlobalAveragePool layer, ConvertContext context)
 {
     return(new GlobalAveragePool2dLayerArgument
     {
         KernelSize = (uint)(layer.Input.Dimensions[2] * layer.Input.Dimensions[3]),
         Channels = (uint)(layer.Input.Dimensions[1])
     });
 }
Ejemplo n.º 2
0
        public void Infer(GlobalAveragePool layer, GlobalAveragePool2dLayerArgument argument, InferenceContext context)
        {
            var inputAlloc  = context.MainMemoryMap[layer.Input.Connection.From];
            var outputAlloc = context.MainMemoryMap[layer.Output];

            argument.Flags = K210LayerFlags.MainMemoryOutput;
            argument.MainMemoryInputAddress  = inputAlloc.GetAddress();
            argument.MainMemoryOutputAddress = outputAlloc.GetAddress();
        }
Ejemplo n.º 3
0
        public override void Process(TransformContext context)
        {
            var avgPool = (AveragePool2d)context.MatchedLayers[0];
            var input   = avgPool.Input.Connection.From;
            var output  = avgPool.Output;

            avgPool.Input.ClearConnection();

            var newAvg = new GlobalAveragePool(input.Dimensions);

            newAvg.Input.SetConnection(input);
            var oldOuts = output.Connections.Select(o => o.To).ToList();

            foreach (var oldOut in oldOuts)
            {
                oldOut.SetConnection(newAvg.Output);
            }
        }
Ejemplo n.º 4
0
        private Layer ConvertMean(tflite.Operator op)
        {
            var inputs = op.GetInputsArray();
            var input  = _graph.Tensors(inputs[0]).Value;
            var axes   = _model.GetTensor <int>(_graph.Tensors(inputs[1]).Value);

            if (axes.ToArray().SequenceEqual(new[] { 1, 2 }))
            {
                var layer = new GlobalAveragePool(input.GetShapeArray().ToNCHW());
                _inputs.Add(layer.Input, inputs[0]);
                var reshape = new Reshape(layer.Output.Dimensions, new[] { -1, layer.Output.Dimensions[1] });
                reshape.Input.SetConnection(layer.Output);
                _outputs.Add(op.Outputs(0), layer.Output);
                return(reshape);
            }
            else
            {
                throw new LayerNotSupportedException(op.ToString(), "Only [1,2] axis mean is supported");
            }
        }