Example #1
0
        /// <summary>
        /// Validate the list of arguments
        /// </summary>
        /// <typeparam name="TSource">The type of the source.</typeparam>
        /// <param name="functionName">Name of the function.</param>
        /// <param name="args">The arguments.</param>
        /// <exception cref="ExecutionException"></exception>
        private void ValidateArgs(string functionName, object[] args)
        {
            var method = KernelFunctions.FirstOrDefault(x => (x.Name == functionName));

            for (int i = 0; i < args.Length; i++)
            {
                var parameter = method.Parameters.ElementAt(i);

                if (args[i].GetType().IsPrimitive)
                {
                    args[i] = Convert.ChangeType(args[i], Type.GetType(parameter.Value.TypeName));
                }
                else if (args[i].GetType().IsArray)
                {
                    if (parameter.Value.TypeName != args[i].GetType().FullName)
                    {
                        throw new ExecutionException(string.Format("Data type mismatch for parameter {0}. Expected is {1} but got {2}",
                                                                   parameter.Key,
                                                                   (parameter.Value,
                                                                    args[i].GetType().FullName)));
                    }
                }
                else if (args[i].GetType().Name == "XArray" || args[i].GetType().BaseType.Name == "XArray")
                {
                    XArray array = (XArray)args[i];
                    if (!parameter.Value.TypeName.Contains(array.DataType.ToCLRType().Name))
                    {
                        throw new ExecutionException(string.Format("Data type mismatch for parameter {0}. Expected is {1} but got {2}",
                                                                   parameter.Key,
                                                                   (parameter.Value,
                                                                    array.DataType.ToCLRType().Name)));
                    }
                }
            }
        }
Example #2
0
        internal static XArray FromRef(XArrayRef tensorRef)
        {
            long[] shape_data = new long[tensorRef.dimCount];
            Marshal.Copy(tensorRef.sizes, shape_data, 0, shape_data.Length);
            XArray result = new XArray(shape_data, tensorRef.elementType);

            result.NativePtr = tensorRef.buffer;

            return(result);
        }
Example #3
0
        internal XArrayRef AllocTensorRef(XArray tensor)
        {
            var tensorRef = new XArrayRef();

            tensorRef.buffer      = GetBufferStart(tensor);
            tensorRef.dimCount    = tensor.Sizes.Length;
            tensorRef.sizes       = AllocArray(tensor.Sizes);
            tensorRef.strides     = AllocArray(tensor.strides);
            tensorRef.elementType = tensor.DataType;
            return(tensorRef);
        }
Example #4
0
        /// <summary>
        /// Executes the specified kernel function name.
        /// </summary>
        /// <typeparam name="TSource">The type of the source.</typeparam>
        /// <param name="functionName">Name of the function.</param>
        /// <param name="args"></param>
        /// <exception cref="ExecutionException">
        /// </exception>
        public override void Execute(string functionName, params object[] args)
        {
            ValidateArgs(functionName, args);

            ComputeKernel       kernel   = _compiledKernels.FirstOrDefault(x => (x.FunctionName == functionName));
            ComputeCommandQueue commands = new ComputeCommandQueue(_context, _defaultDevice, ComputeCommandQueueFlags.None);

            if (kernel == null)
            {
                throw new ExecutionException(string.Format("Kernal function {0} not found", functionName));
            }

            try
            {
                Array       ndobject    = (Array)args.FirstOrDefault(x => (x.GetType().IsArray));
                List <long> length      = new List <long>();
                long        totalLength = 0;
                if (ndobject == null)
                {
                    var xarrayList = args.Where(x => (x.GetType().Name == "XArray" || x.GetType().BaseType.Name == "XArray")).ToList();
                    foreach (var item in xarrayList)
                    {
                        var xarrayobj = (XArray)item;
                        if (xarrayobj.Direction == Direction.Output)
                        {
                            totalLength = xarrayobj.Count;
                            if (!xarrayobj.IsElementWise)
                            {
                                length = xarrayobj.Sizes.ToList();
                            }
                            else
                            {
                                length.Add(totalLength);
                            }
                        }
                    }

                    if (totalLength == 0)
                    {
                        var xarrayobj = (XArray)xarrayList[0];
                        totalLength = xarrayobj.Count;
                        if (!xarrayobj.IsElementWise)
                        {
                            length = xarrayobj.Sizes.ToList();
                        }
                        else
                        {
                            length.Add(totalLength);
                        }
                    }
                }
                else
                {
                    totalLength = ndobject.Length;
                    for (int i = 0; i < ndobject.Rank; i++)
                    {
                        length.Add(ndobject.GetLength(i));
                    }
                }

                var method = KernelFunctions.FirstOrDefault(x => (x.Name == functionName));

                var buffers = BuildKernelArguments(method, args, kernel, totalLength);
                commands.Execute(kernel, null, length.ToArray(), null, null);

                for (int i = 0; i < args.Length; i++)
                {
                    if (args[i].GetType().IsArray)
                    {
                        var ioMode = method.Parameters.ElementAt(i).Value.IOMode;
                        if (ioMode == IOMode.InOut || ioMode == IOMode.Out)
                        {
                            Array r = (Array)args[i];
                            commands.ReadFromMemory(buffers[i], ref r, true, 0, null);
                        }

                        buffers[i].Dispose();
                    }
                    else if (args[i].GetType().Name == "XArray" || args[i].GetType().BaseType.Name == "XArray")
                    {
                        var ioMode = method.Parameters.ElementAt(i).Value.IOMode;
                        if (ioMode == IOMode.InOut || ioMode == IOMode.Out)
                        {
                            XArray r = (XArray)args[i];
                            commands.ReadFromMemory(buffers[i], ref r, true, 0, null);
                        }

                        buffers[i].Dispose();
                    }
                }
            }
            catch (Exception ex)
            {
                throw new ExecutionException(ex.Message);
            }
            finally
            {
                commands.Finish();
                commands.Dispose();
            }
        }
Example #5
0
 public IntPtr GetBufferStart(XArray tensor)
 {
     return(PtrAdd(NativePtr, tensor.storageOffset * tensor.DataType.Size()));
 }