김민석

B2I

Showing 506 changed files with 875 additions and 0 deletions
1 +#folder
2 +experiments/ecg/dataset/preprocessed/ano0
3 +
4 +experiments/ecg/output/beatgan/ecg/model/
5 +
6 +
1 +/workspace/2Dtest/1234/experiments/ecg/plotUtil.py:213: fixed
This diff is collapsed. Click to expand it.
No preview for this file type
No preview for this file type
1 +Execution change.py and change2.py file for change __samples.npy to __spectogram.npy.
2 +
3 +The shape is changed n*2*320 to n*128*128 to n*1*128*128
1 +
2 +import librosa
3 +import numpy as np
4 +import matplotlib.pyplot as plt
5 +import librosa, librosa.display
6 +import cv2
7 +
8 +n_data = np.load('N_samples.npy')
9 +s_data = np.load('S_samples.npy')
10 +v_data = np.load('V_samples.npy')
11 +f_data = np.load('F_samples.npy')
12 +q_data = np.load('Q_samples.npy')
13 +
14 +
15 +
16 +n_fft_n= 256
17 +win_length_n=64
18 +hp_length_n=2
19 +sr = 360
20 +
21 +data =n_data #데이터 종류
22 +
23 +lst = [] #npy로 저장할 데이터들
24 +length = len(data) #출력할 데이터 개수
25 +
26 +
27 +for i in range(length):
28 + #원래 ECG 그래프 그리기
29 + #ax1 = fig1.add_subplot(length,2,2*(i+1)-1)
30 + #ax1.plot(data[i,0,:])
31 +
32 + # STFT 이미지 그리기
33 + #ax2 = fig1.add_subplot(length,2,2*(i+1))
34 +
35 + #STFT
36 + D_highres = librosa.stft(data[i,0,:].flatten(), n_fft=n_fft_n, hop_length=hp_length_n, win_length=win_length_n)
37 +
38 + #ampiltude로 변환
39 + magnitude = np.abs(D_highres)
40 +
41 + #amplitude를 db 스케일로 변환
42 + log_spectrogram = librosa.amplitude_to_db(magnitude)
43 +
44 + #화이트 노이즈 제거
45 + log_spectrogram = log_spectrogram[:,10:150]
46 +
47 + #128,128로 resize
48 + log_spectrogram = cv2.resize(log_spectrogram, (128,128), interpolation = cv2.INTER_AREA)
49 +
50 + #스펙트로그램 출력
51 + #img = librosa.display.specshow(log_spectrogram, sr=sr, hop_length = hp_length_n, ax=ax2, y_axis="linear", x_axis="time")
52 +
53 + #컬러바
54 + #fig.colorbar(img, ax=ax2)# format="%+2.f dB")
55 +
56 + #print(log_spectrogram.shape)
57 +
58 + lst.append(log_spectrogram)
59 + if i%30==0:
60 + print(i,'/',length)
61 +
62 +#npy로 저장
63 +lst = np.array(lst)
64 +output_filename = 'n_spectrogram'
65 +print(lst.shape)
66 +np.save(output_filename, lst)
67 +
68 +
69 +##########
70 +
71 +data =s_data #데이터 종류
72 +
73 +lst = [] #npy로 저장할 데이터들
74 +length = len(data) #출력할 데이터 개수
75 +
76 +
77 +for i in range(length):
78 + #원래 ECG 그래프 그리기
79 + #ax1 = fig1.add_subplot(length,2,2*(i+1)-1)
80 + #ax1.plot(data[i,0,:])
81 +
82 + # STFT 이미지 그리기
83 + #ax2 = fig1.add_subplot(length,2,2*(i+1))
84 +
85 + #STFT
86 + D_highres = librosa.stft(data[i,0,:].flatten(), n_fft=n_fft_n, hop_length=hp_length_n, win_length=win_length_n)
87 +
88 + #ampiltude로 변환
89 + magnitude = np.abs(D_highres)
90 +
91 + #amplitude를 db 스케일로 변환
92 + log_spectrogram = librosa.amplitude_to_db(magnitude)
93 +
94 + #화이트 노이즈 제거
95 + log_spectrogram = log_spectrogram[:,10:150]
96 +
97 + #128,128로 resize
98 + log_spectrogram = cv2.resize(log_spectrogram, (128,128), interpolation = cv2.INTER_AREA)
99 +
100 + #스펙트로그램 출력
101 + #img = librosa.display.specshow(log_spectrogram, sr=sr, hop_length = hp_length_n, ax=ax2, y_axis="linear", x_axis="time")
102 +
103 + #컬러바
104 + #fig.colorbar(img, ax=ax2)# format="%+2.f dB")
105 +
106 + #print(log_spectrogram.shape)
107 +
108 + lst.append(log_spectrogram)
109 + if i%30==0:
110 + print(i,'/',length)
111 +
112 +#npy로 저장
113 +lst = np.array(lst)
114 +output_filename = 's_spectrogram'
115 +print(lst.shape)
116 +np.save(output_filename, lst)
117 +
118 +##########
119 +
120 +data =v_data #데이터 종류
121 +
122 +lst = [] #npy로 저장할 데이터들
123 +length = len(data) #출력할 데이터 개수
124 +
125 +
126 +for i in range(length):
127 + #원래 ECG 그래프 그리기
128 + #ax1 = fig1.add_subplot(length,2,2*(i+1)-1)
129 + #ax1.plot(data[i,0,:])
130 +
131 + # STFT 이미지 그리기
132 + #ax2 = fig1.add_subplot(length,2,2*(i+1))
133 +
134 + #STFT
135 + D_highres = librosa.stft(data[i,0,:].flatten(), n_fft=n_fft_n, hop_length=hp_length_n, win_length=win_length_n)
136 +
137 + #ampiltude로 변환
138 + magnitude = np.abs(D_highres)
139 +
140 + #amplitude를 db 스케일로 변환
141 + log_spectrogram = librosa.amplitude_to_db(magnitude)
142 +
143 + #화이트 노이즈 제거
144 + log_spectrogram = log_spectrogram[:,10:150]
145 +
146 + #128,128로 resize
147 + log_spectrogram = cv2.resize(log_spectrogram, (128,128), interpolation = cv2.INTER_AREA)
148 +
149 + #스펙트로그램 출력
150 + #img = librosa.display.specshow(log_spectrogram, sr=sr, hop_length = hp_length_n, ax=ax2, y_axis="linear", x_axis="time")
151 +
152 + #컬러바
153 + #fig.colorbar(img, ax=ax2)# format="%+2.f dB")
154 +
155 + #print(log_spectrogram.shape)
156 +
157 + lst.append(log_spectrogram)
158 + if i%30==0:
159 + print(i,'/',length)
160 +
161 +#npy로 저장
162 +lst = np.array(lst)
163 +output_filename = 'v_spectrogram'
164 +print(lst.shape)
165 +np.save(output_filename, lst)
166 +
167 +##########
168 +
169 +data =f_data #데이터 종류
170 +
171 +lst = [] #npy로 저장할 데이터들
172 +length = len(data) #출력할 데이터 개수
173 +
174 +
175 +for i in range(length):
176 + #원래 ECG 그래프 그리기
177 + #ax1 = fig1.add_subplot(length,2,2*(i+1)-1)
178 + #ax1.plot(data[i,0,:])
179 +
180 + # STFT 이미지 그리기
181 + #ax2 = fig1.add_subplot(length,2,2*(i+1))
182 +
183 + #STFT
184 + D_highres = librosa.stft(data[i,0,:].flatten(), n_fft=n_fft_n, hop_length=hp_length_n, win_length=win_length_n)
185 +
186 + #ampiltude로 변환
187 + magnitude = np.abs(D_highres)
188 +
189 + #amplitude를 db 스케일로 변환
190 + log_spectrogram = librosa.amplitude_to_db(magnitude)
191 +
192 + #화이트 노이즈 제거
193 + log_spectrogram = log_spectrogram[:,10:150]
194 +
195 + #128,128로 resize
196 + log_spectrogram = cv2.resize(log_spectrogram, (128,128), interpolation = cv2.INTER_AREA)
197 +
198 + #스펙트로그램 출력
199 + #img = librosa.display.specshow(log_spectrogram, sr=sr, hop_length = hp_length_n, ax=ax2, y_axis="linear", x_axis="time")
200 +
201 + #컬러바
202 + #fig.colorbar(img, ax=ax2)# format="%+2.f dB")
203 +
204 + #print(log_spectrogram.shape)
205 +
206 + lst.append(log_spectrogram)
207 + if i%30==0:
208 + print(i,'/',length)
209 +
210 +#npy로 저장
211 +lst = np.array(lst)
212 +output_filename = 'f_spectrogram'
213 +print(lst.shape)
214 +np.save(output_filename, lst)
215 +
216 +##########
217 +
218 +data =q_data #데이터 종류
219 +
220 +lst = [] #npy로 저장할 데이터들
221 +length = len(data) #출력할 데이터 개수
222 +
223 +
224 +for i in range(length):
225 + #원래 ECG 그래프 그리기
226 + #ax1 = fig1.add_subplot(length,2,2*(i+1)-1)
227 + #ax1.plot(data[i,0,:])
228 +
229 + # STFT 이미지 그리기
230 + #ax2 = fig1.add_subplot(length,2,2*(i+1))
231 +
232 + #STFT
233 + D_highres = librosa.stft(data[i,0,:].flatten(), n_fft=n_fft_n, hop_length=hp_length_n, win_length=win_length_n)
234 +
235 + #ampiltude로 변환
236 + magnitude = np.abs(D_highres)
237 +
238 + #amplitude를 db 스케일로 변환
239 + log_spectrogram = librosa.amplitude_to_db(magnitude)
240 +
241 + #화이트 노이즈 제거
242 + log_spectrogram = log_spectrogram[:,10:150]
243 +
244 + #128,128로 resize
245 + log_spectrogram = cv2.resize(log_spectrogram, (128,128), interpolation = cv2.INTER_AREA)
246 +
247 + #스펙트로그램 출력
248 + #img = librosa.display.specshow(log_spectrogram, sr=sr, hop_length = hp_length_n, ax=ax2, y_axis="linear", x_axis="time")
249 +
250 + #컬러바
251 + #fig.colorbar(img, ax=ax2)# format="%+2.f dB")
252 +
253 + #print(log_spectrogram.shape)
254 +
255 + lst.append(log_spectrogram)
256 + if i%30==0:
257 + print(i,'/',length)
258 +
259 +#npy로 저장
260 +lst = np.array(lst)
261 +output_filename = 'q_spectrogram'
262 +print(lst.shape)
263 +np.save(output_filename, lst)
1 +import numpy as np
2 +
3 +N_samples = np.load('n_spectrogram.npy')
4 +S_samples = np.load('s_spectrogram.npy')
5 +V_samples = np.load('v_spectrogram.npy')
6 +F_samples = np.load('f_spectrogram.npy')
7 +Q_samples = np.load('q_spectrogram.npy')
8 +##########
9 +
10 +S_samples = S_samples.reshape(S_samples.shape[0],1,S_samples.shape[1],S_samples.shape[2])
11 +V_samples = V_samples.reshape(V_samples.shape[0],1,V_samples.shape[1],V_samples.shape[2])
12 +F_samples = F_samples.reshape(F_samples.shape[0],1,F_samples.shape[1],F_samples.shape[2])
13 +Q_samples = Q_samples.reshape(Q_samples.shape[0],1,Q_samples.shape[1],Q_samples.shape[2])
14 +N_samples = N_samples.reshape(N_samples.shape[0],1,N_samples.shape[1],N_samples.shape[2])
15 +
16 +np.save('q_spectrogram', Q_samples)
17 +np.save('v_spectrogram', V_samples)
18 +np.save('s_spectrogram', S_samples)
19 +np.save('f_spectrogram', F_samples)
20 +np.save('n_spectrogram', N_samples)
1 +import os
2 +import numpy as np
3 +
4 +import torch
5 +from torch.utils.data import DataLoader,TensorDataset
6 +
7 +from model import BeatGAN
8 +from options import Options
9 +import matplotlib.pyplot as plt
10 +import matplotlib
11 +plt.rcParams["font.family"] = "Times New Roman"
12 +matplotlib.rcParams.update({'font.size': 38})
13 +from plotUtil import save_ts_heatmap
14 +from data import normalize
15 +
16 +device = torch.device("cpu")
17 +
18 +SAVE_DIR="output/demo/"
19 +
20 +
21 +
22 +
23 +def load_case(normal=True):
24 + if normal:
25 + test_samples = np.load(os.path.join("dataset/demo/", "normal_samples.npy"))
26 + else:
27 + test_samples = np.load(os.path.join("dataset/demo/", "abnormal_samples.npy"))
28 +
29 + for i in range(test_samples.shape[0]):
30 + for j in range(1):
31 + test_samples[i][j] = normalize(test_samples[i][j][:])
32 + test_samples = test_samples[:, :1, :]
33 + print(test_samples.shape)
34 + if not normal :
35 + test_y=np.ones([test_samples.shape[0],1])
36 + else:
37 + test_y = np.zeros([test_samples.shape[0], 1])
38 + test_dataset = TensorDataset(torch.Tensor(test_samples), torch.Tensor(test_y))
39 +
40 + return DataLoader(dataset=test_dataset, # torch TensorDataset format
41 + batch_size=64,
42 + shuffle=False,
43 + num_workers=0,
44 + drop_last=False)
45 +
46 +normal_dataloader=load_case(normal=True)
47 +abnormal_dataloader=load_case(normal=False)
48 +opt = Options()
49 +opt.nc=1
50 +opt.nz=50
51 +opt.isize=320
52 +opt.ndf=32
53 +opt.ngf=32
54 +opt.batchsize=64
55 +opt.ngpu=1
56 +opt.istest=True
57 +opt.lr=0.001
58 +opt.beta1=0.5
59 +opt.niter=None
60 +opt.dataset=None
61 +opt.model = None
62 +opt.outf=None
63 +
64 +
65 +
66 +model=BeatGAN(opt,None,device)
67 +model.G.load_state_dict(torch.load('model/beatgan_folder_0_G.pkl',map_location='cpu'))
68 +model.D.load_state_dict(torch.load('model/beatgan_folder_0_D.pkl',map_location='cpu'))
69 +
70 +
71 +model.G.eval()
72 +model.D.eval()
73 +with torch.no_grad():
74 +
75 + abnormal_input=[]
76 + abnormal_output=[]
77 +
78 + normal_input=[]
79 + normal_output=[]
80 + for i, data in enumerate(abnormal_dataloader, 0):
81 + test_x=data[0]
82 + fake_x, _ = model.G(test_x)
83 +
84 + batch_input = test_x.cpu().numpy()
85 + batch_output = fake_x.cpu().numpy()
86 + abnormal_input.append(batch_input)
87 + abnormal_output.append(batch_output)
88 + abnormal_input=np.concatenate(abnormal_input)
89 + abnormal_output=np.concatenate(abnormal_output)
90 +
91 + for i, data in enumerate(normal_dataloader, 0):
92 + test_x=data[0]
93 + fake_x, _ = model.G(test_x)
94 +
95 + batch_input = test_x.cpu().numpy()
96 + batch_output = fake_x.cpu().numpy()
97 + normal_input.append(batch_input)
98 + normal_output.append(batch_output)
99 + normal_input=np.concatenate(normal_input)
100 + normal_output=np.concatenate(normal_output)
101 +
102 + # print(normal_input.shape)
103 + # print(np.reshape((normal_input-normal_output)**2,(normal_input.shape[0],-1)).shape)
104 +
105 + normal_heat= np.reshape((normal_input-normal_output)**2,(normal_input.shape[0],-1))
106 +
107 + abnormal_heat = np.reshape((abnormal_input - abnormal_output)**2 , (abnormal_input.shape[0], -1))
108 +
109 + # print(normal_heat.shape)
110 + # assert False
111 +
112 + max_val = max(np.max(normal_heat), np.max(abnormal_heat))
113 + min_val = min(np.min(normal_heat), np.min(abnormal_heat))
114 +
115 + normal_heat_norm = (normal_heat - min_val) / (max_val - min_val)
116 + abnormal_heat_norm = (abnormal_heat - min_val) / (max_val - min_val)
117 +
118 +
119 +
120 + # for fig
121 + dataset=["normal","abnormal"]
122 +
123 + for d in dataset:
124 + if not os.path.exists(os.path.join(SAVE_DIR , d)):
125 + os.makedirs(os.path.join(SAVE_DIR , d))
126 +
127 + if d=="normal":
128 + data_input=normal_input
129 + data_output=normal_output
130 + data_heat=normal_heat_norm
131 + else:
132 + data_input = abnormal_input
133 + data_output = abnormal_output
134 + data_heat = abnormal_heat_norm
135 +
136 + for i in range(50):
137 +
138 + input_sig=data_input[i]
139 + output_sig=data_output[i]
140 + heat=data_heat[i]
141 +
142 + # print(input_sig.shape)
143 + # print(output_sig.shape)
144 + # print(heat.shape)
145 + # assert False
146 +
147 + x_points = np.arange(input_sig.shape[1])
148 + fig, ax = plt.subplots(2, 1, sharex=True, figsize=(6, 6), gridspec_kw={'height_ratios': [7, 1],
149 + })
150 +
151 + sig_in = input_sig[0, :]
152 +
153 + sig_out = output_sig[0, :]
154 + ax[0].plot(x_points, sig_in, 'k-', linewidth=2.5, label="ori")
155 + ax[0].plot(x_points, sig_out, 'k--', linewidth=2.5, label="gen")
156 + ax[0].set_yticks([])
157 +
158 + # leg=ax[0].legend(loc="upper right",bbox_to_anchor=(1.06, 1.06))
159 + # leg.get_frame().set_alpha(0.0)
160 +
161 + heat_norm = np.reshape(heat, (1, -1))
162 + # heat_norm=np.zeros((1,320))
163 + # if d=="normal":
164 + # heat_norm[0,100:120]=0.0003
165 + # else:
166 + # heat_norm[0,100:120]=0.9997
167 +
168 + ax[1].imshow(heat_norm, cmap="jet", aspect="auto",vmin = 0,vmax = 0.2)
169 + ax[1].set_yticks([])
170 + # ax[1].set_xlim((0,len(x_points)))
171 +
172 + # fig.subplots_adjust(hspace=0.01)
173 + fig.tight_layout()
174 + # fig.show()
175 + # return
176 + fig.savefig(os.path.join(SAVE_DIR+d,str(i)+"_output.png"))
177 +
178 + fig2, ax2 = plt.subplots(1, 1)
179 + ax2.plot(x_points, sig_in, 'k-', linewidth=2.5, label="input signal")
180 + fig2.savefig(os.path.join(SAVE_DIR + d, str(i) + "_input.png"))
181 +
182 +
183 + plt.clf()
184 +
185 +print("output files are in:{}".format(SAVE_DIR))
...\ No newline at end of file ...\ No newline at end of file
1 +
2 +
3 +import os
4 +os.environ["CUDA_VISIBLE_DEVICES"] = "1"
5 +import torch
6 +from options import Options
7 +
8 +from data import load_data
9 +
10 +# from dcgan import DCGAN as myModel
11 +
12 +
13 +device = torch.device("cuda:0" if
14 +torch.cuda.is_available() else "cpu")
15 +print('device: ',device)
16 +
17 +
18 +
19 +opt = Options().parse()
20 +print(opt)
21 +dataloader=load_data(opt)
22 +print("load data success!!!")
23 +
24 +if opt.model == "beatgan":
25 + from model import BeatGAN as MyModel
26 +
27 +else:
28 + raise Exception("no this model :{}".format(opt.model))
29 +
30 +
31 +model=MyModel(opt,dataloader,device)
32 +print('\nmodel_device:',model.device,'\n')
33 +
34 +if not opt.istest:
35 + print("################ Train ##################")
36 + model.train()
37 +else:
38 + print("################ Eval ##################")
39 + model.load()
40 + model.test_type()
41 + # model.test_time()
42 + # model.plotTestFig()
43 + # print("threshold:{}\tf1-score:{}\tauc:{}".format( th, f1, auc))
This diff is collapsed. Click to expand it.
No preview for this file type
This diff is collapsed. Click to expand it.
1 +
2 +
3 +import os,pickle
4 +import numpy as np
5 +import torch
6 +import torch.nn as nn
7 +
8 +from plotUtil import plot_dist,save_pair_fig,save_plot_sample,print_network,save_plot_pair_sample,loss_plot
9 +
10 +def weights_init(mod):
11 + """
12 + Custom weights initialization called on netG, netD and netE
13 + :param m:
14 + :return:
15 + """
16 + classname = mod.__class__.__name__
17 + if classname.find('Conv') != -1:
18 + # mod.weight.data.normal_(0.0, 0.02)
19 + nn.init.xavier_normal_(mod.weight.data)
20 + # nn.init.kaiming_uniform_(mod.weight.data)
21 +
22 + elif classname.find('BatchNorm') != -1:
23 + mod.weight.data.normal_(1.0, 0.02)
24 + mod.bias.data.fill_(0)
25 + elif classname.find('Linear') !=-1 :
26 + torch.nn.init.xavier_uniform(mod.weight)
27 + mod.bias.data.fill_(0.01)
28 +
29 +
30 +class Encoder(nn.Module):
31 + def __init__(self, ngpu,opt,out_z):
32 + super(Encoder, self).__init__()
33 + self.ngpu = ngpu
34 + self.main = nn.Sequential(
35 + # input is (nc) x 320
36 + #nn.Conv1d(opt.nc,opt.ndf,4,2,1,bias=False),
37 + nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
38 + nn.LeakyReLU(0.2, inplace=True),
39 + # state size. (ndf) x 160
40 + #nn.Conv1d(opt.ndf, opt.ndf * 2, 4, 2, 1, bias=False),
41 + #nn.BatchNorm1d(opt.ndf * 2),
42 + nn.Conv2d(opt.ndf, opt.ndf * 2, 4, 2, 1, bias=False),
43 + nn.BatchNorm2d(opt.ndf * 2),
44 + nn.LeakyReLU(0.2, inplace=True),
45 + # state size. (ndf*2) x 80
46 + #nn.Conv1d(opt.ndf * 2, opt.ndf * 4, 4, 2, 1, bias=False),
47 + #nn.BatchNorm1d(opt.ndf * 4),
48 + nn.Conv2d(opt.ndf * 2, opt.ndf * 4, 4, 2, 1, bias=False),
49 + nn.BatchNorm2d(opt.ndf * 4),
50 + nn.LeakyReLU(0.2, inplace=True),
51 +
52 + nn.Conv2d(opt.ndf * 4, opt.ndf * 8, 4, 2, 1, bias=False),
53 + nn.BatchNorm2d(opt.ndf * 8),
54 + nn.LeakyReLU(0.2, inplace=True),
55 +
56 + nn.Conv2d(opt.ndf * 8, opt.ndf * 16, 4, 1, 1, bias=False),
57 + nn.BatchNorm2d(opt.ndf * 16),
58 + nn.LeakyReLU(0.2, inplace=True),
59 + # state size. (ndf*16) x 10
60 +
61 + #nn.Conv1d(opt.ndf * 16, out_z, 10, 1, 0, bias=False)
62 + nn.Conv2d(opt.ndf * 16, out_z, 7, 1, 0, bias=False),
63 + # state size. (nz) x 1
64 + )
65 +
66 + def forward(self, input):
67 + if input.is_cuda and self.ngpu > 1:
68 + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
69 + else:
70 + output = self.main(input)
71 +
72 + return output
73 +
74 +
75 +##
76 +
77 +
78 +class Decoder(nn.Module):
79 + def __init__(self, ngpu,opt):
80 + super(Decoder, self).__init__()
81 + self.ngpu = ngpu
82 + self.main=nn.Sequential(
83 + nn.ConvTranspose2d(opt.nz,opt.ngf*16,7,1,0,bias=False),
84 + nn.BatchNorm2d(opt.ngf*16),
85 + nn.ReLU(True),
86 +
87 + nn.ConvTranspose2d(opt.ngf * 16, opt.ngf * 8, 4, 1, 1, bias=False),
88 + nn.BatchNorm2d(opt.ngf * 8),
89 + nn.ReLU(True),
90 +
91 + nn.ConvTranspose2d(opt.ngf * 8, opt.ngf * 4, 4, 2, 1, bias=False),
92 + nn.BatchNorm2d(opt.ngf * 4),
93 + nn.ReLU(True),
94 + # state size. (ngf*2) x 40
95 + #nn.ConvTranspose1d(opt.ngf * 4, opt.ngf*2, 4, 2, 1, bias=False),
96 + #nn.BatchNorm1d(opt.ngf*2),
97 + nn.ConvTranspose2d(opt.ngf * 4, opt.ngf*2, 4, 2, 1, bias=False),
98 + nn.BatchNorm2d(opt.ngf*2),
99 + nn.ReLU(True),
100 + # state size. (ngf) x 80
101 + #nn.ConvTranspose1d(opt.ngf * 2, opt.ngf , 4, 2, 1, bias=False),
102 + #nn.BatchNorm1d(opt.ngf ),
103 + nn.ConvTranspose2d(opt.ngf * 2, opt.ngf , 4, 2, 1, bias=False),
104 + nn.BatchNorm2d(opt.ngf ),
105 + nn.ReLU(True),
106 + # state size. (ngf) x 160
107 + #nn.ConvTranspose1d(opt.ngf , opt.nc, 4, 2, 1, bias=False),
108 + nn.ConvTranspose2d(opt.ngf , opt.nc, 4, 2, 1, bias=False),
109 + nn.Tanh()
110 + # state size. (nc) x 320
111 +
112 +
113 + )
114 +
115 + def forward(self, input):
116 + if input.is_cuda and self.ngpu > 1:
117 + output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
118 + else:
119 + output = self.main(input)
120 + return output
121 +
122 +
123 +class AD_MODEL(object):
124 + def __init__(self,opt,dataloader,device):
125 + self.G=None
126 + self.D=None
127 +
128 + self.opt=opt
129 + self.niter=opt.niter
130 + self.dataset=opt.dataset
131 + self.model = opt.model
132 + self.outf=opt.outf
133 +
134 +
135 + def train(self):
136 + raise NotImplementedError
137 +
138 + def visualize_results(self, epoch,samples,is_train=True):
139 + if is_train:
140 + sub_folder="train"
141 + else:
142 + sub_folder="test"
143 +
144 + save_dir=os.path.join(self.outf,self.model,self.dataset,sub_folder)
145 +
146 + if not os.path.exists(save_dir):
147 + os.makedirs(save_dir)
148 +
149 + save_plot_sample(samples, epoch, self.dataset, num_epochs=self.niter,
150 + impath=os.path.join(save_dir,'epoch%03d' % epoch + '.png'))
151 +
152 +
153 + def visualize_pair_results(self,epoch,samples1,samples2,is_train=True):
154 + if is_train:
155 + sub_folder="train"
156 + else:
157 + sub_folder="test"
158 +
159 + save_dir=os.path.join(self.outf,self.model,self.dataset,sub_folder)
160 +
161 + if not os.path.exists(save_dir):
162 + os.makedirs(save_dir)
163 +
164 + save_plot_pair_sample(samples1, samples2, epoch, self.dataset, num_epochs=self.niter, impath=os.path.join(save_dir,'epoch%03d' % epoch + '.png'))
165 +
166 + def save(self,train_hist):
167 + save_dir = os.path.join(self.outf, self.model, self.dataset,"model")
168 +
169 + if not os.path.exists(save_dir):
170 + os.makedirs(save_dir)
171 +
172 + with open(os.path.join(save_dir, self.model + '_history.pkl'), 'wb') as f:
173 + pickle.dump(train_hist, f)
174 +
175 + def save_weight_GD(self):
176 + save_dir = os.path.join(self.outf, self.model, self.dataset, "model")
177 +
178 + if not os.path.exists(save_dir):
179 + os.makedirs(save_dir)
180 +
181 + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model+"_folder_"+str(self.opt.folder) + '_G.pkl'))
182 + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model+"_folder_"+str(self.opt.folder) + '_D.pkl'))
183 +
184 + def load(self):
185 + save_dir = os.path.join(self.outf, self.model, self.dataset,"model")
186 +
187 + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model+"_folder_"+str(self.opt.folder) + '_G.pkl')))
188 + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model+"_folder_"+str(self.opt.folder) + '_D.pkl')))
189 +
190 +
191 + def save_loss(self,train_hist):
192 + loss_plot(train_hist, os.path.join(self.outf, self.model, self.dataset), self.model)
193 +
194 +
195 +
196 + def saveTestPair(self,pair,save_dir):
197 + '''
198 +
199 + :param pair: list of (input,output)
200 + :param save_dir:
201 + :return:
202 + '''
203 + assert save_dir is not None
204 + for idx,p in enumerate(pair):
205 + input=p[0]
206 + output=p[1]
207 + save_pair_fig(input,output,os.path.join(save_dir,str(idx)+".png"))
208 +
209 +
210 +
211 +
212 + def analysisRes(self,N_res,A_res,min_score,max_score,threshold,save_dir):
213 + '''
214 +
215 + :param N_res: list of normal score
216 + :param A_res: dict{ "S": list of S score, "V":...}
217 + :param min_score:
218 + :param max_score:
219 + :return:
220 + '''
221 + print("############ Analysis #############")
222 + print("############ Threshold:{} #############".format(threshold))
223 + all_abnormal_score=[]
224 + all_normal_score=np.array([])
225 + for a_type in A_res:
226 + a_score=A_res[a_type]
227 + print("********* Type:{} *************".format(a_type))
228 + normal_score=normal(N_res, min_score, max_score)
229 + abnormal_score=normal(a_score, min_score, max_score)
230 + all_abnormal_score=np.concatenate((all_abnormal_score,np.array(abnormal_score)))
231 + all_normal_score=normal_score
232 + plot_dist(normal_score,abnormal_score , str(self.opt.folder)+"_"+"N", a_type,
233 + save_dir)
234 +
235 + TP=np.count_nonzero(abnormal_score >= threshold)
236 + FP=np.count_nonzero(normal_score >= threshold)
237 + TN=np.count_nonzero(normal_score < threshold)
238 + FN=np.count_nonzero(abnormal_score<threshold)
239 + print("TP:{}".format(TP))
240 + print("FP:{}".format(FP))
241 + print("TN:{}".format(TN))
242 + print("FN:{}".format(FN))
243 + print("Accuracy:{}".format((TP + TN) * 1.0 / (TP + TN + FP + FN)))
244 + print("Precision/ppv:{}".format(TP * 1.0 / (TP + FP)))
245 + print("sensitivity/Recall:{}".format(TP * 1.0 / (TP + FN)))
246 + print("specificity:{}".format(TN * 1.0 / (TN + FP)))
247 + print("F1:{}".format(2.0 * TP / (2 * TP + FP + FN)))
248 +
249 + # all_abnormal_score=np.reshape(np.array(all_abnormal_score),(-1))
250 + # print(all_abnormal_score.shape)
251 + plot_dist(all_normal_score, all_abnormal_score, str(self.opt.folder)+"_"+"N", "A",
252 + save_dir)
253 +
254 +
255 +
256 +
257 +
258 +
259 +def normal(array,min_val,max_val):
260 + return (array-min_val)/(max_val-min_val)
1 +
2 +
3 +import argparse
4 +import os
5 +import torch
6 +
7 +class Options():
8 + """Options class
9 +
10 + Returns:
11 + [argparse]: argparse containing train and test options
12 + """
13 +
14 + def __init__(self):
15 + ##
16 + #
17 + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 +
19 + ##
20 + # Base
21 + self.parser.add_argument('--dataset', default='ecg', help='ecg dataset')
22 + self.parser.add_argument('--dataroot', default='', help='path to dataset')
23 + self.parser.add_argument('--batchsize', type=int, default=64, help='input batch size')
24 + self.parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
25 + self.parser.add_argument('--isize', type=int, default=320, help='input sequence size.')
26 + self.parser.add_argument('--nc', type=int, default=1, help='input sequence channels')
27 + self.parser.add_argument('--nz', type=int, default=50, help='size of the latent z vector')
28 + self.parser.add_argument('--ngf', type=int, default=32)
29 + self.parser.add_argument('--ndf', type=int, default=32)
30 + self.parser.add_argument('--device', type=str, default='gpu', help='Device: gpu | cpu')
31 + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
32 + self.parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
33 + self.parser.add_argument('--model', type=str, default='beatgan', help='choose model')
34 + self.parser.add_argument('--outf', default='./output', help='output folder')
35 +
36 + ##
37 + # Train
38 + self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
39 + self.parser.add_argument('--niter', type=int, default=100, help='number of epochs to train for')
40 + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
41 + self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
42 + self.parser.add_argument('--w_adv', type=float, default=1, help='parameter')
43 + self.parser.add_argument('--folder', type=int, default=0, help='folder index 0-4')
44 + self.parser.add_argument('--n_aug', type=int, default=0, help='aug data times')
45 +
46 +
47 +
48 + ## Test
49 + self.parser.add_argument('--istest',action='store_true',help='train model or test model')
50 + self.parser.add_argument('--threshold', type=float, default=0.05, help='threshold score for anomaly')
51 +
52 + self.opt = None
53 +
54 + def parse(self):
55 + """ Parse Arguments.
56 + """
57 +
58 + self.opt = self.parser.parse_args()
59 +
60 + str_ids = self.opt.gpu_ids.split(',')
61 + self.opt.gpu_ids = []
62 + for str_id in str_ids:
63 + id = int(str_id)
64 + if id >= 0:
65 + self.opt.gpu_ids.append(id)
66 +
67 + # set gpu ids
68 + if self.opt.device == 'gpu':
69 + torch.cuda.set_device(self.opt.gpu_ids[0])
70 +
71 + args = vars(self.opt)
72 +
73 + # print('------------ Options -------------')
74 + # for k, v in sorted(args.items()):
75 + # print('%s: %s' % (str(k), str(v)))
76 + # print('-------------- End ----------------')
77 +
78 + # save to the disk
79 + self.opt.name = "%s/%s" % (self.opt.model, self.opt.dataset)
80 + expr_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
81 + test_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
82 +
83 + if not os.path.isdir(expr_dir):
84 + os.makedirs(expr_dir)
85 + if not os.path.isdir(test_dir):
86 + os.makedirs(test_dir)
87 +
88 + file_name = os.path.join(expr_dir, 'opt.txt')
89 + with open(file_name, 'wt') as opt_file:
90 + opt_file.write('------------ Options -------------\n')
91 + for k, v in sorted(args.items()):
92 + opt_file.write('%s: %s\n' % (str(k), str(v)))
93 + opt_file.write('-------------- End ----------------\n')
94 + return self.opt
...\ No newline at end of file ...\ No newline at end of file
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
No preview for this file type
No preview for this file type
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.