static (TorchTensor, TorchTensor) GetBatch(TorchTensor source, int index, int bptt) { var len = Math.Min(bptt, source.shape[0] - 1 - index); var data = source[TorchTensorIndex.Slice(index, index + len)]; var target = source[TorchTensorIndex.Slice(index + 1, index + 1 + len)].reshape(-1); return(data, target); }
public PositionalEncoding(long dmodel, double dropout, int maxLen = 5000) : base("PositionalEncoding") { this.dropout = Dropout(dropout); var pe = Float32Tensor.zeros(new long[] { maxLen, dmodel }); var position = Float32Tensor.arange(0, maxLen, 1).unsqueeze(1); var divTerm = (Float32Tensor.arange(0, dmodel, 2) * (-Math.Log(10000.0) / dmodel)).exp(); pe[TorchTensorIndex.Ellipsis, TorchTensorIndex.Slice(0, null, 2)] = (position * divTerm).sin(); pe[TorchTensorIndex.Ellipsis, TorchTensorIndex.Slice(1, null, 2)] = (position * divTerm).cos(); this.pe = pe.unsqueeze(0).transpose(0, 1); RegisterComponents(); }
public override TorchTensor forward(TorchTensor t) { var x = t + pe[TorchTensorIndex.Slice(null, t.shape[0]), TorchTensorIndex.Slice()]; return(dropout.forward(x)); }