示例#1
0
        public static Tensor <Type> Create(Tensor <Type> x, IReadOnlyList <XSlice> slices)
        {
            if (slices.Count == 0)
            {
                throw new RankException("no slices provided");
            }

            if (x is BroadCast <Type> xBroad) // && slices.Count == 1 && slices[0].IsSingleton()
            {
                var broadcasted   = xBroad.broadcast;
                var needBroadcast = new List <int>();
                var keptSlices    = new List <XSlice>();

                for (int d = 0; d < slices.Count; ++d)
                {
                    if (broadcasted.Contains(d))
                    {
                        if (slices[d].IsSingleton)
                        {
                            keptSlices.Add(0);
                        }
                        else
                        {
                            keptSlices.Add(XSlicer._);
                            needBroadcast.Add(d);
                        }
                    }
                    else
                    {
                        keptSlices.Add(slices[d]);
                    }
                }

                if (needBroadcast.Count == 0)
                {
                    return(xBroad.x[keptSlices]);
                }
                else
                {
                    return(new Slicing <Type>(x, slices.ToArray()));
                }
            }

            var result = new Slicing <Type>(x, slices.ToArray());

            switch (x)
            {
            case Fill <Type> fill:
                return(Op.ConstLike(fill.x, result));

            default:
                return(result);
            }
        }
示例#2
0
        /// <summary>Advanced indexing</summary>
        /// <remarks>`indices.Length` should match `x.NDim`</remarks>
        /// <param name="x">the array to index</param>
        /// <param name="indices">the indices to take from `x`.</param>
        public static Tensor <Type> Create(Tensor <Type> x, Tensor <int>[] indices)
        {
            var result = new Indexing <Type>(x, indices);

            switch (x)
            {
            case Fill <Type> fill:
                return(Op.ConstLike(fill.x, result));

            default:
                return(result);
            }
        }
示例#3
0
        static public Tensor <Type> Create(Tensor <Type> x, int axis)
        {
            axis = axis < 0 ? axis + x.NDim : axis;
            var result = new Sum <Type>(x, axis);

            switch (x)
            {
            case Fill <Type> fill:
                return(Op.ConstLike(fill.x * x.Shape[axis].As <Type>(), result));

            case OneHot <Type> oneHot:
                return(axis == 0 ? oneHot.Content.Reshape(result.Shape) : result);

            case OneHotPoint <Type> oneHotPoint:
                var point = ((Scalar <int>[])oneHotPoint.Indexes).ToArray();
                point[axis] = 0;
                return(Op.OneHot(result.Shape, point, oneHotPoint.Content));

            default:
                return(result);
            }
        }