Exemplo n.º 1
0
            public void Add(string name, torch.nn.Module submodule)
            {
                Debug.Assert(!handle.IsInvalid);
                if (submodule.BoxedModule == null)
                {
                    throw new InvalidOperationException("A Sequential or loaded module may not be added to a Sequential");
                }

                THSNN_Sequential_push_back(handle, name, submodule.BoxedModule.handle);
                torch.CheckForErrors();
            }
Exemplo n.º 2
0
        //public static void FreezeModuleParams(ModuleList modules)
        //{
        //    foreach (var module in modules)
        //    {
        //        FreezeModuleParams(module);
        //    }
        //}

        public static void FreezeModuleParams(torch.nn.Module module)
        {
            if (module is null)
            {
                return;
            }
            foreach (var param in module.parameters())
            {
                param.requires_grad = false;
            }
        }
Exemplo n.º 3
0
 public ActivationFunction(string name) : base(name)
 {
     _function = name?.ToLower() switch
     {
         "relu" => torch.nn.ReLU(),
         "gelu" => torch.nn.GELU(),
         "gelu_fast" => new GeLUFast(),
         "tanh" => torch.nn.Tanh(),
         "linear" => torch.nn.Identity(),
         _ => throw new NotSupportedException($"Activation function {name} not supported.")
     };
 }
Exemplo n.º 4
0
            internal void Add(string name, torch.nn.Module submodule)
            {
                Debug.Assert(!handle.IsInvalid);
                if (submodule.BoxedModule == null)
                {
                    throw new InvalidOperationException("A Sequential or loaded module may not be added to a Sequential");
                }

                THSNN_Sequential_push_back(handle, name, submodule.BoxedModule.handle);
                torch.CheckForErrors();
                // Keep the sub-module alive for at least as long as the Sequential object is alive.
                _modules.Add(submodule);
                _names.Add(name);
            }
Exemplo n.º 5
0
        private static torch.Tensor ForwardOneLayer(torch.Tensor input, torch.Tensor selfAttentionPaddingMask,
                                                    torch.nn.Module convLayer, torch.nn.Module layerNorm)
        {
            using var disposeScope = torch.NewDisposeScope();

            torch.Tensor x = selfAttentionPaddingMask.IsNull()
                ? input.alias()
                : input.masked_fill(selfAttentionPaddingMask.T.unsqueeze(-1), 0);

            var conv = convLayer.forward(x);

            conv.add_(input);
            var norm = layerNorm.forward(conv);

            return(norm.MoveToOuterDisposeScope());
        }
Exemplo n.º 6
0
            internal void Add(torch.nn.Module module)
            {
                var name = _modules.Count.ToString();

                Add(name, module);
            }