static public int[] Get8DParametersFromNHWCParametersAndShape(TensorShape shape, int[] parameters, int defaultValue) { if (parameters.Length == TensorShape.MaxRank) { return(parameters); } Assert.AreEqual(4, parameters.Length); Assert.IsTrue(shape.IsNHWC(), $"4D NCHW Parameters {parameters} can't be used with a tensor of shape {shape} as it contains other dimensions, please use 8D parameters for this shape."); return(new int[] { defaultValue, defaultValue, parameters[0], defaultValue, defaultValue, parameters[1], parameters[2], parameters[3] }); }
static public int[] GetNHWCParametersFrom8DParameterAndShape(TensorShape shape, int[] parameters) { if (parameters.Length == 4) { return(parameters); } Assert.IsTrue(shape.IsNHWC(), $"Parameters {parameters} can't be converted to NCHW with a tensor of shape {shape} as it contains other dimensions."); Assert.AreEqual(parameters.Length, TensorShape.MaxRank); return(new int[] { parameters[TensorShape.DataBatch], parameters[TensorShape.H], parameters[TensorShape.W], parameters[TensorShape.C] }); }
static public int[] Get8DPermutationsForNHWCPermutationsAndShape(TensorShape shape, int[] permutations) { if (permutations.Length == TensorShape.MaxRank) { return(permutations); } Assert.AreEqual(4, permutations.Length); Assert.IsTrue(shape.IsNHWC(), $"4D NCHW Permutation {permutations} can't be used with a tensor of shape {shape} as it contains other dimensions, please use an 8D permutation for this shape."); int batchOldAxis = NHWCTo8DAxis(permutations[0]); int heighOldAxis = NHWCTo8DAxis(permutations[1]); int widthOldIndex = NHWCTo8DAxis(permutations[2]); int channeOldIndex = NHWCTo8DAxis(permutations[3]); return(new int[] { 0, 1, batchOldAxis, 3, 4, heighOldAxis, widthOldIndex, channeOldIndex }); }