/// <summary> /// Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples /// if the distribution parameters are batched. /// </summary> /// <param name="sample_shape">The sample shape.</param> public override Tensor rsample(params long[] sample_shape) { var cShape = new List <long>(); cShape.Add(total_count); cShape.AddRange(sample_shape); var samples = categorical.sample(cShape.ToArray()); var shifted_idx = Enumerable.Range(0, (int)samples.dim()).ToList(); var tc = shifted_idx[0]; shifted_idx.RemoveAt(0); shifted_idx.Add(tc); samples = samples.permute(shifted_idx.Select(i => (long)i).ToArray()); var counts = samples.new_zeros(ExtendedShape(sample_shape)); counts.scatter_add_(-1, samples, torch.ones_like(samples)); return(counts.type_as(probs)); }