Ejemplo n.º 1
0
                    public static Tensor dropout(Tensor x, double p = 0.5, bool training = true)
                    {
                        if (!training)
                        {
                            return(x);
                        }
                        switch (x.dtype)
                        {
                        case torch.float16:
                        {
                            var y       = new Tensor(x.__shape, x.dtype, (!torch.autograd.grad_mode.no_grad.prev) && x.requires_grad);
                            var dropmap = new bool[x.__half.Length];
                            MKL.Dropout(x.__half, p, y.__half, dropmap);
                            if (y.requires_grad)
                            {
                                y.__backward_fn = () =>
                                {
                                    MKL.dDropout(x.grad.__half, p, dropmap, y.grad.__half);
                                    if (x.__backward_fn != null)
                                    {
                                        x.__backward_fn();
                                    }
                                };
                            }
                            return(y);
                        }

                        case torch.float32:
                        {
                            var y       = new Tensor(x.__shape, x.dtype, (!torch.autograd.grad_mode.no_grad.prev) && x.requires_grad);
                            var dropmap = new bool[x.__float.Length];
                            MKL.Dropout(x.__float, p, y.__float, dropmap);
                            if (y.requires_grad)
                            {
                                y.__backward_fn = () =>
                                {
                                    MKL.dDropout(x.grad.__float, p, dropmap, y.grad.__float);
                                    if (x.__backward_fn != null)
                                    {
                                        x.__backward_fn();
                                    }
                                };
                            }
                            return(y);
                        }

                        case torch.float64:
                        {
                            var y       = new Tensor(x.__shape, x.dtype, (!torch.autograd.grad_mode.no_grad.prev) && x.requires_grad);
                            var dropmap = new bool[x.__double.Length];
                            MKL.Dropout(x.__double, p, y.__double, dropmap);
                            if (y.requires_grad)
                            {
                                y.__backward_fn = () =>
                                {
                                    MKL.dDropout(x.grad.__double, p, dropmap, y.grad.__double);
                                    if (x.__backward_fn != null)
                                    {
                                        x.__backward_fn();
                                    }
                                };
                            }
                            return(y);
                        }

                        case torch.int8:
                        case torch.uint8:
                        case torch.int16:
                        case torch.int32:
                        case torch.int64:
                        {
                            throw new TorchException("TorchException: nn.functional.dropout is not implemented for integer tensors.");
                        }

                        case torch.@bool:
                        {
                            throw new TorchException("TorchException: nn.functional.dropout is not implemented for bool tensors.");
                        }

                        default:
                        {
                            return(null);
                        }
                        }
                    }