Example #1
0
 public static Tensor <T> Create <T>(Tensor <T> x, Tensor <int> indexes, int axis, Dim[] shape)
 {
     x.AssertOfDim(shape.Length);
     for (int i = 0; i < x.NDim; ++i)
     {
         if (i != axis)
         {
             ShapeExtension.Bind(ref x.Shape[i], ref shape[i]);
         }
     }
     return(new UnArgmax <T>(x, indexes, axis, shape));
 }
Example #2
0
        static Dim[] EinsteinShape(Dim[] shapeX, Dim[] shapeY, Einstein[] einstein)
        {
            var ndim   = einstein.Sum(e => e.axisZ == null ? 0 : 1);
            var shapeZ = new Dim[ndim];

            foreach (var e in einstein)
            {
                switch (e.mode)
                {
                case EinsteinMode.INNER:
                    ShapeExtension.Bind(ref shapeX[(int)e.axisX], ref shapeY[(int)e.axisY]);
                    if (e.axisZ != null)
                    {
                        shapeZ[(int)e.axisZ] = 1;
                    }
                    break;

                case EinsteinMode.ELEMENTWISE:
                    ShapeExtension.Bind(ref shapeX[(int)e.axisX], ref shapeY[(int)e.axisY]);
                    shapeZ[(int)e.axisZ] = shapeX[(int)e.axisX];
                    break;

                case EinsteinMode.OUTERX:
                    shapeZ[(int)e.axisZ] = shapeX[(int)e.axisX];
                    break;

                case EinsteinMode.OUTERY:
                    shapeZ[(int)e.axisZ] = shapeY[(int)e.axisY];
                    break;

                case EinsteinMode.SUMX:
                case EinsteinMode.SUMY:
                    if (e.axisZ != null)
                    {
                        shapeZ[(int)e.axisZ] = 1;
                    }
                    break;
                }
            }

            return(shapeZ);
        }
Example #3
0
        protected Concat(int axis, params Tensor <T>[] inputs) : base("Concat", (Scalar <int>)axis, inputs.ToStructArray())
        {
            int ndim = inputs.Max(i => i.NDim);

            if (axis < 0)
            {
                axis += ndim;
            }

            foreach (var x in inputs)
            {
                x.AssertOfDim(ndim);
            }

            this.Shape = new Dim[ndim];
            for (int a = 0; a < ndim; ++a)
            {
                Shape[a] = inputs[0].Shape[a];
            }
            _slices    = new XSlice[inputs.Length];
            _slices[0] = Range(0, this.Shape[axis]);

            for (int i = 1; i < inputs.Length; ++i)
            {
                inputs[i].AssertOfDim(ndim);
                for (int a = 0; a < ndim; ++a)
                {
                    if (a != axis)
                    {
                        ShapeExtension.Bind(ref inputs[i].Shape[a], ref this.Shape[a]);
                    }
                }

                var start = Shape[axis];
                Shape[axis] += inputs[i].Shape[axis];
                _slices[i]   = Range(start, Shape[axis]);
            }

            _inputs = inputs;
            _axis   = axis;
        }
Example #4
0
        public static Tensor <float> Create(Tensor <float> a, IEnumerable <int> axesA, Tensor <float> b, IEnumerable <int> axesB)
        {
            var removeA = axesA.Select(i => i < 0 ? i + a.NDim : i).ToArray();
            var removeB = axesB.Select(i => i < 0 ? i + b.NDim : i).ToArray();

            int n = removeA.Length;

            if (removeB.Length != n)
            {
                throw new RankException(string.Format(
                                            "The axes parameters of TensorDot should have the same size. Found [{0}] and [{1}].",
                                            string.Join(", ", removeA.AsEnumerable()), string.Join(", ", removeB.AsEnumerable())));
            }

            for (int d = 0; d < n; ++d)
            {
                ShapeExtension.Bind(ref a.Shape[removeA[d]], ref b.Shape[removeB[d]]);
            }

            var keptX = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d)).ToArray();
            var keptY = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d)).ToArray();
            // Move the axes to sum over to the end of "a"
            var keptA = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d));
            var at    = a.DimShuffle(keptA.Concat(removeA).ToArray());

            // Move the axes to sum over to the front of "b"
            var keptB = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d));
            var bt    = b.DimShuffle(removeB.Concat(keptB).ToArray());

            var resultShape = keptA.Select(axis => a.Shape[axis]).Concat(keptB.Select(axis => b.Shape[axis])).ToArray();
            var a2d         = Reshape2D(at, a.NDim - n);
            var b2d         = Reshape2D(bt, n);
            var res         = Op.Dot(a2d, b2d).Reshape(resultShape);

            res.Comment = "TensorDot";

            return(res);
            //return new TensorDot(a, ToAxes(_axesX), b, ToAxes(_axesY));
        }
