public override Tensor Normalization(Tensor X, Tensor S, Tensor B, int pool, int axis, float epsilon) { if (axis != 3 && axis != -1) { throw new NotImplementedException(); } if (pool <= 0) { pool = X.batch; } if (pool > 1) { throw new NotImplementedException(); // @TODO: support other types of Normalization at test time } // Currently supported only pool=1 (InstanceNormalization) var meanVariance = GlobalAvgVariancePool2D(X); var O = NewTensor(X.shape); var fn = BestKernel(ComputeKernelLibrary.NormalizationTail(X.shape, O.shape)); fn.SetTensor("X", X.shape, Pin(X).buffer); fn.SetTensor("O", O.shape, Pin(O).buffer); fn.SetTensor("W", meanVariance.shape, Pin(meanVariance).buffer); fn.shader.SetFloat("_Epsilon", epsilon); fn.Dispatch(); return(ScaleBias(O, S, B)); }