Showing
6 changed files
with
84 additions
and
78 deletions
... | @@ -151,7 +151,7 @@ def compute_psnr(a, b): | ... | @@ -151,7 +151,7 @@ def compute_psnr(a, b): |
151 | mse = torch.mean((a - b)**2).item() | 151 | mse = torch.mean((a - b)**2).item() |
152 | return -10 * math.log10(mse) | 152 | return -10 * math.log10(mse) |
153 | 153 | ||
154 | -def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): | 154 | +def _encode(seq, path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path): |
155 | compressai.set_entropy_coder(coder) | 155 | compressai.set_entropy_coder(coder) |
156 | enc_start = time.time() | 156 | enc_start = time.time() |
157 | 157 | ||
... | @@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou | ... | @@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou |
182 | strings.append([s[0]]) | 182 | strings.append([s[0]]) |
183 | 183 | ||
184 | with torch.no_grad(): | 184 | with torch.no_grad(): |
185 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | 185 | + recon_out = net.decompress(strings, shape) |
186 | x_recon = crop(recon_out["x_hat"], (h, w)) | 186 | x_recon = crop(recon_out["x_hat"], (h, w)) |
187 | 187 | ||
188 | psnr=compute_psnr(x, x_recon) | 188 | psnr=compute_psnr(x, x_recon) |
189 | 189 | ||
190 | - if i==False: | 190 | + #if i==False: |
191 | - diff=x-ref | 191 | + # diff=x-ref |
192 | - diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 | 192 | + # diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 |
193 | - diff_img = torch2img(diff1) | 193 | + # diff_img = torch2img(diff1) |
194 | - diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png") | 194 | + # diff_img.save("../Data/train/"+seq+str(ff)+"_train_v1_q"+str(quality)+".png") |
195 | 195 | ||
196 | enc_time = time.time() - enc_start | 196 | enc_time = time.time() - enc_start |
197 | size = filesize(output) | 197 | size = filesize(output) |
... | @@ -336,15 +336,15 @@ def encode(argv): | ... | @@ -336,15 +336,15 @@ def encode(argv): |
336 | total_psnr=0.0 | 336 | total_psnr=0.0 |
337 | total_bpp=0.0 | 337 | total_bpp=0.0 |
338 | total_time=0.0 | 338 | total_time=0.0 |
339 | - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | 339 | + img_path =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" |
340 | - img=args.image+"_frame"+str(0)+".png" | 340 | + img=img_path+"_frame"+str(0)+".png" |
341 | - total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | 341 | + total_psnr, total_bpp, ref,total_time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) |
342 | for ff in range(1, args.frame): | 342 | for ff in range(1, args.frame): |
343 | with Path(log_path).open("a") as f: | 343 | with Path(log_path).open("a") as f: |
344 | f.write(f" {ff:3d} | ") | 344 | f.write(f" {ff:3d} | ") |
345 | - img=args.image+"_frame"+str(ff)+".png" | 345 | + img=img_path+"_frame"+str(ff)+".png" |
346 | 346 | ||
347 | - psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | 347 | + psnr, total_bpp, ref,time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) |
348 | total_psnr+=psnr | 348 | total_psnr+=psnr |
349 | total_time+=time | 349 | total_time+=time |
350 | 350 | ... | ... |
... | @@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
213 | strings.append([s[0]]) | 213 | strings.append([s[0]]) |
214 | 214 | ||
215 | with torch.no_grad(): | 215 | with torch.no_grad(): |
216 | - recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | 216 | + recon_out1 = net.decompress(strings,shape) |
217 | x_hat1 = crop(recon_out1["x_hat"], (h, w)) | 217 | x_hat1 = crop(recon_out1["x_hat"], (h, w)) |
218 | 218 | ||
219 | with torch.no_grad(): | 219 | with torch.no_grad(): |
... | @@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
231 | strings.append([s[0]]) | 231 | strings.append([s[0]]) |
232 | 232 | ||
233 | with torch.no_grad(): | 233 | with torch.no_grad(): |
234 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | 234 | + recon_out = net.decompress(strings, shape) |
235 | x_hat2 = crop(recon_out["x_hat"], (h, w)) | 235 | x_hat2 = crop(recon_out["x_hat"], (h, w)) |
236 | x_recon=ref+x_hat1-x_hat2 | 236 | x_recon=ref+x_hat1-x_hat2 |
237 | 237 | ... | ... |
... | @@ -17,6 +17,7 @@ import struct | ... | @@ -17,6 +17,7 @@ import struct |
17 | import sys | 17 | import sys |
18 | import time | 18 | import time |
19 | import math | 19 | import math |
20 | +from pytorch_msssim import ms_ssim | ||
20 | 21 | ||
21 | from pathlib import Path | 22 | from pathlib import Path |
22 | 23 | ||
... | @@ -27,7 +28,12 @@ from PIL import Image | ... | @@ -27,7 +28,12 @@ from PIL import Image |
27 | from torchvision.transforms import ToPILImage, ToTensor | 28 | from torchvision.transforms import ToPILImage, ToTensor |
28 | 29 | ||
29 | import compressai | 30 | import compressai |
30 | - | 31 | +from compressai.transforms.functional import ( |
32 | + rgb2ycbcr, | ||
33 | + ycbcr2rgb, | ||
34 | + yuv_420_to_444, | ||
35 | + yuv_444_to_420, | ||
36 | +) | ||
31 | from compressai.zoo import models | 37 | from compressai.zoo import models |
32 | 38 | ||
33 | model_ids = {k: i for i, k in enumerate(models.keys())} | 39 | model_ids = {k: i for i, k in enumerate(models.keys())} |
... | @@ -151,13 +157,28 @@ def compute_psnr(a, b): | ... | @@ -151,13 +157,28 @@ def compute_psnr(a, b): |
151 | mse = torch.mean((a - b)**2).item() | 157 | mse = torch.mean((a - b)**2).item() |
152 | return -10 * math.log10(mse) | 158 | return -10 * math.log10(mse) |
153 | 159 | ||
154 | -def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): | 160 | +def compute_msssim(a, b): |
161 | + return ms_ssim(a, b, data_range=1.).item() | ||
162 | + | ||
163 | +def ycbcr_psnr(a, b): | ||
164 | + yuv_a=rgb2ycbcr(a) | ||
165 | + yuv_b=rgb2ycbcr(b) | ||
166 | + a_y, a_cb, a_cr = yuv_a.chunk(3, -3) | ||
167 | + b_y, b_cb, b_cr = yuv_b.chunk(3, -3) | ||
168 | + y=compute_psnr(a_y, b_y) | ||
169 | + cb=compute_psnr(a_cb, b_cb) | ||
170 | + cr=compute_psnr(a_cr, b_cr) | ||
171 | + return (4*y+cb+cr)/6 | ||
172 | + | ||
173 | +def _encode(checkpoint, path, seq, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path): | ||
155 | compressai.set_entropy_coder(coder) | 174 | compressai.set_entropy_coder(coder) |
156 | enc_start = time.time() | 175 | enc_start = time.time() |
157 | 176 | ||
158 | - img = load_image(image) | 177 | + img = load_image(image+"_frame"+str(ff)+".png") |
159 | start = time.time() | 178 | start = time.time() |
160 | - net = models[model](quality=quality, metric=metric, pretrained=True).eval() | 179 | + net = models[model](quality=quality, metric=metric, pretrained=True) |
180 | + | ||
181 | + net.eval() | ||
161 | load_time = time.time() - start | 182 | load_time = time.time() - start |
162 | 183 | ||
163 | x = img2torch(img) | 184 | x = img2torch(img) |
... | @@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
182 | strings.append([s[0]]) | 203 | strings.append([s[0]]) |
183 | 204 | ||
184 | with torch.no_grad(): | 205 | with torch.no_grad(): |
185 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | 206 | + recon_out = net.decompress(strings,shape) |
186 | x_recon = crop(recon_out["x_hat"], (h, w)) | 207 | x_recon = crop(recon_out["x_hat"], (h, w)) |
187 | 208 | ||
188 | psnr=compute_psnr(x, x_recon) | 209 | psnr=compute_psnr(x, x_recon) |
210 | + ssim=compute_msssim(x, x_recon) | ||
211 | + ycbcr=ycbcr_psnr(x, x_recon) | ||
212 | + else: | ||
213 | + if checkpoint: # load from previous checkpoint | ||
214 | + checkpoint = torch.load(checkpoint) | ||
215 | + #state_dict = load_state_dict(checkpoint["state_dict"]) | ||
216 | + net=models[model](quality=quality, metric=metric) | ||
217 | + net.load_state_dict(checkpoint["state_dict"]) | ||
218 | + net.update(force=True) | ||
189 | else: | 219 | else: |
220 | + net = models[model](quality=quality, metric=metric, pretrained=True) | ||
221 | + | ||
190 | diff=x-ref | 222 | diff=x-ref |
191 | - #1 | ||
192 | diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 | 223 | diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5 |
193 | 224 | ||
194 | - #2 | ||
195 | - ''' | ||
196 | - diff1=torch.clamp(diff, min=0.0, max=1.0) | ||
197 | - diff2=-torch.clamp(diff, min=-1.0, max=0.0) | ||
198 | - | ||
199 | - diff1=pad(diff1, p) | ||
200 | - diff2=pad(diff2, p) | ||
201 | - ''' | ||
202 | - #1 | ||
203 | - | ||
204 | - with torch.no_grad(): | ||
205 | - out1 = net.compress(diff1) | ||
206 | - shape1 = out1["shape"] | ||
207 | - strings = [] | ||
208 | - | ||
209 | - with Path(output).open("ab") as f: | ||
210 | - # write shape and number of encoded latents | ||
211 | - write_uints(f, (shape1[0], shape1[1], len(out1["strings"]))) | ||
212 | - | ||
213 | - for s in out1["strings"]: | ||
214 | - write_uints(f, (len(s[0]),)) | ||
215 | - write_bytes(f, s[0]) | ||
216 | - strings.append([s[0]]) | ||
217 | - | ||
218 | - with torch.no_grad(): | ||
219 | - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | ||
220 | - x_hat1 = crop(recon_out["x_hat"], (h, w)) | ||
221 | 225 | ||
222 | - #2 | ||
223 | - ''' | ||
224 | with torch.no_grad(): | 226 | with torch.no_grad(): |
225 | out1 = net.compress(diff1) | 227 | out1 = net.compress(diff1) |
226 | shape1 = out1["shape"] | 228 | shape1 = out1["shape"] |
... | @@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
236 | strings.append([s[0]]) | 238 | strings.append([s[0]]) |
237 | 239 | ||
238 | with torch.no_grad(): | 240 | with torch.no_grad(): |
239 | - recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"]))) | 241 | + recon_out = net.decompress(strings, shape1) |
240 | x_hat1 = crop(recon_out["x_hat"], (h, w)) | 242 | x_hat1 = crop(recon_out["x_hat"], (h, w)) |
241 | - with torch.no_grad(): | ||
242 | - out = net.compress(diff2) | ||
243 | - shape = out["shape"] | ||
244 | - strings = [] | ||
245 | 243 | ||
246 | - with Path(output).open("ab") as f: | ||
247 | - # write shape and number of encoded latents | ||
248 | - write_uints(f, (shape[0], shape[1], len(out["strings"]))) | ||
249 | - | ||
250 | - for s in out["strings"]: | ||
251 | - write_uints(f, (len(s[0]),)) | ||
252 | - write_bytes(f, s[0]) | ||
253 | - strings.append([s[0]]) | ||
254 | - | ||
255 | - with torch.no_grad(): | ||
256 | - recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"]))) | ||
257 | - x_hat2 = crop(recon_out["x_hat"], (h, w)) | ||
258 | - x_recon=ref+x_hat1-x_hat2 | ||
259 | - ''' | ||
260 | 244 | ||
261 | x_recon=ref+x_hat1-0.5 | 245 | x_recon=ref+x_hat1-0.5 |
262 | psnr=compute_psnr(x, x_recon) | 246 | psnr=compute_psnr(x, x_recon) |
247 | + ssim=compute_msssim(x, x_recon) | ||
248 | + ycbcr=ycbcr_psnr(x, x_recon) | ||
263 | diff_img = torch2img(diff1) | 249 | diff_img = torch2img(diff1) |
264 | - diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png") | 250 | +# diff_img.save(path+"recon/"+seq+str(ff)+"_q"+str(quality)+".png") |
251 | +# diff_img.save("../Data/train/"+seq+str(ff)+"_train8_q"+str(quality)+".png") | ||
265 | 252 | ||
266 | enc_time = time.time() - enc_start | 253 | enc_time = time.time() - enc_start |
267 | size = filesize(output) | 254 | size = filesize(output) |
... | @@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o | ... | @@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o |
269 | with Path(log_path).open("a") as f: | 256 | with Path(log_path).open("a") as f: |
270 | f.write( f" {bpp-total_bpp:.4f} | " | 257 | f.write( f" {bpp-total_bpp:.4f} | " |
271 | f" {psnr:.4f} |" | 258 | f" {psnr:.4f} |" |
259 | + f" {ssim:.4f} |" | ||
260 | + f" {ycbcr:.4f} |" | ||
272 | f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") | 261 | f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n") |
273 | recon_img = torch2img(x_recon) | 262 | recon_img = torch2img(x_recon) |
274 | recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") | 263 | recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png") |
275 | 264 | ||
276 | - return psnr, bpp, x_recon, enc_time | 265 | + return psnr, bpp, x_recon, enc_time, ssim, ycbcr |
277 | 266 | ||
278 | 267 | ||
279 | def _decode(inputpath, coder, show, frame, output=None): | 268 | def _decode(inputpath, coder, show, frame, output=None): |
... | @@ -381,13 +370,19 @@ def encode(argv): | ... | @@ -381,13 +370,19 @@ def encode(argv): |
381 | default=768, | 370 | default=768, |
382 | help="hight setting (default: %(default))", | 371 | help="hight setting (default: %(default))", |
383 | ) | 372 | ) |
373 | + parser.add_argument( | ||
374 | + "-check", | ||
375 | + "--checkpoint", | ||
376 | + type=str, | ||
377 | + help="Path to a checkpoint", | ||
378 | + ) | ||
384 | parser.add_argument("-o", "--output", help="Output path") | 379 | parser.add_argument("-o", "--output", help="Output path") |
385 | args = parser.parse_args(argv) | 380 | args = parser.parse_args(argv) |
386 | path="examples/"+args.image+"/" | 381 | path="examples/"+args.image+"/" |
387 | if not args.output: | 382 | if not args.output: |
388 | #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") | 383 | #args.output = Path(Path(args.image).resolve().name).with_suffix(".bin") |
389 | - args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin" | 384 | + args.output = path+args.image+"_q"+str(args.quality)+"_train_ssim.bin" |
390 | - log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt" | 385 | + log_path=path+args.image+"_q"+str(args.quality)+"_train_ssim.txt" |
391 | 386 | ||
392 | header = get_header(args.model, args.metric, args.quality) | 387 | header = get_header(args.model, args.metric, args.quality) |
393 | with Path(args.output).open("wb") as f: | 388 | with Path(args.output).open("wb") as f: |
... | @@ -400,32 +395,43 @@ def encode(argv): | ... | @@ -400,32 +395,43 @@ def encode(argv): |
400 | f"frames : {args.frame}\n") | 395 | f"frames : {args.frame}\n") |
401 | f.write( f"frame | bpp | " | 396 | f.write( f"frame | bpp | " |
402 | f" psnr |" | 397 | f" psnr |" |
398 | + f" ssim |" | ||
403 | f" Encoded time (model loading)\n" | 399 | f" Encoded time (model loading)\n" |
404 | f" {0:3d} | ") | 400 | f" {0:3d} | ") |
405 | 401 | ||
406 | total_psnr=0.0 | 402 | total_psnr=0.0 |
403 | + total_ssim=0.0 | ||
404 | + total_ycbcr=0.0 | ||
407 | total_bpp=0.0 | 405 | total_bpp=0.0 |
408 | total_time=0.0 | 406 | total_time=0.0 |
409 | - args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" | 407 | + img =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444" |
410 | - img=args.image+"_frame"+str(0)+".png" | 408 | + total_psnr, total_bpp, ref, total_time, total_ssim, total_ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) |
411 | - total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path) | ||
412 | for ff in range(1, args.frame): | 409 | for ff in range(1, args.frame): |
413 | with Path(log_path).open("a") as f: | 410 | with Path(log_path).open("a") as f: |
414 | f.write(f" {ff:3d} | ") | 411 | f.write(f" {ff:3d} | ") |
415 | - img=args.image+"_frame"+str(ff)+".png" | 412 | + if ff%25==0: |
416 | - | 413 | + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, ref, total_bpp, ff, args.output, log_path) |
417 | - psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | 414 | + else: |
415 | + psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path) | ||
418 | total_psnr+=psnr | 416 | total_psnr+=psnr |
417 | + total_ssim+=ssim | ||
418 | + total_ycbcr+=ycbcr | ||
419 | total_time+=time | 419 | total_time+=time |
420 | 420 | ||
421 | total_psnr/=args.frame | 421 | total_psnr/=args.frame |
422 | + total_ssim/=args.frame | ||
423 | + total_ycbcr/=args.frame | ||
422 | total_bpp/=args.frame | 424 | total_bpp/=args.frame |
423 | 425 | ||
424 | with Path(log_path).open("a") as f: | 426 | with Path(log_path).open("a") as f: |
425 | f.write( f"\n Total Encoded time: {total_time:.2f}s\n" | 427 | f.write( f"\n Total Encoded time: {total_time:.2f}s\n" |
426 | f"\n Total PSNR: {total_psnr:.6f}\n" | 428 | f"\n Total PSNR: {total_psnr:.6f}\n" |
429 | + f"\n Total SSIM: {total_ssim:.6f}\n" | ||
430 | + f"\n Total ycbcr: {total_ycbcr:.6f}\n" | ||
427 | f" Total BPP: {total_bpp:.6f}\n") | 431 | f" Total BPP: {total_bpp:.6f}\n") |
428 | print(total_psnr) | 432 | print(total_psnr) |
433 | + print(total_ssim) | ||
434 | + print(total_ycbcr) | ||
429 | print(total_bpp) | 435 | print(total_bpp) |
430 | 436 | ||
431 | 437 | ... | ... |
Our Encoder/train_RGB.py
0 → 100644
This diff is collapsed. Click to expand it.
Our Encoder/train_RGB_MS-SSIMloss.py
0 → 100644
This diff is collapsed. Click to expand it.
Our Encoder/train_YCbCr.py
0 → 100644
This diff is collapsed. Click to expand it.
-
Please register or login to post a comment