public static void BindToShape(this ITensor thiz, Scalar <int>[] shape) { var a = thiz.Shape; thiz.AssertOfDim(shape.Length); for (int d = 0; d < thiz.NDim; ++d) { ShapeExtension.Bind(ref a[d], ref shape[d]); } }
private Elementwise(Tensor <Type>[] inputs, Scalar <Type> .Var[] vars, Scalar <Type> abstraction, Dim[] shape) : base("Elementwise", inputs) { if (inputs.Length == 0) { throw new ArgumentException("Need at least one input"); } if (inputs.Length != vars.Length) { throw new ArgumentException("Need one captured by inputs"); } var nDim = inputs.First().NDim; if (!inputs.All(x => x.NDim == nDim)) { throw new RankException($"Dims don't match: [{string.Join(", ", inputs.Select(_ => _.NDim))}]"); } this.Vars = vars; this.Inputs = inputs; this.Abstraction = abstraction; broadcast = new Dictionary <Tensor <Type>, List <int> >(inputs.Length); foreach (var x in inputs) { broadcast[x] = new List <int>(); } // checks and binds shape Shape = shape; for (int d = 0; d < nDim; ++d) { foreach (var x in inputs) { if (x.Shape[d].NeedBroadcast(this.Shape[d])) { broadcast[x].Add(d); } else { ShapeExtension.Bind(ref this.Shape[d], ref x.Shape[d]); } } } }