Example #1
0
        C.Function get_mask_and_infer_from_last_dimension(C.Function inputs, C.Function mask)
        {
            if (mask == null)
            {
                var inputs_shape = inputs.Output.Shape.Dimensions.ToArray();
                var ndims        = inputs_shape.Length - 1;
                var x            = CC.Sqrt(CC.ReduceSum(CC.Square(inputs), new C.Axis(ndims - 1)));
                x = CC.Squeeze(x);
                System.Diagnostics.Debug.Assert(x.Output.Shape.Dimensions.Count == 1);
                x    = CC.Argmax(x, new C.Axis(0));
                mask = CC.OneHotOp(x, numClass: (uint)inputs_shape[0], outputSparse: false, axis: new C.Axis(0));
            }
            mask = CC.Reshape(mask, mask.Output.Shape.AppendShape(new int[] { 1 }));
            var masked = CC.ElementTimes(inputs, mask);

            masked = CC.Flatten(masked);
            masked = CC.Squeeze(masked);
            return(masked);
        }