Exemplo n.º 1
0
        public override List <ONNXTensor> Run(IEnumerable <string> outputs, Dictionary <string, ONNXTensor> feedDict)
        {
            var binding = new LearningModelBinding(sess);

            foreach (var item in feedDict)
            {
                object tensor;
                if (!IsFP16)
                {
                    tensor = TensorFloat.CreateFromArray(item.Value.Shape, item.Value.Buffer);
                    if (IsGPU)
                    {
                        //TODO: Move SoftwareTensor to DX12Tensor
                        tensor = MoveToGPU((TensorFloat)tensor);
                    }
                }
                else
                {
                    tensor = TensorFloat16Bit.CreateFromArray(item.Value.Shape, item.Value.Buffer);
                }
                binding.Bind(item.Key, tensor);
            }

            var result = sess.Evaluate(binding, $"eval{++evalCount}");

            var ret = new List <ONNXTensor>();

            foreach (var item in outputs)
            {
                var tensor = result.Outputs[item] as TensorFloat;
                var vector = tensor.GetAsVectorView().ToArray();
                ret.Add(new ONNXTensor()
                {
                    Buffer = vector, Shape = tensor.Shape.ToArray()
                });
            }

            return(ret);
        }
        public static async Task <SoftwareBitmap> GetImageFromTensorFloatDataAsync(TensorFloat16Bit outputImage, uint imageWidth, uint imageHeigth, double DpiX, double DpiY)
        {
            var outData      = outputImage.GetAsVectorView().ToArray();
            var lineLength   = outData.Length / 3;
            var newImageData = new byte[4 * lineLength];

            if (outData.Length > 0)
            {
                var bData = outData.Take(lineLength).ToArray();
                var rData = outData.Skip(lineLength * 2).Take(lineLength).ToArray();
                var gData = outData.Skip(lineLength).Take(lineLength).ToArray();


                for (var i = 0; i < lineLength; i++)
                {
                    var b = (bData[i]);//* 255);
                    if (b < 0)
                    {
                        b = 0;
                    }
                    else if (b > 255)
                    {
                        b = 255;
                    }
                    newImageData[i * 4 + 0] = (byte)b;
                    var g = (gData[i]);// * 255);
                    if (g < 0)
                    {
                        g = 0;
                    }
                    else if (g > 255)
                    {
                        g = 255;
                    }
                    newImageData[i * 4 + 1] = (byte)g;
                    var r = (rData[i]);// * 255);
                    if (r < 0)
                    {
                        r = 0;
                    }
                    else if (r > 255)
                    {
                        r = 255;
                    }
                    newImageData[i * 4 + 2] = (byte)r;
                    newImageData[i * 4 + 3] = 255;
                }
            }

            using (var ms = new InMemoryRandomAccessStream())
            {
                var encoder = await BitmapEncoder.CreateAsync(BitmapEncoder.JpegEncoderId, ms);

                encoder.SetPixelData(BitmapPixelFormat.Bgra8, BitmapAlphaMode.Ignore, imageWidth, imageHeigth, DpiX, DpiY, newImageData);
                await encoder.FlushAsync();

                var decoder = await BitmapDecoder.CreateAsync(ms);

                var sbmp = await decoder.GetSoftwareBitmapAsync();

                //sbmp = SoftwareBitmap.Convert(sbmp, BitmapPixelFormat.Bgra8, BitmapAlphaMode.Ignore);
                return(sbmp);
            }
        }