Ejemplo n.º 1
0
        private static Tensor <float> _simplifyDot(Tensor <float> x, Tensor <float> y, Einstein[] einstein)
        {
            // "ij,jk->ik" is a dot
            // "ji,jk->ik" is a dot
            int  outerX = -1, inner = -1, outerY = -1;
            bool valid = true;

            for (int i = 0; i < einstein.Length; ++i)
            {
                switch (einstein[i].mode)
                {
                case EinsteinMode.INNER:
                    if (inner < 0)
                    {
                        inner = i;
                    }
                    else
                    {
                        valid = false;
                    }
                    break;

                case EinsteinMode.OUTERX:
                    if (outerX == -1)
                    {
                        outerX = i;
                    }
                    break;

                case EinsteinMode.OUTERY:
                    outerY = i;
                    break;
                }
            }
            if (valid && inner >= 0)
            {
                if (x.NDim <= 2 && y.NDim <= 2)
                {
                    var transX = outerX >= 0 && einstein[outerX].axisX > einstein[inner].axisX;
                    var transY = outerY >= 0 && einstein[outerY].axisY < einstein[inner].axisY;
                    var transZ = outerX >= 0 && outerY >= 0 && einstein[outerX].axisZ > einstein[outerY].axisZ;
                    if (!transZ)
                    {
                        return(Op.Dot(x, y, transposeX: transX, transposeY: transY));
                    }
                    else
                    {
                        return(Op.Dot(y, x, transposeX: !transY, transposeY: !transX));
                    }
                }
                else
                {
                    return(Op.TensorDot(
                               x, new int[] { (int)einstein[inner].axisX },
                               y, new int[] { (int)einstein[inner].axisY }
                               ));
                }
            }
            return(null);
        }
Ejemplo n.º 2
0
        public static Tensor <float> Create(Tensor <float> a, IEnumerable <int> axesA, Tensor <float> b, IEnumerable <int> axesB)
        {
            var removeA = axesA.Select(i => i < 0 ? i + a.NDim : i).ToArray();
            var removeB = axesB.Select(i => i < 0 ? i + b.NDim : i).ToArray();

            int n = removeA.Length;

            if (removeB.Length != n)
            {
                throw new RankException(string.Format(
                                            "The axes parameters of TensorDot should have the same size. Found [{0}] and [{1}].",
                                            string.Join(", ", removeA.AsEnumerable()), string.Join(", ", removeB.AsEnumerable())));
            }

            for (int d = 0; d < n; ++d)
            {
                ShapeExtension.Bind(ref a.Shape[removeA[d]], ref b.Shape[removeB[d]]);
            }

            var keptX = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d)).ToArray();
            var keptY = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d)).ToArray();
            // Move the axes to sum over to the end of "a"
            var keptA = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d));
            var at    = a.DimShuffle(keptA.Concat(removeA).ToArray());

            // Move the axes to sum over to the front of "b"
            var keptB = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d));
            var bt    = b.DimShuffle(removeB.Concat(keptB).ToArray());

            var resultShape = keptA.Select(axis => a.Shape[axis]).Concat(keptB.Select(axis => b.Shape[axis])).ToArray();
            var a2d         = Reshape2D(at, a.NDim - n);
            var b2d         = Reshape2D(bt, n);
            var res         = Op.Dot(a2d, b2d).Reshape(resultShape);

            res.Comment = "TensorDot";

            return(res);
            //return new TensorDot(a, ToAxes(_axesX), b, ToAxes(_axesY));
        }
Ejemplo n.º 3
0
        public void DontDuplicatesGradient()
        {
            var x  = Op.Vector <float>("x");
            var W  = Op.Shared(NN.Random.Uniform(-1f, 1f, 10, 10), "W1");
            var x1 = Op.Dot(W, x);

            x1.Name = nameof(x1);

            var y = Op.Sigmoid(Op.Tanh(x1));

            y.Name = nameof(y);

            var cost = Op.Norm2(y) + Op.Norm2(x1);

            var dW = Op.Grad(cost, W);

            var update = new OrderedDictionary {
                [W] = W - 0.05f * dW
            };

            var f = Op.Function(input: x, output: cost, updates: update);

            AssertSourceContains("Dot", exactly: 2);
        }