示例#1
0
        public Tensor Gather(Tensor result, Tensor src, int dim, Tensor indices)
        {
            if (result != null && result.DimensionCount != src.DimensionCount)
            {
                throw new InvalidOperationException("result and src must have same number of dimensions");
            }

            if (result != null && dim < 0 && dim >= result.DimensionCount)
            {
                throw new ArgumentOutOfRangeException(nameof(dim));
            }

            if (indices.DimensionCount != src.DimensionCount)
            {
                throw new InvalidOperationException("src and indices must have same number of dimensions");
            }

            if (result != null && !result.IsSameSizeAs(indices))
            {
                throw new InvalidOperationException("result and indices must be the same size");
            }

            if (result != null && !TensorResultBuilder.ArrayEqualExcept(src.Sizes, result.Sizes, dim))
            {
                throw new InvalidOperationException("result and src must be the same size except in dimension dim");
            }

            Tensor writeTarget = TensorResultBuilder.GetWriteTarget(result, indices.Allocator, src.ElementType, false, indices.Sizes);

            TensorApplyCPU.Gather(writeTarget, src, dim, indices);

            return(writeTarget);
        }