public static Tensor <T> Create <T>(Tensor <T> x, Tensor <int> indexes, int axis, Dim[] shape) { x.AssertOfDim(shape.Length); for (int i = 0; i < x.NDim; ++i) { if (i != axis) { ShapeExtension.Bind(ref x.Shape[i], ref shape[i]); } } return(new UnArgmax <T>(x, indexes, axis, shape)); }
static Dim[] EinsteinShape(Dim[] shapeX, Dim[] shapeY, Einstein[] einstein) { var ndim = einstein.Sum(e => e.axisZ == null ? 0 : 1); var shapeZ = new Dim[ndim]; foreach (var e in einstein) { switch (e.mode) { case EinsteinMode.INNER: ShapeExtension.Bind(ref shapeX[(int)e.axisX], ref shapeY[(int)e.axisY]); if (e.axisZ != null) { shapeZ[(int)e.axisZ] = 1; } break; case EinsteinMode.ELEMENTWISE: ShapeExtension.Bind(ref shapeX[(int)e.axisX], ref shapeY[(int)e.axisY]); shapeZ[(int)e.axisZ] = shapeX[(int)e.axisX]; break; case EinsteinMode.OUTERX: shapeZ[(int)e.axisZ] = shapeX[(int)e.axisX]; break; case EinsteinMode.OUTERY: shapeZ[(int)e.axisZ] = shapeY[(int)e.axisY]; break; case EinsteinMode.SUMX: case EinsteinMode.SUMY: if (e.axisZ != null) { shapeZ[(int)e.axisZ] = 1; } break; } } return(shapeZ); }
protected Concat(int axis, params Tensor <T>[] inputs) : base("Concat", (Scalar <int>)axis, inputs.ToStructArray()) { int ndim = inputs.Max(i => i.NDim); if (axis < 0) { axis += ndim; } foreach (var x in inputs) { x.AssertOfDim(ndim); } this.Shape = new Dim[ndim]; for (int a = 0; a < ndim; ++a) { Shape[a] = inputs[0].Shape[a]; } _slices = new XSlice[inputs.Length]; _slices[0] = Range(0, this.Shape[axis]); for (int i = 1; i < inputs.Length; ++i) { inputs[i].AssertOfDim(ndim); for (int a = 0; a < ndim; ++a) { if (a != axis) { ShapeExtension.Bind(ref inputs[i].Shape[a], ref this.Shape[a]); } } var start = Shape[axis]; Shape[axis] += inputs[i].Shape[axis]; _slices[i] = Range(start, Shape[axis]); } _inputs = inputs; _axis = axis; }
public static Tensor <float> Create(Tensor <float> a, IEnumerable <int> axesA, Tensor <float> b, IEnumerable <int> axesB) { var removeA = axesA.Select(i => i < 0 ? i + a.NDim : i).ToArray(); var removeB = axesB.Select(i => i < 0 ? i + b.NDim : i).ToArray(); int n = removeA.Length; if (removeB.Length != n) { throw new RankException(string.Format( "The axes parameters of TensorDot should have the same size. Found [{0}] and [{1}].", string.Join(", ", removeA.AsEnumerable()), string.Join(", ", removeB.AsEnumerable()))); } for (int d = 0; d < n; ++d) { ShapeExtension.Bind(ref a.Shape[removeA[d]], ref b.Shape[removeB[d]]); } var keptX = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d)).ToArray(); var keptY = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d)).ToArray(); // Move the axes to sum over to the end of "a" var keptA = Enumerable.Range(0, a.NDim).Where(d => !removeA.Contains(d)); var at = a.DimShuffle(keptA.Concat(removeA).ToArray()); // Move the axes to sum over to the front of "b" var keptB = Enumerable.Range(0, b.NDim).Where(d => !removeB.Contains(d)); var bt = b.DimShuffle(removeB.Concat(keptB).ToArray()); var resultShape = keptA.Select(axis => a.Shape[axis]).Concat(keptB.Select(axis => b.Shape[axis])).ToArray(); var a2d = Reshape2D(at, a.NDim - n); var b2d = Reshape2D(bt, n); var res = Op.Dot(a2d, b2d).Reshape(resultShape); res.Comment = "TensorDot"; return(res); //return new TensorDot(a, ToAxes(_axesX), b, ToAxes(_axesY)); }
private Dot(Tensor <float> x, Tensor <float> y, bool transposeX = false, bool transposeY = false) : base("Dot", x, y, transposeX.Named("transA"), transposeY.Named("transB")) { this.TransposeX = transposeX; this.TransposeY = transposeY; if (x.NDim == 1 && y.NDim == 1) { // rowV dot rowV (forbidden) if (TransposeX && TransposeY) { throw Rank(this, "Can't dot two row vectors: {0} and {1}"); } // colV dot colV (used as inner product) if (!TransposeX && !TransposeY) { TransposeX = true; } // colV dot rowV (outer product) if (!TransposeX && TransposeY) { _shape = new Dim[] { x.Shape[0], y.Shape[0] } } ; // rowV dot colV (inner product) if (TransposeX && !TransposeY) { if (!x.Shape[0].CanEqualTo(y.Shape[0])) { throw Rank(this, "Can't dot {0} with {1}"); } ShapeExtension.Bind(ref x.Shape[0], ref y.Shape[0]); _shape = new Dim[] { }; } return; } // mat dot colV if (y.NDim == 1) { TransposeY = false; if (!x.Shape[TransposeX ? 0 : x.NDim - 1].CanEqualTo(y.Shape[0])) { throw Rank(this, "Can't dot {0} with {1}"); } ShapeExtension.Bind(ref x.Shape[TransposeX ? 0 : x.NDim - 1], ref y.Shape[0]); _shape = (TransposeX ? x.Shape.Reverse().ToArray() : x.Shape).DropRight(1); return; } else if (y.NDim == 0) { TransposeY = false; _shape = x.Shape.ToArray(); } // mat dot mat else { if (x.NDim == 1) { TransposeX = true; } _shape = new Dim[x.NDim + y.NDim - 2]; var axisX = x.Shape[TransposeX ? 0 : x.NDim - 1]; var axisY = y.Shape[TransposeY ? 1 : y.NDim - 2]; if (!axisX.CanEqualTo(axisY)) { throw Rank(this, "Can't dot {0} with {1}"); } ShapeExtension.Bind(ref axisX, ref axisY); // copy x dims but (n - 1) if (!TransposeX) { for (int i = 0; i < x.NDim - 1; ++i) { _shape[i] = x.Shape[i]; } } else { for (int i = 0; i < x.NDim - 1; ++i) { _shape[i] = x.Shape[x.NDim - 1 - i]; } } // copy y dims but (n - 2) if (!TransposeY) { for (int i = 0; i < y.NDim - 2; ++i) { _shape[x.NDim + i] = y.Shape[i]; } _shape[x.NDim + y.NDim - 3] = y.Shape[y.NDim - 1]; } else { for (int i = 0; i < y.NDim - 2; ++i) { _shape[x.NDim + i] = y.Shape[y.NDim - 1 - i]; } _shape[x.NDim + y.NDim - 3] = y.Shape[0]; } } }