コード例 #1
0
ファイル: AssertTensor.cs プロジェクト: stuarthillary/TheaNet
        public static void BindToShape(this ITensor thiz, Scalar <int>[] shape)
        {
            var a = thiz.Shape;

            thiz.AssertOfDim(shape.Length);
            for (int d = 0; d < thiz.NDim; ++d)
            {
                ShapeExtension.Bind(ref a[d], ref shape[d]);
            }
        }
コード例 #2
0
 public void Add_(IExpr expr, IExpr value)
 {
     if (expr is ITensor tensor && PreserveShape)
     {
         if (!ShapeExtension.CanEqualTo((value as ITensor).Shape, tensor.Shape))
         {
             throw new ArgumentException($"Can't patch {expr} with {value}");
         }
     }
     _substitutions[expr] = value;
 }
コード例 #3
0
ファイル: AssertTensor.cs プロジェクト: stuarthillary/TheaNet
        public static void AssertOfShape <T>(this ITensor <T> thiz, params Scalar <int>[] shape)
        {
            var a = thiz.Shape;

            if (thiz.NDim != shape.Length)
            {
                throw RankException("{0} of shape {1}, won't match with: {2}", thiz, thiz.Shape.Format(thiz), shape.Format(thiz));
            }

            for (int d = 0; d < thiz.NDim; ++d)
            {
                if (!ShapeExtension.CanEqualTo(a[d], shape[d]))
                {
                    throw RankException("{0} of shape {1}, won't match with: {2}", thiz, thiz.Shape.Format(thiz), shape.Format(thiz));
                }
            }
        }
コード例 #4
0
            private Elementwise(Tensor <Type>[] inputs, Scalar <Type> .Var[] vars, Scalar <Type> abstraction, Dim[] shape) :
                base("Elementwise", inputs)
            {
                if (inputs.Length == 0)
                {
                    throw new ArgumentException("Need at least one input");
                }
                if (inputs.Length != vars.Length)
                {
                    throw new ArgumentException("Need one captured by inputs");
                }
                var nDim = inputs.First().NDim;

                if (!inputs.All(x => x.NDim == nDim))
                {
                    throw new RankException($"Dims don't match: [{string.Join(", ", inputs.Select(_ => _.NDim))}]");
                }

                this.Vars   = vars;
                this.Inputs = inputs;

                this.Abstraction = abstraction;
                broadcast        = new Dictionary <Tensor <Type>, List <int> >(inputs.Length);
                foreach (var x in inputs)
                {
                    broadcast[x] = new List <int>();
                }

                // checks and binds shape
                Shape = shape;
                for (int d = 0; d < nDim; ++d)
                {
                    foreach (var x in inputs)
                    {
                        if (x.Shape[d].NeedBroadcast(this.Shape[d]))
                        {
                            broadcast[x].Add(d);
                        }
                        else
                        {
                            ShapeExtension.Bind(ref this.Shape[d], ref x.Shape[d]);
                        }
                    }
                }
            }