/// <summary> /// Returns the default initialization bounds based on the initialization type, fan-in and fan-out. /// </summary> /// <param name="initialization"></param> /// <param name="fans"></param> /// <returns></returns> public static float InitializationBound(Initialization initialization, FanInFanOut fans) { switch (initialization) { case Initialization.GlorotUniform: return((float)Math.Sqrt(6.0 / (fans.FanIn + fans.FanOut))); case Initialization.HeUniform: return((float)Math.Sqrt(2.0 / fans.FanIn)); case Initialization.GlorotNormal: return((float)Math.Sqrt(6.0 / (fans.FanIn + fans.FanOut))); case Initialization.HeNormal: return((float)Math.Sqrt(2.0 / fans.FanIn)); default: throw new ArgumentException("Unsupported Initialization type: " + initialization); } }
/// <summary> /// Calculates the distribution /// </summary> /// <param name="initialization"></param> /// <param name="fans"></param> /// <param name="random"></param> /// <returns></returns> public static IContinuousDistribution GetWeightDistribution(Initialization initialization, FanInFanOut fans, Random random) { var bound = InitializationBound(initialization, fans); switch (initialization) { case Initialization.GlorotUniform: return(new ContinuousUniform(-bound, bound, new Random(random.Next()))); case Initialization.HeUniform: return(new ContinuousUniform(-bound, bound, new Random(random.Next()))); case Initialization.GlorotNormal: return(new Normal(0.0, bound, new Random(random.Next()))); case Initialization.HeNormal: return(new Normal(0.0, bound, new Random(random.Next()))); default: throw new ArgumentException("Unsupported Initialization type: " + initialization); } }