choiseungmi

code update

...@@ -151,7 +151,7 @@ def compute_psnr(a, b): ...@@ -151,7 +151,7 @@ def compute_psnr(a, b):
151 mse = torch.mean((a - b)**2).item() 151 mse = torch.mean((a - b)**2).item()
152 return -10 * math.log10(mse) 152 return -10 * math.log10(mse)
153 153
154 -def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): 154 +def _encode(seq, path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path):
155 compressai.set_entropy_coder(coder) 155 compressai.set_entropy_coder(coder)
156 enc_start = time.time() 156 enc_start = time.time()
157 157
...@@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou ...@@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou
182 strings.append([s[0]]) 182 strings.append([s[0]])
183 183
184 with torch.no_grad(): 184 with torch.no_grad():
185 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) 185 + recon_out = net.decompress(strings, shape)
186 x_recon = crop(recon_out["x_hat"], (h, w)) 186 x_recon = crop(recon_out["x_hat"], (h, w))
187 187
188 psnr=compute_psnr(x, x_recon) 188 psnr=compute_psnr(x, x_recon)
189 189
190 - if i==False: 190 + #if i==False:
191 - diff=x-ref 191 + # diff=x-ref
192 - diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 192 + # diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
193 - diff_img = torch2img(diff1) 193 + # diff_img = torch2img(diff1)
194 - diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png") 194 + # diff_img.save("../Data/train/"+seq+str(ff)+"_train_v1_q"+str(quality)+".png")
195 195
196 enc_time = time.time() - enc_start 196 enc_time = time.time() - enc_start
197 size = filesize(output) 197 size = filesize(output)
...@@ -336,15 +336,15 @@ def encode(argv): ...@@ -336,15 +336,15 @@ def encode(argv):
336 total_psnr=0.0 336 total_psnr=0.0
337 total_bpp=0.0 337 total_bpp=0.0
338 total_time=0.0 338 total_time=0.0
339 - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" 339 + img_path =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
340 - img=args.image+"_frame"+str(0)+".png" 340 + img=img_path+"_frame"+str(0)+".png"
341 - total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) 341 + total_psnr, total_bpp, ref,total_time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
342 for ff in range(1, args.frame): 342 for ff in range(1, args.frame):
343 with Path(log_path).open("a") as f: 343 with Path(log_path).open("a") as f:
344 f.write(f" {ff:3d} | ") 344 f.write(f" {ff:3d} | ")
345 - img=args.image+"_frame"+str(ff)+".png" 345 + img=img_path+"_frame"+str(ff)+".png"
346 346
347 - psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) 347 + psnr, total_bpp, ref,time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
348 total_psnr+=psnr 348 total_psnr+=psnr
349 total_time+=time 349 total_time+=time
350 350
......
...@@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
213 strings.append([s[0]]) 213 strings.append([s[0]])
214 214
215 with torch.no_grad(): 215 with torch.no_grad():
216 - recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) 216 + recon_out1 = net.decompress(strings,shape)
217 x_hat1 = crop(recon_out1["x_hat"], (h, w)) 217 x_hat1 = crop(recon_out1["x_hat"], (h, w))
218 218
219 with torch.no_grad(): 219 with torch.no_grad():
...@@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
231 strings.append([s[0]]) 231 strings.append([s[0]])
232 232
233 with torch.no_grad(): 233 with torch.no_grad():
234 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) 234 + recon_out = net.decompress(strings, shape)
235 x_hat2 = crop(recon_out["x_hat"], (h, w)) 235 x_hat2 = crop(recon_out["x_hat"], (h, w))
236 x_recon=ref+x_hat1-x_hat2 236 x_recon=ref+x_hat1-x_hat2
237 237
......
...@@ -17,6 +17,7 @@ import struct ...@@ -17,6 +17,7 @@ import struct
17 import sys 17 import sys
18 import time 18 import time
19 import math 19 import math
20 +from pytorch_msssim import ms_ssim
20 21
21 from pathlib import Path 22 from pathlib import Path
22 23
...@@ -27,7 +28,12 @@ from PIL import Image ...@@ -27,7 +28,12 @@ from PIL import Image
27 from torchvision.transforms import ToPILImage, ToTensor 28 from torchvision.transforms import ToPILImage, ToTensor
28 29
29 import compressai 30 import compressai
30 - 31 +from compressai.transforms.functional import (
32 + rgb2ycbcr,
33 + ycbcr2rgb,
34 + yuv_420_to_444,
35 + yuv_444_to_420,
36 +)
31 from compressai.zoo import models 37 from compressai.zoo import models
32 38
33 model_ids = {k: i for i, k in enumerate(models.keys())} 39 model_ids = {k: i for i, k in enumerate(models.keys())}
...@@ -151,13 +157,28 @@ def compute_psnr(a, b): ...@@ -151,13 +157,28 @@ def compute_psnr(a, b):
151 mse = torch.mean((a - b)**2).item() 157 mse = torch.mean((a - b)**2).item()
152 return -10 * math.log10(mse) 158 return -10 * math.log10(mse)
153 159
154 -def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): 160 +def compute_msssim(a, b):
161 + return ms_ssim(a, b, data_range=1.).item()
162 +
163 +def ycbcr_psnr(a, b):
164 + yuv_a=rgb2ycbcr(a)
165 + yuv_b=rgb2ycbcr(b)
166 + a_y, a_cb, a_cr = yuv_a.chunk(3, -3)
167 + b_y, b_cb, b_cr = yuv_b.chunk(3, -3)
168 + y=compute_psnr(a_y, b_y)
169 + cb=compute_psnr(a_cb, b_cb)
170 + cr=compute_psnr(a_cr, b_cr)
171 + return (4*y+cb+cr)/6
172 +
173 +def _encode(checkpoint, path, seq, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
155 compressai.set_entropy_coder(coder) 174 compressai.set_entropy_coder(coder)
156 enc_start = time.time() 175 enc_start = time.time()
157 176
158 - img = load_image(image) 177 + img = load_image(image+"_frame"+str(ff)+".png")
159 start = time.time() 178 start = time.time()
160 - net = models[model](quality=quality, metric=metric, pretrained=True).eval() 179 + net = models[model](quality=quality, metric=metric, pretrained=True)
180 +
181 + net.eval()
161 load_time = time.time() - start 182 load_time = time.time() - start
162 183
163 x = img2torch(img) 184 x = img2torch(img)
...@@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
182 strings.append([s[0]]) 203 strings.append([s[0]])
183 204
184 with torch.no_grad(): 205 with torch.no_grad():
185 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) 206 + recon_out = net.decompress(strings,shape)
186 x_recon = crop(recon_out["x_hat"], (h, w)) 207 x_recon = crop(recon_out["x_hat"], (h, w))
187 208
188 psnr=compute_psnr(x, x_recon) 209 psnr=compute_psnr(x, x_recon)
210 + ssim=compute_msssim(x, x_recon)
211 + ycbcr=ycbcr_psnr(x, x_recon)
189 else: 212 else:
213 + if checkpoint: # load from previous checkpoint
214 + checkpoint = torch.load(checkpoint)
215 + #state_dict = load_state_dict(checkpoint["state_dict"])
216 + net=models[model](quality=quality, metric=metric)
217 + net.load_state_dict(checkpoint["state_dict"])
218 + net.update(force=True)
219 + else:
220 + net = models[model](quality=quality, metric=metric, pretrained=True)
221 +
190 diff=x-ref 222 diff=x-ref
191 - #1
192 diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 223 diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
224 +
193 225
194 - #2
195 - '''
196 - diff1=torch.clamp(diff, min=0.0, max=1.0)
197 - diff2=-torch.clamp(diff, min=-1.0, max=0.0)
198 -
199 - diff1=pad(diff1, p)
200 - diff2=pad(diff2, p)
201 - '''
202 - #1
203 -
204 - with torch.no_grad():
205 - out1 = net.compress(diff1)
206 - shape1 = out1["shape"]
207 - strings = []
208 -
209 - with Path(output).open("ab") as f:
210 - # write shape and number of encoded latents
211 - write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
212 -
213 - for s in out1["strings"]:
214 - write_uints(f, (len(s[0]),))
215 - write_bytes(f, s[0])
216 - strings.append([s[0]])
217 -
218 - with torch.no_grad():
219 - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
220 - x_hat1 = crop(recon_out["x_hat"], (h, w))
221 -
222 - #2
223 - '''
224 with torch.no_grad(): 226 with torch.no_grad():
225 out1 = net.compress(diff1) 227 out1 = net.compress(diff1)
226 shape1 = out1["shape"] 228 shape1 = out1["shape"]
...@@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
236 strings.append([s[0]]) 238 strings.append([s[0]])
237 239
238 with torch.no_grad(): 240 with torch.no_grad():
239 - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) 241 + recon_out = net.decompress(strings, shape1)
240 x_hat1 = crop(recon_out["x_hat"], (h, w)) 242 x_hat1 = crop(recon_out["x_hat"], (h, w))
241 - with torch.no_grad(): 243 +
242 - out = net.compress(diff2)
243 - shape = out["shape"]
244 - strings = []
245 -
246 - with Path(output).open("ab") as f:
247 - # write shape and number of encoded latents
248 - write_uints(f, (shape[0], shape[1], len(out["strings"])))
249 -
250 - for s in out["strings"]:
251 - write_uints(f, (len(s[0]),))
252 - write_bytes(f, s[0])
253 - strings.append([s[0]])
254 -
255 - with torch.no_grad():
256 - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
257 - x_hat2 = crop(recon_out["x_hat"], (h, w))
258 - x_recon=ref+x_hat1-x_hat2
259 - '''
260 244
261 x_recon=ref+x_hat1-0.5 245 x_recon=ref+x_hat1-0.5
262 psnr=compute_psnr(x, x_recon) 246 psnr=compute_psnr(x, x_recon)
247 + ssim=compute_msssim(x, x_recon)
248 + ycbcr=ycbcr_psnr(x, x_recon)
263 diff_img = torch2img(diff1) 249 diff_img = torch2img(diff1)
264 - diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png") 250 +# diff_img.save(path+"recon/"+seq+str(ff)+"_q"+str(quality)+".png")
251 +# diff_img.save("../Data/train/"+seq+str(ff)+"_train8_q"+str(quality)+".png")
265 252
266 enc_time = time.time() - enc_start 253 enc_time = time.time() - enc_start
267 size = filesize(output) 254 size = filesize(output)
...@@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o ...@@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
269 with Path(log_path).open("a") as f: 256 with Path(log_path).open("a") as f:
270 f.write( f" {bpp-total_bpp:.4f} | " 257 f.write( f" {bpp-total_bpp:.4f} | "
271 f" {psnr:.4f} |" 258 f" {psnr:.4f} |"
259 + f" {ssim:.4f} |"
260 + f" {ycbcr:.4f} |"
272 f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") 261 f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
273 recon_img = torch2img(x_recon) 262 recon_img = torch2img(x_recon)
274 recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") 263 recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png")
275 264
276 - return psnr, bpp, x_recon, enc_time 265 + return psnr, bpp, x_recon, enc_time, ssim, ycbcr
277 266
278 267
279 def _decode(inputpath, coder, show, frame, output=None): 268 def _decode(inputpath, coder, show, frame, output=None):
...@@ -381,13 +370,19 @@ def encode(argv): ...@@ -381,13 +370,19 @@ def encode(argv):
381 default=768, 370 default=768,
382 help="hight setting (default: %(default))", 371 help="hight setting (default: %(default))",
383 ) 372 )
373 + parser.add_argument(
374 + "-check",
375 + "--checkpoint",
376 + type=str,
377 + help="Path to a checkpoint",
378 + )
384 parser.add_argument("-o", "--output", help="Output path") 379 parser.add_argument("-o", "--output", help="Output path")
385 args = parser.parse_args(argv) 380 args = parser.parse_args(argv)
386 path="examples/"+args.image+"/" 381 path="examples/"+args.image+"/"
387 if not args.output: 382 if not args.output:
388 #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") 383 #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
389 - args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin" 384 + args.output = path+args.image+"_q"+str(args.quality)+"_train_ssim.bin"
390 - log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt" 385 + log_path=path+args.image+"_q"+str(args.quality)+"_train_ssim.txt"
391 386
392 header = get_header(args.model, args.metric, args.quality) 387 header = get_header(args.model, args.metric, args.quality)
393 with Path(args.output).open("wb") as f: 388 with Path(args.output).open("wb") as f:
...@@ -400,32 +395,43 @@ def encode(argv): ...@@ -400,32 +395,43 @@ def encode(argv):
400 f"frames : {args.frame}\n") 395 f"frames : {args.frame}\n")
401 f.write( f"frame | bpp | " 396 f.write( f"frame | bpp | "
402 f" psnr |" 397 f" psnr |"
398 + f" ssim |"
403 f" Encoded time (model loading)\n" 399 f" Encoded time (model loading)\n"
404 f" {0:3d} | ") 400 f" {0:3d} | ")
405 401
406 total_psnr=0.0 402 total_psnr=0.0
403 + total_ssim=0.0
404 + total_ycbcr=0.0
407 total_bpp=0.0 405 total_bpp=0.0
408 total_time=0.0 406 total_time=0.0
409 - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" 407 + img =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
410 - img=args.image+"_frame"+str(0)+".png" 408 + total_psnr, total_bpp, ref, total_time, total_ssim, total_ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
411 - total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
412 for ff in range(1, args.frame): 409 for ff in range(1, args.frame):
413 with Path(log_path).open("a") as f: 410 with Path(log_path).open("a") as f:
414 f.write(f" {ff:3d} | ") 411 f.write(f" {ff:3d} | ")
415 - img=args.image+"_frame"+str(ff)+".png" 412 + if ff%25==0:
416 - 413 + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, ref, total_bpp, ff, args.output, log_path)
417 - psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) 414 + else:
415 + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
418 total_psnr+=psnr 416 total_psnr+=psnr
417 + total_ssim+=ssim
418 + total_ycbcr+=ycbcr
419 total_time+=time 419 total_time+=time
420 420
421 total_psnr/=args.frame 421 total_psnr/=args.frame
422 + total_ssim/=args.frame
423 + total_ycbcr/=args.frame
422 total_bpp/=args.frame 424 total_bpp/=args.frame
423 425
424 with Path(log_path).open("a") as f: 426 with Path(log_path).open("a") as f:
425 f.write( f"\n Total Encoded time: {total_time:.2f}s\n" 427 f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
426 f"\n Total PSNR: {total_psnr:.6f}\n" 428 f"\n Total PSNR: {total_psnr:.6f}\n"
429 + f"\n Total SSIM: {total_ssim:.6f}\n"
430 + f"\n Total ycbcr: {total_ycbcr:.6f}\n"
427 f" Total BPP: {total_bpp:.6f}\n") 431 f" Total BPP: {total_bpp:.6f}\n")
428 print(total_psnr) 432 print(total_psnr)
433 + print(total_ssim)
434 + print(total_ycbcr)
429 print(total_bpp) 435 print(total_bpp)
430 436
431 437
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.