C.Function get_mask_and_infer_from_last_dimension(C.Function inputs, C.Function mask) { if (mask == null) { var inputs_shape = inputs.Output.Shape.Dimensions.ToArray(); var ndims = inputs_shape.Length - 1; var x = CC.Sqrt(CC.ReduceSum(CC.Square(inputs), new C.Axis(ndims - 1))); x = CC.Squeeze(x); System.Diagnostics.Debug.Assert(x.Output.Shape.Dimensions.Count == 1); x = CC.Argmax(x, new C.Axis(0)); mask = CC.OneHotOp(x, numClass: (uint)inputs_shape[0], outputSparse: false, axis: new C.Axis(0)); } mask = CC.Reshape(mask, mask.Output.Shape.AppendShape(new int[] { 1 })); var masked = CC.ElementTimes(inputs, mask); masked = CC.Flatten(masked); masked = CC.Squeeze(masked); return(masked); }