示例#1
0
        public float GetError(Norm norm)
        {
            float error = 0;

            switch (norm)
            {
            case Norm.L1:
                _groundTrouthData.Sub(_res, _temp);
                _temp.Abs();
                _temp.Sum(_summedError, _buffer);

                error = _summedError;
                error = error / _batch / _inChannels / _inWidth / _inHeight;
                return(error);

            case Norm.L2:
                _groundTrouthData.Sub(_res, _temp);
                _temp.Sqr();
                _temp.Sum(_summedError, _buffer);

                error = _summedError;
                error = error / _batch / _inChannels / _inWidth / _inHeight;
                return(error);

            case Norm.MSSSIM:
                return(error);

            case Norm.Mix:
                if (_msssiml1 == null)
                {
                    _msssiml1 = new CudaDeviceVariable <float>(_inChannels * _batch);
                }

                if (_kernelMSSSIML1 == null)
                {
                }

                _kernelMSSSIML1.RunSafe(_res, _groundTrouthData, _msssiml1, _dx, _inChannels, _batch, 0.84f);
                _msssiml1.Sum(_summedError, _buffer);
                error = _summedError;
                return(error);

            default:
                return(0);
            }
        }
        public void MinimizeCUBLAS(int tileCountX, int tileCountY)
        {
            int shiftCount;// = shifts.Count;

            shiftCount = GetShiftCount();

            concatenateShifts.RunSafe(shifts_d, shiftPitches, AllShifts_d, shiftCount, tileCountX, tileCountY);


            shiftsMeasured.CopyToDevice(AllShifts_d);

            CudaStopWatch sw = new CudaStopWatch();

            sw.Start();


            int imageCount = frameCount;
            int tileCount  = tileCountX * tileCountY;
            int n1         = imageCount - 1;
            int m          = shiftCount;

            status.Memset(0);
            shiftMatrices.Memset(0);
            float[] shiftMatrix = CreateShiftMatrix();
            shiftMatrices.CopyToDevice(shiftMatrix, 0, 0, shiftMatrix.Length * sizeof(float));

            copyShiftMatrixKernel.RunSafe(shiftMatrices, tileCount, imageCount, shiftCount);
            shiftSafeMatrices.CopyToDevice(shiftMatrices);


            for (int i = 0; i < 10; i++)
            {
                blas.GemmBatched(Operation.Transpose, Operation.NonTranspose, n1, n1, m, one, shiftMatrixArray, m, shiftMatrixArray, m, zero, matrixSquareArray, n1, tileCount);
                //float[] mSqr = matricesSquared;

                if (n1 <= 32)
                {
                    //MatinvBatchedS can only invert up to 32x32 matrices
                    blas.MatinvBatchedS(n1, matrixSquareArray, n1, matrixInvertedArray, n1, infoInverse, tileCount);
                }
                else
                {
                    blas.GetrfBatchedS(n1, matrixSquareArray, n1, pivotArray, infoInverse, tileCount);
                    blas.GetriBatchedS(n1, matrixSquareArray, n1, pivotArray, matrixInvertedArray, n1, infoInverse, tileCount);
                }


                //int[] info = infoInverse;
                //mSqr = matricesInverted;
                blas.GemmBatched(Operation.NonTranspose, Operation.Transpose, n1, m, n1, one, matrixInvertedArray, n1, shiftMatrixArray, m, zero, solvedMatrixArray, n1, tileCount);
                blas.GemmBatched(Operation.NonTranspose, Operation.Transpose, n1, 2, m, one, solvedMatrixArray, n1, shiftMeasuredArray, 2, zero, shiftOneToOneArray, n1, tileCount);
                blas.GemmBatched(Operation.NonTranspose, Operation.NonTranspose, m, 2, n1, one, shiftMatrixArray, m, shiftOneToOneArray, n1, zero, shiftOptimArray, m, tileCount);

                checkForOutliers.RunSafe(shiftsMeasured, shiftsOptim, shiftMatrices, status, infoInverse, tileCount, imageCount, shiftCount);

                status.Sum(statusSum, buffer, 0);
                int[] stats = status;

                for (int j = 0; j < tileCount; j++)
                {
                    if (stats[j] >= 0)
                    {
                        Console.Write(j + ": " + stats[j] + "; ");
                    }
                }
                Console.WriteLine();

                int stat = statusSum;
                if (stat == -tileCount)
                {
                    break;
                }

                //float2[] AllShifts_h = shiftsMeasured;
            }

            blas.GemmBatched(Operation.NonTranspose, Operation.NonTranspose, m, 2, n1, one, shiftMatrixSafeArray, m, shiftOneToOneArray, n1, zero, shiftMeasuredArray, m, tileCount);

            AllShifts_d.Memset(0);
            transposeShifts.RunSafe(AllShifts_d, shiftsMeasured, shiftsOneToOne, shiftsOneToOne_d, tileCount, imageCount, shiftCount);
            //shiftsMeasured.CopyToDevice(AllShifts_d);

            //float2[] AllShiftsFinal_h = shiftsMeasured;

            sw.Stop();
            Console.WriteLine("Time for optimisation: " + sw.GetElapsedTime() + " msec.");

            separateShifts.RunSafe(AllShifts_d, shifts_d, shiftPitches, shiftCount, tileCountX, tileCountY);
        }