private static Tensor <float> _simplifyDot(Tensor <float> x, Tensor <float> y, Einstein[] einstein) { // "ij,jk->ik" is a dot // "ji,jk->ik" is a dot int outerX = -1, inner = -1, outerY = -1; bool valid = true; for (int i = 0; i < einstein.Length; ++i) { switch (einstein[i].mode) { case EinsteinMode.INNER: if (inner < 0) { inner = i; } else { valid = false; } break; case EinsteinMode.OUTERX: if (outerX == -1) { outerX = i; } break; case EinsteinMode.OUTERY: outerY = i; break; } } if (valid && inner >= 0) { if (x.NDim <= 2 && y.NDim <= 2) { var transX = outerX >= 0 && einstein[outerX].axisX > einstein[inner].axisX; var transY = outerY >= 0 && einstein[outerY].axisY < einstein[inner].axisY; var transZ = outerX >= 0 && outerY >= 0 && einstein[outerX].axisZ > einstein[outerY].axisZ; if (!transZ) { return(Op.Dot(x, y, transposeX: transX, transposeY: transY)); } else { return(Op.Dot(y, x, transposeX: !transY, transposeY: !transX)); } } else { return(Op.TensorDot( x, new int[] { (int)einstein[inner].axisX }, y, new int[] { (int)einstein[inner].axisY } )); } } return(null); }
public static Tensor <float> Create(Tensor <float> a, IEnumerable <int> axesA, Tensor <float> b, IEnumerable <int> axesB) { var removeA = axesA.Select(i => i < 0 ? i + a.NDim : i).ToArray(); var removeB = axesB.Select(i => i < 0 ? i + b.NDim : i).ToArray(); int n = removeA.Length; if (removeB.Length != n) { throw new RankException(string.Format( "The axes parameters of TensorDot should have the same size. Found [{0}] and [{1}].", string.Join(", ", removeA.AsEnumerable()), string.Join(", ", removeB.AsEnumerable()))); } for (int d = 0; d < n; ++d) { ShapeExtension.Bind(ref a.Shape[removeA[d]], ref b.Shape[removeB[d]]); } var keptX = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d)).ToArray(); var keptY = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d)).ToArray(); // Move the axes to sum over to the end of "a" var keptA = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d)); var at = a.DimShuffle(keptA.Concat(removeA).ToArray()); // Move the axes to sum over to the front of "b" var keptB = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d)); var bt = b.DimShuffle(removeB.Concat(keptB).ToArray()); var resultShape = keptA.Select(axis => a.Shape[axis]).Concat(keptB.Select(axis => b.Shape[axis])).ToArray(); var a2d = Reshape2D(at, a.NDim - n); var b2d = Reshape2D(bt, n); var res = Op.Dot(a2d, b2d).Reshape(resultShape); res.Comment = "TensorDot"; return(res); //return new TensorDot(a, ToAxes(_axesX), b, ToAxes(_axesY)); }
public void DontDuplicatesGradient() { var x = Op.Vector <float>("x"); var W = Op.Shared(NN.Random.Uniform(-1f, 1f, 10, 10), "W1"); var x1 = Op.Dot(W, x); x1.Name = nameof(x1); var y = Op.Sigmoid(Op.Tanh(x1)); y.Name = nameof(y); var cost = Op.Norm2(y) + Op.Norm2(x1); var dW = Op.Grad(cost, W); var update = new OrderedDictionary { [W] = W - 0.05f * dW }; var f = Op.Function(input: x, output: cost, updates: update); AssertSourceContains("Dot", exactly: 2); }