示例#1
0
            public Tensor forward(Tensor input, Tensor?h0 = null)
            {
                var hN = THSNN_RNNCell_forward(handle, input.Handle, h0?.Handle ?? IntPtr.Zero);

                if (hN == IntPtr.Zero)
                {
                    torch.CheckForErrors();
                }
                return(new Tensor(hN));
            }
示例#2
0
                public static Tensor binary_cross_entropy(Tensor src, Tensor target, Tensor?weight = null, Reduction reduction = Reduction.Mean)
                {
                    var res = THSNN_binary_cross_entropy(src.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, (long)reduction);

                    if (res == IntPtr.Zero)
                    {
                        torch.CheckForErrors();
                    }
                    return(new Tensor(res));
                }
示例#3
0
            public (Tensor, Tensor) forward(Tensor input, Tensor?h0 = null)
            {
                var res = THSNN_GRU_forward(handle, input.Handle, h0?.Handle ?? IntPtr.Zero, out IntPtr hN);

                if (res == IntPtr.Zero || hN == IntPtr.Zero)
                {
                    torch.CheckForErrors();
                }
                return(new Tensor(res), new Tensor(hN));
            }
示例#4
0
 public static Loss binary_cross_entropy_with_logits_loss(Tensor?weight = null, Reduction reduction = Reduction.Mean, Tensor?posWeights = null)
 {
     return((Tensor src, Tensor target) => {
         var res = THSNN_binary_cross_entropy_with_logits(src.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, (long)reduction, posWeights?.Handle ?? IntPtr.Zero);
         if (res == IntPtr.Zero)
         {
             torch.CheckForErrors();
         }
         return new Tensor(res);
     });
 }
示例#5
0
 public static Loss nll_loss(Tensor?weight = null, Reduction reduction = Reduction.Mean)
 {
     return((Tensor src, Tensor target) => {
         var res = THSNN_nll_loss(src.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, (long)reduction);
         if (res == IntPtr.Zero)
         {
             torch.CheckForErrors();
         }
         return new Tensor(res);
     });
 }
示例#6
0
 public static Loss cross_entropy_loss(Tensor?weight = null, long?ignore_index = null, Reduction reduction = Reduction.Mean)
 {
     return((Tensor src, Tensor target) => {
         var ii = ignore_index.HasValue ? ignore_index.Value : -100;
         var res = THSNN_cross_entropy(src.Handle, target.Handle, weight?.Handle ?? IntPtr.Zero, ii, ignore_index.HasValue, (long)reduction);
         if (res == IntPtr.Zero)
         {
             torch.CheckForErrors();
         }
         return new Tensor(res);
     });
 }
示例#7
0
                public static Loss multi_margin_loss(int p = 1, double margin = 1.0, Tensor?weight = null, Reduction reduction = Reduction.Mean)
                {
                    IntPtr h = (weight is null) ? IntPtr.Zero : weight.Handle;

                    return((Tensor input, Tensor target) => {
                        var res = THSNN_multi_margin_loss(input.Handle, target.Handle, p, margin, h, (long)reduction);
                        if (res == IntPtr.Zero)
                        {
                            torch.CheckForErrors();
                        }
                        return new Tensor(res);
                    });
                }
示例#8
0
            public Tensor forward(Tensor src, Tensor?src_mask = null, Tensor?src_key_padding_mask = null)
            {
                var res = THSNN_TransformerEncoder_forward(handle,
                                                           src.Handle,
                                                           src_mask?.Handle ?? IntPtr.Zero,
                                                           src_key_padding_mask?.Handle ?? IntPtr.Zero);

                if (res == IntPtr.Zero)
                {
                    torch.CheckForErrors();
                }
                return(new Tensor(res));
            }
示例#9
0
            public Tensor forward(Tensor tgt, Tensor memory, Tensor?tgt_mask = null, Tensor?memory_mask = null, Tensor?tgt_key_padding_mask = null, Tensor?memory_key_padding_mask = null)
            {
                var res = THSNN_TransformerDecoder_forward(handle,
                                                           tgt.Handle,
                                                           memory.Handle,
                                                           tgt_mask?.Handle ?? IntPtr.Zero,
                                                           memory_mask?.Handle ?? IntPtr.Zero,
                                                           tgt_key_padding_mask?.Handle ?? IntPtr.Zero,
                                                           memory_key_padding_mask?.Handle ?? IntPtr.Zero);

                if (res == IntPtr.Zero)
                {
                    torch.CheckForErrors();
                }
                return(new Tensor(res));
            }
示例#10
0
            public new (Tensor, Tensor) forward(Tensor input, Tensor?h0 = null)
            {
                if (h0 is null)
                {
                    var N = _batch_first ? input.shape[0] : input.shape[1];
                    var D = _bidirectional ? 2 : 1;

                    h0 = torch.zeros(new long[] { D *_num_layers, N, _hidden_size });
                }

                var res = THSNN_RNN_forward(handle, input.Handle, h0.Handle, out IntPtr hN);

                if (res == IntPtr.Zero || hN == IntPtr.Zero)
                {
                    torch.CheckForErrors();
                }
                return(new Tensor(res), new Tensor(hN));
            }
示例#11
0
文件: RNN.cs 项目: dsyme/TorchSharp
 public void set_bias_hh(Tensor?value, long idx)
 {
     THSNN_RNN_set_bias_hh(handle, (value is null ? IntPtr.Zero : value.Handle), idx);
     torch.CheckForErrors();
     ConditionallyRegisterParameter($"bias_hh_l{idx}", value);
 }
示例#12
0
 public Tuple <Tensor, Tensor> forward(Tensor query, Tensor key, Tensor value, Tensor?key_padding_mask = null, bool need_weights = true, Tensor?attn_mask = null)
 {
     THSNN_MultiheadAttention_forward(handle,
                                      query.Handle,
                                      key.Handle,
                                      value.Handle,
                                      key_padding_mask?.Handle ?? IntPtr.Zero,
                                      need_weights,
                                      attn_mask?.Handle ?? IntPtr.Zero,
                                      out var res1,
                                      out var res2);
     if (res1 == IntPtr.Zero || (need_weights && res2 == IntPtr.Zero))
     {
         torch.CheckForErrors();
     }
     return(Tuple.Create(new Tensor(res1), new Tensor(res2)));
 }