choiseungmi

code update

......@@ -151,7 +151,7 @@ def compute_psnr(a, b):
mse = torch.mean((a - b)**2).item()
return -10 * math.log10(mse)
def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path):
def _encode(seq, path, image, model, metric, quality, coder, i, ref,total_bpp, ff, output, log_path):
compressai.set_entropy_coder(coder)
enc_start = time.time()
......@@ -182,16 +182,16 @@ def _encode(path, image, model, metric, quality, coder, i, ref,total_bpp, ff, ou
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
recon_out = net.decompress(strings, shape)
x_recon = crop(recon_out["x_hat"], (h, w))
psnr=compute_psnr(x, x_recon)
if i==False:
diff=x-ref
diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
diff_img = torch2img(diff1)
diff_img.save(path+"recon/diff_v1_"+str(ff)+"_q"+str(quality)+".png")
#if i==False:
# diff=x-ref
# diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
# diff_img = torch2img(diff1)
# diff_img.save("../Data/train/"+seq+str(ff)+"_train_v1_q"+str(quality)+".png")
enc_time = time.time() - enc_start
size = filesize(output)
......@@ -336,15 +336,15 @@ def encode(argv):
total_psnr=0.0
total_bpp=0.0
total_time=0.0
args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
img=args.image+"_frame"+str(0)+".png"
total_psnr, total_bpp, ref,total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
img_path =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
img=img_path+"_frame"+str(0)+".png"
total_psnr, total_bpp, ref,total_time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
for ff in range(1, args.frame):
with Path(log_path).open("a") as f:
f.write(f" {ff:3d} | ")
img=args.image+"_frame"+str(ff)+".png"
img=img_path+"_frame"+str(ff)+".png"
psnr, total_bpp, ref,time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
psnr, total_bpp, ref,time = _encode(args.image, path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
total_psnr+=psnr
total_time+=time
......
......@@ -213,7 +213,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
strings.append([s[0]])
with torch.no_grad():
recon_out1 = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
recon_out1 = net.decompress(strings,shape)
x_hat1 = crop(recon_out1["x_hat"], (h, w))
with torch.no_grad():
......@@ -231,7 +231,7 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
recon_out = net.decompress(strings, shape)
x_hat2 = crop(recon_out["x_hat"], (h, w))
x_recon=ref+x_hat1-x_hat2
......
......@@ -17,6 +17,7 @@ import struct
import sys
import time
import math
from pytorch_msssim import ms_ssim
from pathlib import Path
......@@ -27,7 +28,12 @@ from PIL import Image
from torchvision.transforms import ToPILImage, ToTensor
import compressai
from compressai.transforms.functional import (
rgb2ycbcr,
ycbcr2rgb,
yuv_420_to_444,
yuv_444_to_420,
)
from compressai.zoo import models
model_ids = {k: i for i, k in enumerate(models.keys())}
......@@ -151,13 +157,28 @@ def compute_psnr(a, b):
mse = torch.mean((a - b)**2).item()
return -10 * math.log10(mse)
def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
def compute_msssim(a, b):
return ms_ssim(a, b, data_range=1.).item()
def ycbcr_psnr(a, b):
yuv_a=rgb2ycbcr(a)
yuv_b=rgb2ycbcr(b)
a_y, a_cb, a_cr = yuv_a.chunk(3, -3)
b_y, b_cb, b_cr = yuv_b.chunk(3, -3)
y=compute_psnr(a_y, b_y)
cb=compute_psnr(a_cb, b_cb)
cr=compute_psnr(a_cr, b_cr)
return (4*y+cb+cr)/6
def _encode(checkpoint, path, seq, image, model, metric, quality, coder, i, ref, total_bpp, ff, output, log_path):
compressai.set_entropy_coder(coder)
enc_start = time.time()
img = load_image(image)
img = load_image(image+"_frame"+str(ff)+".png")
start = time.time()
net = models[model](quality=quality, metric=metric, pretrained=True).eval()
net = models[model](quality=quality, metric=metric, pretrained=True)
net.eval()
load_time = time.time() - start
x = img2torch(img)
......@@ -182,45 +203,26 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
recon_out = net.decompress(strings,shape)
x_recon = crop(recon_out["x_hat"], (h, w))
psnr=compute_psnr(x, x_recon)
ssim=compute_msssim(x, x_recon)
ycbcr=ycbcr_psnr(x, x_recon)
else:
if checkpoint: # load from previous checkpoint
checkpoint = torch.load(checkpoint)
#state_dict = load_state_dict(checkpoint["state_dict"])
net=models[model](quality=quality, metric=metric)
net.load_state_dict(checkpoint["state_dict"])
net.update(force=True)
else:
net = models[model](quality=quality, metric=metric, pretrained=True)
diff=x-ref
#1
diff1=torch.clamp(diff, min=-0.5, max=0.5)+0.5
#2
'''
diff1=torch.clamp(diff, min=0.0, max=1.0)
diff2=-torch.clamp(diff, min=-1.0, max=0.0)
diff1=pad(diff1, p)
diff2=pad(diff2, p)
'''
#1
with torch.no_grad():
out1 = net.compress(diff1)
shape1 = out1["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape1[0], shape1[1], len(out1["strings"])))
for s in out1["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
x_hat1 = crop(recon_out["x_hat"], (h, w))
#2
'''
with torch.no_grad():
out1 = net.compress(diff1)
shape1 = out1["shape"]
......@@ -236,32 +238,17 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape1[0], shape1[1], len(out1["strings"])))
recon_out = net.decompress(strings, shape1)
x_hat1 = crop(recon_out["x_hat"], (h, w))
with torch.no_grad():
out = net.compress(diff2)
shape = out["shape"]
strings = []
with Path(output).open("ab") as f:
# write shape and number of encoded latents
write_uints(f, (shape[0], shape[1], len(out["strings"])))
for s in out["strings"]:
write_uints(f, (len(s[0]),))
write_bytes(f, s[0])
strings.append([s[0]])
with torch.no_grad():
recon_out = net.decompress(strings, (shape[0], shape[1], len(out["strings"])))
x_hat2 = crop(recon_out["x_hat"], (h, w))
x_recon=ref+x_hat1-x_hat2
'''
x_recon=ref+x_hat1-0.5
psnr=compute_psnr(x, x_recon)
ssim=compute_msssim(x, x_recon)
ycbcr=ycbcr_psnr(x, x_recon)
diff_img = torch2img(diff1)
diff_img.save(path+"recon/diff"+str(ff)+"_q"+str(quality)+".png")
# diff_img.save(path+"recon/"+seq+str(ff)+"_q"+str(quality)+".png")
# diff_img.save("../Data/train/"+seq+str(ff)+"_train8_q"+str(quality)+".png")
enc_time = time.time() - enc_start
size = filesize(output)
......@@ -269,11 +256,13 @@ def _encode(path, image, model, metric, quality, coder, i, ref, total_bpp, ff, o
with Path(log_path).open("a") as f:
f.write( f" {bpp-total_bpp:.4f} | "
f" {psnr:.4f} |"
f" {ssim:.4f} |"
f" {ycbcr:.4f} |"
f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)\n")
recon_img = torch2img(x_recon)
recon_img.save(path+"recon/recon"+str(ff)+"_q"+str(quality)+".png")
return psnr, bpp, x_recon, enc_time
return psnr, bpp, x_recon, enc_time, ssim, ycbcr
def _decode(inputpath, coder, show, frame, output=None):
......@@ -381,13 +370,19 @@ def encode(argv):
default=768,
help="hight setting (default: %(default))",
)
parser.add_argument(
"-check",
"--checkpoint",
type=str,
help="Path to a checkpoint",
)
parser.add_argument("-o", "--output", help="Output path")
args = parser.parse_args(argv)
path="examples/"+args.image+"/"
if not args.output:
#args.output = Path(Path(args.image).resolve().name).with_suffix(".bin")
args.output = path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.bin"
log_path=path+args.image+"_"+args.model+"_q"+str(args.quality)+"_v2.txt"
args.output = path+args.image+"_q"+str(args.quality)+"_train_ssim.bin"
log_path=path+args.image+"_q"+str(args.quality)+"_train_ssim.txt"
header = get_header(args.model, args.metric, args.quality)
with Path(args.output).open("wb") as f:
......@@ -400,32 +395,43 @@ def encode(argv):
f"frames : {args.frame}\n")
f.write( f"frame | bpp | "
f" psnr |"
f" ssim |"
f" Encoded time (model loading)\n"
f" {0:3d} | ")
total_psnr=0.0
total_ssim=0.0
total_ycbcr=0.0
total_bpp=0.0
total_time=0.0
args.image =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
img=args.image+"_frame"+str(0)+".png"
total_psnr, total_bpp, ref, total_time = _encode(path, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
img =path + args.image+"_768x768_"+str(args.framerate)+"_8bit_444"
total_psnr, total_bpp, ref, total_time, total_ssim, total_ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, 0, total_bpp, 0, args.output, log_path)
for ff in range(1, args.frame):
with Path(log_path).open("a") as f:
f.write(f" {ff:3d} | ")
img=args.image+"_frame"+str(ff)+".png"
psnr, total_bpp, ref, time = _encode(path, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
if ff%25==0:
psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, True, ref, total_bpp, ff, args.output, log_path)
else:
psnr, total_bpp, ref, time, ssim, ycbcr = _encode(args.checkpoint, path, args.image, img, args.model, args.metric, args.quality, args.coder, False, ref, total_bpp, ff, args.output, log_path)
total_psnr+=psnr
total_ssim+=ssim
total_ycbcr+=ycbcr
total_time+=time
total_psnr/=args.frame
total_ssim/=args.frame
total_ycbcr/=args.frame
total_bpp/=args.frame
with Path(log_path).open("a") as f:
f.write( f"\n Total Encoded time: {total_time:.2f}s\n"
f"\n Total PSNR: {total_psnr:.6f}\n"
f"\n Total SSIM: {total_ssim:.6f}\n"
f"\n Total ycbcr: {total_ycbcr:.6f}\n"
f" Total BPP: {total_bpp:.6f}\n")
print(total_psnr)
print(total_ssim)
print(total_ycbcr)
print(total_bpp)
......
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.