Example #5
0
        private Dot(Tensor <float> x, Tensor <float> y, bool transposeX = false, bool transposeY = false)
            : base("Dot", x, y, transposeX.Named("transA"), transposeY.Named("transB"))
        {
            this.TransposeX = transposeX;
            this.TransposeY = transposeY;

            if (x.NDim == 1 && y.NDim == 1)
            {
                // rowV dot rowV (forbidden)
                if (TransposeX && TransposeY)
                {
                    throw Rank(this, "Can't dot two row vectors: {0} and {1}");
                }
                // colV dot colV (used as inner product)
                if (!TransposeX && !TransposeY)
                {
                    TransposeX = true;
                }
                // colV dot rowV (outer product)
                if (!TransposeX && TransposeY)
                {
                    _shape = new Dim[] { x.Shape[0], y.Shape[0] }
                }
                ;
                // rowV dot colV (inner product)
                if (TransposeX && !TransposeY)
                {
                    if (!x.Shape[0].CanEqualTo(y.Shape[0]))
                    {
                        throw Rank(this, "Can't dot {0} with {1}");
                    }
                    ShapeExtension.Bind(ref x.Shape[0], ref y.Shape[0]);
                    _shape = new Dim[] { };
                }
                return;
            }

            // mat dot colV
            if (y.NDim == 1)
            {
                TransposeY = false;
                if (!x.Shape[TransposeX ? 0 : x.NDim - 1].CanEqualTo(y.Shape[0]))
                {
                    throw Rank(this, "Can't dot {0} with {1}");
                }
                ShapeExtension.Bind(ref x.Shape[TransposeX ? 0 : x.NDim - 1], ref y.Shape[0]);
                _shape = (TransposeX ? x.Shape.Reverse().ToArray() : x.Shape).DropRight(1);
                return;
            }
            else if (y.NDim == 0)
            {
                TransposeY = false;
                _shape     = x.Shape.ToArray();
            }
            // mat dot mat
            else
            {
                if (x.NDim == 1)
                {
                    TransposeX = true;
                }
                _shape = new Dim[x.NDim + y.NDim - 2];
                var axisX = x.Shape[TransposeX ? 0 : x.NDim - 1];
                var axisY = y.Shape[TransposeY ? 1 : y.NDim - 2];
                if (!axisX.CanEqualTo(axisY))
                {
                    throw Rank(this, "Can't dot {0} with {1}");
                }
                ShapeExtension.Bind(ref axisX, ref axisY);

                // copy x dims but (n - 1)
                if (!TransposeX)
                {
                    for (int i = 0; i < x.NDim - 1; ++i)
                    {
                        _shape[i] = x.Shape[i];
                    }
                }
                else
                {
                    for (int i = 0; i < x.NDim - 1; ++i)
                    {
                        _shape[i] = x.Shape[x.NDim - 1 - i];
                    }
                }

                // copy y dims but (n - 2)
                if (!TransposeY)
                {
                    for (int i = 0; i < y.NDim - 2; ++i)
                    {
                        _shape[x.NDim + i] = y.Shape[i];
                    }
                    _shape[x.NDim + y.NDim - 3] = y.Shape[y.NDim - 1];
                }
                else
                {
                    for (int i = 0; i < y.NDim - 2; ++i)
                    {
                        _shape[x.NDim + i] = y.Shape[y.NDim - 1 - i];
                    }
                    _shape[x.NDim + y.NDim - 3] = y.Shape[0];
                }
            }
        }