/// <summary> /// Calculates the difference between adjoining elements along the specified axes. /// </summary> /// <param name="axis">The axis to operate along.</param> /// <param name="source">The NdArray containing the source values.</param> /// <returns>The differences NdArray. It has one element less in dimension <paramref name="axis"/> as the source NdArray.</returns> public static NdArray <T> DiffAxis(int axis, NdArray <T> source) { NdArray <T> .CheckAxis(axis, source); var shiftRanges = new List <IRange>(); var cutRanges = new List <IRange>(); for (var index = 0; index < source.NumDimensions; index++) { if (index == axis) { shiftRanges.Add(RangeFactory.Range(1, SpecialIdx.None)); cutRanges.Add(RangeFactory.Range(SpecialIdx.None, source.Shape[index] - 2)); } else { shiftRanges.Add(RangeFactory.All); cutRanges.Add(RangeFactory.All); } } var shiftArray = source[shiftRanges.ToArray()]; var cutArray = source[cutRanges.ToArray()]; return(shiftArray - cutArray); }
/// <summary> /// Concatenates NdArrays along an axis. /// </summary> /// <param name="axis">The concatenation axis.</param> /// <param name="sources">Sequence of NdArrays to concatenate.</param> /// <returns>The concatenated NdArray.</returns> public static NdArray <T> Concat(int axis, NdArray <T>[] sources) { if (sources.Length == 0) { throw new ArgumentException("Cannot concatenate empty sequence of NdArray.", "sources"); } var shape = sources[0].Shape.Select(s => s).ToArray(); if (!(axis >= 0 && axis < shape.Length)) { var errorMessage = string.Format("Concatenation axis {0} is out of range for shape {1}.", axis, ErrorMessage.ShapeToString(shape)); throw new ArgumentOutOfRangeException("axis", errorMessage); } var arrayIndex = 0; foreach (var source in sources) { if (!Enumerable.SequenceEqual(List.Without(axis, source.Shape), List.Without(axis, shape))) { var errorMessage = string.Format("Concatentation element with index {0} with shape{1} must be equal to shape {2} of the first element, except in the concatenation axis {3}", arrayIndex, ErrorMessage.ShapeToString(source.Shape), ErrorMessage.ShapeToString(shape), axis); throw new ArgumentException(errorMessage, "sources"); } arrayIndex++; } var totalSize = sources.Sum(i => i.Shape[axis]); var concatShape = List.Set(axis, totalSize, shape); var result = new NdArray <T>(concatShape, sources[0].Storage.Device); var position = 0; foreach (var source in sources) { var arrayLength = source.Shape[axis]; if (arrayLength > 0) { var range = Enumerable.Range(0, shape.Length).Select(idx => { if (idx == axis) { return(RangeFactory.Range(position, position + arrayLength - 1)); } return(RangeFactory.All); }); result[range.ToArray()] = source; position += arrayLength; } } return(result); }
public void Range_ReturnRange() { // arrange & action var rng = RangeFactory.Range(10, 30, 2) as Range; // assert Assert.IsInstanceOfType(rng, typeof(Range)); Assert.AreEqual(10, rng.Start); Assert.AreEqual(30, rng.Stop); Assert.AreEqual(2, rng.Step); }
public void RangeArgsToString_TwoRangesWithStep() { // arrange var objects = new[] { RangeFactory.Range(0, 4, 2), RangeFactory.Range(0, 3) }; // action var message = ErrorMessage.RangeArgsToString(objects); // assert Assert.AreEqual("[0:4:2, 0:3]", message); }
public void RangeArgsToString_RangeDefault() { // arrange var objects = new[] { RangeFactory.Range(0, 2) }; // action var message = ErrorMessage.RangeArgsToString(objects); // assert Assert.AreEqual("[0:2]", message); }