Hyunji

samplers

1 +""" torch samplers for different distributions"""
2 +
3 +import numpy as np
4 +import torch
5 +from scipy.linalg import circulant
6 +
7 +
8 +def sample_gaussian(mean, sigma, tril_sigma=False):
9 + noise = torch.randn_like(mean)
10 +
11 + # we getting sigma
12 + if tril_sigma:
13 + z_sample = torch.bmm(sigma, noise.unsqueeze(dim=2)).squeeze() + mean
14 + else:
15 + z_sample = noise * sigma + mean
16 + return z_sample
17 +
18 +
19 +def sample_echo(f, s, m=None, replace=False, pop=True):
20 + """
21 + f, s : are the outputs of encoder (shape : [B, Z] for f)
22 + s is shape [B, Z] or [B, Z, Z]
23 + tril_sigma: if we have s as diagonal matrix or lt matrix
24 + m : number of samples to consider to generate noise when replace
25 + is true (default to batch_size)
26 + replace : sampling with replacement or not (if sampling with
27 + replacement, pop is not considered)
28 + pop: If true, remove the sample to which noise is being added
29 + detach_noise_grad : detach gradient of noise or not
30 + """
31 + batch_size, z_size = f.shape[0], f.shape[1:]
32 +
33 + # get indices
34 + if not replace:
35 + indices = circulant(np.arange(batch_size))
36 + if pop:
37 + # just ignore the first column
38 + indices = indices[:, 1:]
39 + for i in indices:
40 + np.random.shuffle(i)
41 + else:
42 + m = batch_size if m is None else m
43 + indices = np.random.choice(batch_size, size=(batch_size, m), replace=True)
44 +
45 + f_arr = f[indices.reshape(-1)].view(indices.shape + z_size)
46 + s_arr = s[indices.reshape(-1)].view(indices.shape + z_size)
47 +
48 + epsilon = f_arr[:, 0] + torch.sum(f_arr[:, 1:] * torch.cumprod(s_arr[:, :-1], dim=1), dim=1)
49 +
50 + z_sample = f + s * epsilon
51 + return z_sample