public static Tensor <Type> Create(Tensor <Type> x, IReadOnlyList <XSlice> slices) { if (slices.Count == 0) { throw new RankException("no slices provided"); } if (x is BroadCast <Type> xBroad) // && slices.Count == 1 && slices[0].IsSingleton() { var broadcasted = xBroad.broadcast; var needBroadcast = new List <int>(); var keptSlices = new List <XSlice>(); for (int d = 0; d < slices.Count; ++d) { if (broadcasted.Contains(d)) { if (slices[d].IsSingleton) { keptSlices.Add(0); } else { keptSlices.Add(XSlicer._); needBroadcast.Add(d); } } else { keptSlices.Add(slices[d]); } } if (needBroadcast.Count == 0) { return(xBroad.x[keptSlices]); } else { return(new Slicing <Type>(x, slices.ToArray())); } } var result = new Slicing <Type>(x, slices.ToArray()); switch (x) { case Fill <Type> fill: return(Op.ConstLike(fill.x, result)); default: return(result); } }
/// <summary>Advanced indexing</summary> /// <remarks>`indices.Length` should match `x.NDim`</remarks> /// <param name="x">the array to index</param> /// <param name="indices">the indices to take from `x`.</param> public static Tensor <Type> Create(Tensor <Type> x, Tensor <int>[] indices) { var result = new Indexing <Type>(x, indices); switch (x) { case Fill <Type> fill: return(Op.ConstLike(fill.x, result)); default: return(result); } }
static public Tensor <Type> Create(Tensor <Type> x, int axis) { axis = axis < 0 ? axis + x.NDim : axis; var result = new Sum <Type>(x, axis); switch (x) { case Fill <Type> fill: return(Op.ConstLike(fill.x * x.Shape[axis].As <Type>(), result)); case OneHot <Type> oneHot: return(axis == 0 ? oneHot.Content.Reshape(result.Shape) : result); case OneHotPoint <Type> oneHotPoint: var point = ((Scalar <int>[])oneHotPoint.Indexes).ToArray(); point[axis] = 0; return(Op.OneHot(result.Shape, point, oneHotPoint.Content)); default: return(result); } }