示例#1
0
        public (NSArray <MPSImage> Inputs, NSArray <MPSState> Losses) GetRandomBatch(IMTLDevice device, int batchSize)
        {
            var trainImageDesc = MPSImageDescriptor.GetImageDescriptor(
                MPSImageFeatureChannelFormat.Unorm8,
                ImageSize, ImageSize, 1,
                1,
                MTLTextureUsage.ShaderWrite | MTLTextureUsage.ShaderRead);

            var trainBatch     = new List <MPSImage> ();
            var lossStateBatch = new List <MPSState> ();

            unsafe
            {
                fixed(byte *imagesPointer = imagesData)
                fixed(byte *labelsPointer = labelsData)
                {
                    for (var i = 0; i < batchSize; i++)
                    {
                        var randomIndex = random.Next(numImages);

                        var trainImage = new MPSImage(device, trainImageDesc)
                        {
                            Label = "TrainImage" + i
                        };
                        trainBatch.Add(trainImage);
                        var trainImagePointer = imagesPointer + ImagesPrefixSize + randomIndex * ImageSize * ImageSize;
                        trainImage.WriteBytes((IntPtr)trainImagePointer, MPSDataLayout.HeightPerWidthPerFeatureChannels, 0);

                        var labelPointer = labelsPointer + LabelsPrefixSize + randomIndex;
                        var labelsValues = new float[12];
                        labelsValues[*labelPointer] = 1;

                        fixed(void *p = labelsValues)
                        {
                            using var data = NSData.FromBytes((IntPtr)p, 12 * sizeof(float));
                            var desc = MPSCnnLossDataDescriptor.Create(
                                data, MPSDataLayout.HeightPerWidthPerFeatureChannels, new MTLSize(1, 1, 12));
                            var lossState = new MPSCnnLossLabels(device, desc);

                            lossStateBatch.Add(lossState);
                        }
                    }
                }
            }

            return(NSArray <MPSImage> .FromNSObjects(trainBatch.ToArray()),
                   NSArray <MPSState> .FromNSObjects(lossStateBatch.ToArray()));
        }
示例#2
0
        public void Metal()
        {
            TestRuntime.AssertDevice();
            TestRuntime.AssertXcodeVersion(10, 0);

            device = MTLDevice.SystemDefault;
            // some older hardware won't have a default
            if (device == null || !MPSKernel.Supports(device))
            {
                Assert.Inconclusive("Metal is not supported");
            }

            cache = NSArray <MPSImage> .FromNSObjects(
                new MPSImage (device, MPSImageDescriptor.GetImageDescriptor(MPSImageFeatureChannelFormat.Float32, 224, 224, 3)),
                new MPSImage (device, MPSImageDescriptor.GetImageDescriptor(MPSImageFeatureChannelFormat.Float32, 224, 224, 3)),
                new MPSImage (device, MPSImageDescriptor.GetImageDescriptor(MPSImageFeatureChannelFormat.Float32, 224, 224, 3)),
                new MPSImage (device, MPSImageDescriptor.GetImageDescriptor(MPSImageFeatureChannelFormat.Float32, 224, 224, 3)),
                new MPSImage (device, MPSImageDescriptor.GetImageDescriptor(MPSImageFeatureChannelFormat.Float32, 224, 224, 3))
                );
        }