/// <summary> /// You might need it to make sure you don't copy /// your data but recreate a wrapper (if have one) /// /// O(V) /// </summary> public GenTensor <T> Forward() { var res = new GenTensor <T>(Shape); foreach (var index in res.IterateOverElements()) { res.SetValueNoCheck(ConstantsAndFunctions <T> .Forward(GetValueNoCheck(index)), index); } return(res); }
/// <summary> /// Creates a vector from an array of primitives /// Its length will be equal to elements.Length /// </summary> public static GenTensor <T> CreateVector(params T[] elements) { var res = new GenTensor <T>(elements.Length); for (int i = 0; i < elements.Length; i++) { res.SetValueNoCheck(elements[i], i); } return(res); }
/// <summary> /// [i, j, k...]th element of the resulting tensor is /// operation(a[i, j, k...], b[i, j, k...]) /// </summary> public static GenTensor <T> Zip(GenTensor <T> a, GenTensor <T> b, Func <T, T, T> operation) { #if ALLOW_EXCEPTIONS if (a.Shape != b.Shape) { throw new InvalidShapeException("Arguments should be of the same shape"); } #endif var res = new GenTensor <T>(a.Shape); if (res.Shape.shape.Length == 1) { for (int x = 0; x < res.Shape.shape[0]; x++) { res.Data[x] = ConstantsAndFunctions <T> .Forward( operation(a.GetValueNoCheck(x), b.GetValueNoCheck(x))); } } else if (res.Shape.shape.Length == 2) { for (int x = 0; x < res.Shape.shape[0]; x++) { for (int y = 0; y < res.Shape.shape[1]; y++) { res.Data[x * res.Blocks[0] + y] = ConstantsAndFunctions <T> .Forward( operation(a.GetValueNoCheck(x, y), b.GetValueNoCheck(x, y))); } } } else if (res.Shape.shape.Length == 3) { for (int x = 0; x < res.Shape.shape[0]; x++) { for (int y = 0; y < res.Shape.shape[1]; y++) { for (int z = 0; z < res.Shape.shape[2]; z++) { res.Data[x * res.Blocks[0] + y * res.Blocks[1] + z] = ConstantsAndFunctions <T> .Forward( operation(a.GetValueNoCheck(x, y, z), b.GetValueNoCheck(x, y, z))); } } } } else { foreach (var index in res.IterateOverElements()) { res.SetValueNoCheck(ConstantsAndFunctions <T> .Forward( operation(a.GetValueNoCheck(index), b.GetValueNoCheck(index))), index); } } return(res); }
/// <summary> /// Copies a tensor calling each cell with a .Copy() /// /// O(V) /// </summary> public GenTensor <T> Copy(bool copyElements) { var res = new GenTensor <T>(Shape); if (!copyElements) { foreach (var index in res.IterateOverElements()) { res.SetValueNoCheck(ConstantsAndFunctions <T> .Forward(GetValueNoCheck(index)), index); } } else { foreach (var index in res.IterateOverElements()) { res.SetValueNoCheck(ConstantsAndFunctions <T> .Copy(GetValueNoCheck(index)), index); } } return(res); }
/// <summary> /// Creates an indentity matrix whose width and height are equal to diag /// 1 is achieved with TWrapper.SetOne() /// 0 is achieved with TWrapper.SetZero() /// </summary> public static GenTensor <T> CreateIdentityMatrix(int diag) { var res = new GenTensor <T>(diag, diag); for (int i = 0; i < res.Data.Length; i++) { res.Data[i] = ConstantsAndFunctions <T> .CreateZero(); } for (int i = 0; i < diag; i++) { res.SetValueNoCheck(ConstantsAndFunctions <T> .CreateOne, i, i); } return(res); }
/// <summary> /// Applies scalar product to every vector in a tensor so that /// you will get a one-reduced dimensional tensor /// (e. g. TensorVectorDotProduct([4 x 3 x 2], [4 x 3 x 2]) -> [4 x 3] /// /// O(V) /// </summary> public static GenTensor <T> TensorVectorDotProduct(GenTensor <T> a, GenTensor <T> b) { #if ALLOW_EXCEPTIONS if (a.Shape.SubShape(0, 1) != b.Shape.SubShape(0, 1)) { throw new InvalidShapeException("Other dimensions of tensors should be equal"); } #endif var resTensor = new GenTensor <T>(a.Shape.SubShape(0, 1)); foreach (var index in resTensor.IterateOverElements()) { var scal = VectorDotProduct(a.GetSubtensor(index), b.GetSubtensor(index)); resTensor.SetValueNoCheck(scal, index); } return(resTensor); }
private static void GetCofactor(GenTensor <T> a, GenTensor <T> temp, int rowId, int colId, int diagLength) { int i = 0, j = 0; for (int row = 0; row < diagLength; row++) { for (int col = 0; col < diagLength; col++) { if (row != rowId && col != colId) { temp.SetValueNoCheck(a.GetValueNoCheck(row, col), i, j++); if (j == diagLength - 1) { j = 0; i++; } } } } }