예제 #1
0
        /// <summary>
        /// Calculates the difference between adjoining elements along the specified axes.
        /// </summary>
        /// <param name="axis">The axis to operate along.</param>
        /// <param name="source">The NdArray containing the source values.</param>
        /// <returns>The differences NdArray. It has one element less in dimension <paramref name="axis"/> as the source NdArray.</returns>
        public static NdArray <T> DiffAxis(int axis, NdArray <T> source)
        {
            NdArray <T> .CheckAxis(axis, source);

            var shiftRanges = new List <IRange>();
            var cutRanges   = new List <IRange>();

            for (var index = 0; index < source.NumDimensions; index++)
            {
                if (index == axis)
                {
                    shiftRanges.Add(RangeFactory.Range(1, SpecialIdx.None));
                    cutRanges.Add(RangeFactory.Range(SpecialIdx.None, source.Shape[index] - 2));
                }
                else
                {
                    shiftRanges.Add(RangeFactory.All);
                    cutRanges.Add(RangeFactory.All);
                }
            }

            var shiftArray = source[shiftRanges.ToArray()];
            var cutArray   = source[cutRanges.ToArray()];

            return(shiftArray - cutArray);
        }
예제 #2
0
        /// <summary>
        /// Concatenates NdArrays along an axis.
        /// </summary>
        /// <param name="axis">The concatenation axis.</param>
        /// <param name="sources">Sequence of NdArrays to concatenate.</param>
        /// <returns>The concatenated NdArray.</returns>
        public static NdArray <T> Concat(int axis, NdArray <T>[] sources)
        {
            if (sources.Length == 0)
            {
                throw new ArgumentException("Cannot concatenate empty sequence of NdArray.", "sources");
            }

            var shape = sources[0].Shape.Select(s => s).ToArray();

            if (!(axis >= 0 && axis < shape.Length))
            {
                var errorMessage = string.Format("Concatenation axis {0} is out of range for shape {1}.", axis, ErrorMessage.ShapeToString(shape));
                throw new ArgumentOutOfRangeException("axis", errorMessage);
            }

            var arrayIndex = 0;

            foreach (var source in sources)
            {
                if (!Enumerable.SequenceEqual(List.Without(axis, source.Shape), List.Without(axis, shape)))
                {
                    var errorMessage = string.Format("Concatentation element with index {0} with shape{1} must be equal to shape {2} of the first element, except in the concatenation axis {3}", arrayIndex, ErrorMessage.ShapeToString(source.Shape), ErrorMessage.ShapeToString(shape), axis);
                    throw new ArgumentException(errorMessage, "sources");
                }

                arrayIndex++;
            }

            var totalSize   = sources.Sum(i => i.Shape[axis]);
            var concatShape = List.Set(axis, totalSize, shape);

            var result   = new NdArray <T>(concatShape, sources[0].Storage.Device);
            var position = 0;

            foreach (var source in sources)
            {
                var arrayLength = source.Shape[axis];
                if (arrayLength > 0)
                {
                    var range = Enumerable.Range(0, shape.Length).Select(idx =>
                    {
                        if (idx == axis)
                        {
                            return(RangeFactory.Range(position, position + arrayLength - 1));
                        }

                        return(RangeFactory.All);
                    });
                    result[range.ToArray()] = source;
                    position += arrayLength;
                }
            }

            return(result);
        }
예제 #3
0
        public void Range_ReturnRange()
        {
            // arrange & action
            var rng = RangeFactory.Range(10, 30, 2) as Range;

            // assert
            Assert.IsInstanceOfType(rng, typeof(Range));
            Assert.AreEqual(10, rng.Start);
            Assert.AreEqual(30, rng.Stop);
            Assert.AreEqual(2, rng.Step);
        }
예제 #4
0
        public void RangeArgsToString_TwoRangesWithStep()
        {
            // arrange
            var objects = new[] { RangeFactory.Range(0, 4, 2), RangeFactory.Range(0, 3) };

            // action
            var message = ErrorMessage.RangeArgsToString(objects);

            // assert
            Assert.AreEqual("[0:4:2, 0:3]", message);
        }
예제 #5
0
        public void RangeArgsToString_RangeDefault()
        {
            // arrange
            var objects = new[] { RangeFactory.Range(0, 2) };

            // action
            var message = ErrorMessage.RangeArgsToString(objects);

            // assert
            Assert.AreEqual("[0:2]", message);
        }