Ejemplo n.º 1
0
 // result = a + alpha * b
 public static Array <Real> Add(this Array <Real> a, Array <Real> b, Real alpha = 1, Array <Real> result = null)
 {
     return(Array_.ElementwiseOp(a, b, result,
                                 (n, x, offsetx, incx, y, offsety, incy, z, offsetz, incz) =>
     {
         if (alpha == 1 && incx == 1 && incy == 1 && incz == 1)
         {
             Blas.vadd(n, x, offsetx, y, offsety, z, offsetz);
         }
         else if (alpha == -1 && incx == 1 && incy == 1 && incz == 1)
         {
             Blas.vsub(n, x, offsetx, y, offsety, z, offsetz);
         }
         else if (z == x)
         {
             Blas.axpy(n, alpha, y, offsety, incy, x, offsetx, incx);
         }
         else if (z == y && alpha == 1)
         {
             Blas.axpy(n, alpha, x, offsetx, incx, y, offsety, incy);
         }
         // TODO: else if (incx == 0) => broadcast x ??
         // TODO: else if (incy == 0) => broadcast y ??
         else
         {
             for (int i = 0; i < n; i++)         // TODO: Blas.copy y => x, Blas.axpy(1, x, y)
             {
                 z[offsetz] = x[offsetx] + alpha * y[offsety];
                 offsetx += incx;
                 offsety += incy;
                 offsetz += incz;
             }
         }   // See also: mkl_?omatadd => C := alpha*op(A) + beta*op(B)
     }));
 }
Ejemplo n.º 2
0
        public void CompareElementWisePerformance()
        {
            Trace.Listeners.Add(new ConsoleTraceListener());

            Func <float, float, float> f = (x, y) => x + y;
            var clock = new Stopwatch();

#if DEBUG
            Trace.WriteLine($"Testing on DEBUG");
#else
            Trace.WriteLine($"Testing on RELEASE");
#endif
            Trace.WriteLine($"Testing on {Blas.NThreads} threads");

            for (int i = 0; i < 300; ++i)
            {
                int n = i + 1;
                var a = NN.Random.Uniform(-1f, 1f, n, n);
                var b = NN.Random.Uniform(-1f, 1f, n, n);
                var r = NN.Zeros(n, n);

                var size = a.Size;
                // estimating loop count for this size
                NN.MIN_SIZE_FOR_PARELLELISM = size * 2;
                var loopCount = 0;
                clock.Restart();

                while (clock.ElapsedMilliseconds < 1000)
                {
                    NN.Apply(a, b, f, result: r);
                    ++loopCount;
                }
                Trace.WriteLine($"doing {loopCount} loops for size {size}");

                // profiling Normal
                clock.Restart();
                for (int _ = 0; _ < loopCount; _++)
                {
                    NN.Apply(a, b, f, result: r);
                }
                var time = clock.ElapsedMilliseconds;
                var avg  = (double)time / loopCount;

                // profiling Parrellized
                NN.MIN_SIZE_FOR_PARELLELISM = 0;
                clock.Restart();
                for (int _ = 0; _ < loopCount; _++)
                {
                    NN.Apply(a, b, f, result: r);
                }
                var timePar = clock.ElapsedMilliseconds;
                var avgPar  = (double)timePar / loopCount;

                clock.Restart();
                for (int _ = 0; _ < loopCount; _++)
                {
                    a.Add(b, result: r);
                }
                var timeAdd = clock.ElapsedMilliseconds;
                var avgAdd  = (double)timeAdd / loopCount;

                clock.Restart();
                for (int _ = 0; _ < loopCount; _++)
                {
                    Blas.vadd(size, a.Values, 0, b.Values, 0, r.Values, 0);
                }
                var timeBlas = clock.ElapsedMilliseconds;
                var avgBlas  = (double)timeBlas / loopCount;

                var message = $"On size {size}, avg time: {avg:F4}ms \t with parallelism {avgPar:F4}ms \t with Add {avgAdd:F4}ms \t with Blas {avgBlas:F4}ms.";
                Trace.WriteLine(message);
            }

            throw new Exception("see output for profiler results");
        }