Example #1
0
        /// <summary>
        /// You might need it to make sure you don't copy
        /// your data but recreate a wrapper (if have one)
        ///
        /// O(V)
        /// </summary>
        public GenTensor <T> Forward()
        {
            var res = new GenTensor <T>(Shape);

            foreach (var index in res.IterateOverElements())
            {
                res.SetValueNoCheck(ConstantsAndFunctions <T> .Forward(GetValueNoCheck(index)), index);
            }
            return(res);
        }
        /// <summary>
        /// Creates a vector from an array of primitives
        /// Its length will be equal to elements.Length
        /// </summary>
        public static GenTensor <T> CreateVector(params T[] elements)
        {
            var res = new GenTensor <T>(elements.Length);

            for (int i = 0; i < elements.Length; i++)
            {
                res.SetValueNoCheck(elements[i], i);
            }
            return(res);
        }
Example #3
0
        /// <summary>
        /// [i, j, k...]th element of the resulting tensor is
        /// operation(a[i, j, k...], b[i, j, k...])
        /// </summary>
        public static GenTensor <T> Zip(GenTensor <T> a,
                                        GenTensor <T> b, Func <T, T, T> operation)
        {
            #if ALLOW_EXCEPTIONS
            if (a.Shape != b.Shape)
            {
                throw new InvalidShapeException("Arguments should be of the same shape");
            }
            #endif
            var res = new GenTensor <T>(a.Shape);

            if (res.Shape.shape.Length == 1)
            {
                for (int x = 0; x < res.Shape.shape[0]; x++)
                {
                    res.Data[x] = ConstantsAndFunctions <T> .Forward(
                        operation(a.GetValueNoCheck(x), b.GetValueNoCheck(x)));
                }
            }
            else if (res.Shape.shape.Length == 2)
            {
                for (int x = 0; x < res.Shape.shape[0]; x++)
                {
                    for (int y = 0; y < res.Shape.shape[1]; y++)
                    {
                        res.Data[x * res.Blocks[0] + y] = ConstantsAndFunctions <T> .Forward(
                            operation(a.GetValueNoCheck(x, y), b.GetValueNoCheck(x, y)));
                    }
                }
            }
            else if (res.Shape.shape.Length == 3)
            {
                for (int x = 0; x < res.Shape.shape[0]; x++)
                {
                    for (int y = 0; y < res.Shape.shape[1]; y++)
                    {
                        for (int z = 0; z < res.Shape.shape[2]; z++)
                        {
                            res.Data[x * res.Blocks[0] + y * res.Blocks[1] + z] = ConstantsAndFunctions <T> .Forward(
                                operation(a.GetValueNoCheck(x, y, z), b.GetValueNoCheck(x, y, z)));
                        }
                    }
                }
            }
            else
            {
                foreach (var index in res.IterateOverElements())
                {
                    res.SetValueNoCheck(ConstantsAndFunctions <T> .Forward(
                                            operation(a.GetValueNoCheck(index), b.GetValueNoCheck(index))), index);
                }
            }
            return(res);
        }
Example #4
0
        /// <summary>
        /// Copies a tensor calling each cell with a .Copy()
        ///
        /// O(V)
        /// </summary>
        public GenTensor <T> Copy(bool copyElements)
        {
            var res = new GenTensor <T>(Shape);

            if (!copyElements)
            {
                foreach (var index in res.IterateOverElements())
                {
                    res.SetValueNoCheck(ConstantsAndFunctions <T> .Forward(GetValueNoCheck(index)), index);
                }
            }
            else
            {
                foreach (var index in res.IterateOverElements())
                {
                    res.SetValueNoCheck(ConstantsAndFunctions <T> .Copy(GetValueNoCheck(index)), index);
                }
            }
            return(res);
        }
        /// <summary>
        /// Creates an indentity matrix whose width and height are equal to diag
        /// 1 is achieved with TWrapper.SetOne()
        /// 0 is achieved with TWrapper.SetZero()
        /// </summary>
        public static GenTensor <T> CreateIdentityMatrix(int diag)
        {
            var res = new GenTensor <T>(diag, diag);

            for (int i = 0; i < res.Data.Length; i++)
            {
                res.Data[i] = ConstantsAndFunctions <T> .CreateZero();
            }

            for (int i = 0; i < diag; i++)
            {
                res.SetValueNoCheck(ConstantsAndFunctions <T> .CreateOne, i, i);
            }
            return(res);
        }
 /// <summary>
 /// Applies scalar product to every vector in a tensor so that
 /// you will get a one-reduced dimensional tensor
 /// (e. g. TensorVectorDotProduct([4 x 3 x 2], [4 x 3 x 2]) -> [4 x 3]
 ///
 /// O(V)
 /// </summary>
 public static GenTensor <T> TensorVectorDotProduct(GenTensor <T> a,
                                                    GenTensor <T> b)
 {
     #if ALLOW_EXCEPTIONS
     if (a.Shape.SubShape(0, 1) != b.Shape.SubShape(0, 1))
     {
         throw new InvalidShapeException("Other dimensions of tensors should be equal");
     }
     #endif
     var resTensor = new GenTensor <T>(a.Shape.SubShape(0, 1));
     foreach (var index in resTensor.IterateOverElements())
     {
         var scal = VectorDotProduct(a.GetSubtensor(index), b.GetSubtensor(index));
         resTensor.SetValueNoCheck(scal, index);
     }
     return(resTensor);
 }
Example #7
0
        private static void GetCofactor(GenTensor <T> a, GenTensor <T> temp, int rowId,
                                        int colId, int diagLength)
        {
            int i = 0, j = 0;

            for (int row = 0; row < diagLength; row++)
            {
                for (int col = 0; col < diagLength; col++)
                {
                    if (row != rowId && col != colId)
                    {
                        temp.SetValueNoCheck(a.GetValueNoCheck(row, col), i, j++);
                        if (j == diagLength - 1)
                        {
                            j = 0;
                            i++;
                        }
                    }
                }
            }
        }