public ConvolutionWeights(CompileOptions options, int inChannels, int outChannels, int kernelSize, int stride, bool bias, string label, int seed)
        {
            this.options = options.ExecutionOptions;

            descriptor = MPSCnnConvolutionDescriptor.CreateCnnConvolutionDescriptor(
                (System.nuint)kernelSize, (System.nuint)kernelSize,
                (System.nuint)inChannels,
                (System.nuint)outChannels);
            descriptor.StrideInPixelsX = (nuint)stride;
            descriptor.StrideInPixelsY = (nuint)stride;
            this.bias  = bias;
            this.label = string.IsNullOrEmpty(label) ? Guid.NewGuid().ToString() : label;

            var lenWeights = inChannels * kernelSize * kernelSize * outChannels;

            var vDescWeights = VectorDescriptor(lenWeights);

            weightVectors = new OptimizableVector(options.Device, vDescWeights, 0.0f);

            var vDescBiases = VectorDescriptor(outChannels);

            biasVectors = new OptimizableVector(options.Device, vDescBiases, 0.1f);

            RandomizeWeights((nuint)seed);

            convWtsAndBias  = new MPSCnnConvolutionWeightsAndBiasesState(weightVectors.Value.Data, biasVectors.Value.Data);
            momentumVectors = NSArray <MPSVector> .FromNSObjects(weightVectors.Momentum, biasVectors.Momentum);

            velocityVectors = NSArray <MPSVector> .FromNSObjects(weightVectors.Velocity, biasVectors.Velocity);

            var odesc = new MPSNNOptimizerDescriptor(options.LearningRate, 1.0f, MPSNNRegularizationType.None, 1.0f);

            updater = new MPSNNOptimizerAdam(
                options.Device,
                beta1: 0.9f, beta2: 0.999f, epsilon: 1e-8f,
                timeStep: 0,
                optimizerDescriptor: odesc);
        }
        public override MPSCnnConvolutionWeightsAndBiasesState Update(IMTLCommandBuffer commandBuffer, MPSCnnConvolutionGradientState gradientState, MPSCnnConvolutionWeightsAndBiasesState sourceState)
        {
            updateCount++;

            updater.Encode(commandBuffer, gradientState, sourceState, momentumVectors, velocityVectors, convWtsAndBias);

            if (updateCount != updater.TimeStep)
            {
                throw new Exception($"Update time step is out of synch");
            }

            //Console.WriteLine ($"UpdateWeights of Conv2dDataSource {this.Label}");

            return(convWtsAndBias);
        }