Showing
1 changed file
with
51 additions
and
0 deletions
2DCNN/lib/utils/samplers.py
0 → 100644
1 | +""" torch samplers for different distributions""" | ||
2 | + | ||
3 | +import numpy as np | ||
4 | +import torch | ||
5 | +from scipy.linalg import circulant | ||
6 | + | ||
7 | + | ||
8 | +def sample_gaussian(mean, sigma, tril_sigma=False): | ||
9 | + noise = torch.randn_like(mean) | ||
10 | + | ||
11 | + # we getting sigma | ||
12 | + if tril_sigma: | ||
13 | + z_sample = torch.bmm(sigma, noise.unsqueeze(dim=2)).squeeze() + mean | ||
14 | + else: | ||
15 | + z_sample = noise * sigma + mean | ||
16 | + return z_sample | ||
17 | + | ||
18 | + | ||
19 | +def sample_echo(f, s, m=None, replace=False, pop=True): | ||
20 | + """ | ||
21 | + f, s : are the outputs of encoder (shape : [B, Z] for f) | ||
22 | + s is shape [B, Z] or [B, Z, Z] | ||
23 | + tril_sigma: if we have s as diagonal matrix or lt matrix | ||
24 | + m : number of samples to consider to generate noise when replace | ||
25 | + is true (default to batch_size) | ||
26 | + replace : sampling with replacement or not (if sampling with | ||
27 | + replacement, pop is not considered) | ||
28 | + pop: If true, remove the sample to which noise is being added | ||
29 | + detach_noise_grad : detach gradient of noise or not | ||
30 | + """ | ||
31 | + batch_size, z_size = f.shape[0], f.shape[1:] | ||
32 | + | ||
33 | + # get indices | ||
34 | + if not replace: | ||
35 | + indices = circulant(np.arange(batch_size)) | ||
36 | + if pop: | ||
37 | + # just ignore the first column | ||
38 | + indices = indices[:, 1:] | ||
39 | + for i in indices: | ||
40 | + np.random.shuffle(i) | ||
41 | + else: | ||
42 | + m = batch_size if m is None else m | ||
43 | + indices = np.random.choice(batch_size, size=(batch_size, m), replace=True) | ||
44 | + | ||
45 | + f_arr = f[indices.reshape(-1)].view(indices.shape + z_size) | ||
46 | + s_arr = s[indices.reshape(-1)].view(indices.shape + z_size) | ||
47 | + | ||
48 | + epsilon = f_arr[:, 0] + torch.sum(f_arr[:, 1:] * torch.cumprod(s_arr[:, :-1], dim=1), dim=1) | ||
49 | + | ||
50 | + z_sample = f + s * epsilon | ||
51 | + return z_sample |
-
Please register or login to post a comment