Exemplo n.º 1
0
        public void Infer(FullyConnected layer, FullyConnectedLayerArgument argument, InferenceContext context)
        {
            var inputAlloc = context.MainMemoryMap[layer.Input.Connection.From];
            var outputAlloc = context.MainMemoryMap[layer.Output];

            argument.MainMemoryInputAddress = inputAlloc.GetAddress();
            argument.MainMemoryOutputAddress = outputAlloc.GetAddress();
        }
Exemplo n.º 2
0
        public void Forward(FullyConnectedLayerArgument argument, ForwardContext context)
        {
            var src = MemoryMarshal.Cast<byte, float>(context.GetMainRamAt((int)argument.MainMemoryInputAddress));
            var dest = MemoryMarshal.Cast<byte, float>(context.GetMainRamAt((int)argument.MainMemoryOutputAddress));

            for (int oc = 0; oc < argument.OutputChannels; oc++)
            {
                var weights = new ReadOnlySpan<float>(argument.Weights, (int)(oc * argument.InputChannels), (int)argument.InputChannels);
                float sum = 0;
                for (int ic = 0; ic < argument.InputChannels; ic++)
                    sum += src[ic] * weights[ic];
                dest[oc] = sum + argument.Bias[oc];
            }
        }
Exemplo n.º 3
0
        public FullyConnectedLayerArgument DeserializeBin(int offset, K210BinDeserializeContext context)
        {
            var sr = context.GetReaderAt(offset);
            var argument = new FullyConnectedLayerArgument();
            argument.Flags = sr.Read<K210LayerFlags>();
            argument.MainMemoryInputAddress = sr.Read<uint>();
            argument.MainMemoryOutputAddress = sr.Read<uint>();
            argument.InputChannels = sr.Read<uint>();
            argument.OutputChannels = sr.Read<uint>();
            argument.Activation = sr.Read<ActivationFunctionType>();
            argument.Weights = MemoryMarshal.Cast<byte, float>(sr.ReadAsSpan((int)(argument.InputChannels * argument.OutputChannels * 4))).ToArray();
            argument.Bias = MemoryMarshal.Cast<byte, float>(sr.ReadAsSpan((int)(argument.OutputChannels * 4))).ToArray();

            return argument;
        }