CompressAI使用的熵编码器为ANS,训练时用于估计输入分布的方式为论文Variational Image Compression With a Scale Hyperprior中提出的单变量非参数密度模型,具体分析可见基于熵编码VAE的图像压缩模型

部分工具函数与类

LowerBound

LowerBound实现了使用初始化的下界对输入进行可微分的截断。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor:
return torch.max(x, bound)

def lower_bound_bwd(x: Tensor, bound: Tensor, grad_output: Tensor):
pass_through_if = (x >= bound) | (grad_output < 0) # 当输入x大于bound,或者梯度小于0时可以传播梯度
return pass_through_if * grad_output, None # 相对于对x进行max操作的梯度设置为0或1

class LowerBoundFunction(torch.autograd.Function):
"""Autograd function for the `LowerBound` operator."""

@staticmethod
def forward(ctx, x, bound):
ctx.save_for_backward(x, bound)
return lower_bound_fwd(x, bound)

@staticmethod
def backward(ctx, grad_output):
x, bound = ctx.saved_tensors
return lower_bound_bwd(x, bound, grad_output)

class LowerBound(nn.Module):
"""Lower bound operator, computes `torch.max(x, bound)` with a custom
gradient.

The derivative is replaced by the identity function when `x` is moved
towards the `bound`, otherwise the gradient is kept to zero.
"""

bound: Tensor

def __init__(self, bound: float):
super().__init__()
self.register_buffer("bound", torch.Tensor([float(bound)]))

@torch.jit.unused
def lower_bound(self, x):
return LowerBoundFunction.apply(x, self.bound)

def forward(self, x):
if torch.jit.is_scripting():
return torch.max(x, self.bound)
return self.lower_bound(x)

_pmf_to_cdf

1
2
3
4
5
6
7
8
9
def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length: int):
cdf = torch.zeros(
(len(pmf_length), max_length + 2), dtype=torch.int32, device=pmf.device
)
for i, p in enumerate(pmf):
prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
_cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision)
cdf[i, : _cdf.size(0)] = _cdf
return cdf

EntropyModel

EntropyModel定义在compressai/entropy_models/entropy_models.py中。EntropyModel继承了nn.Module,所以可以通过.train().eval()的方式在训练、推理时使用不同的量化方法。其主要实现了量化方式以及利用ANS进行压缩、解压的流程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def __init__(
self,
likelihood_bound: float = 1e-9,
entropy_coder: Optional[str] = None,
entropy_coder_precision: int = 16, # 熵编码器精度
):
super().__init__()

if entropy_coder is None:
entropy_coder = default_entropy_coder() # 可选择ANS或rangecoder,默认使用ANS
self.entropy_coder = _EntropyCoder(entropy_coder)
self.entropy_coder_precision = int(entropy_coder_precision)

self.use_likelihood_bound = likelihood_bound > 0
if self.use_likelihood_bound:
self.likelihood_lower_bound = LowerBound(likelihood_bound)

# to be filled on update()
self.register_buffer("_offset", torch.IntTensor())
self.register_buffer("_quantized_cdf", torch.IntTensor())
self.register_buffer("_cdf_length", torch.IntTensor())
  • likelihood_bound:截断返回的概率分布,以避免entropy loss出现极大的梯度。
  • entropy_coder_precision设定了熵编码器的精度,默认设定为16,对于ANS而言即使用$[0,2^{16})$范围内的整数去表示压缩的数据。精度的设置影响了熵编码器压缩效率与压缩速度的权衡。

quantize

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def quantize(
self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
) -> Tensor:
if mode not in ("noise", "dequantize", "symbols"):
raise ValueError(f'Invalid quantization mode: "{mode}"')

if mode == "noise":
half = float(0.5)
noise = torch.empty_like(inputs).uniform_(-half, half)
inputs = inputs + noise
return inputs

outputs = inputs.clone()
if means is not None:
outputs -= means

outputs = torch.round(outputs)

if mode == "dequantize":
if means is not None:
outputs += means
return outputs

assert mode == "symbols", mode
outputs = outputs.int()
return outputs

quantize方法中支持三种量化方式:

  • noise:在输入上添加均匀分布的噪声模拟量化阶为1的量化,理论具体可见END-TO-END OPTIMIZED IMAGE COMPRESSION,即通过均匀加性噪声避免了直接量化导致不可求导的问题

  • dequantize:将输入通过torch.round量化为整型数int

  • symbol:将输入减去均值(如有)并通过torch.round量化化为整型数int

dequantize

1
2
3
4
5
6
7
8
9
10
@staticmethod
def dequantize(
inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float
) -> Tensor:
if means is not None:
outputs = inputs.type_as(means)
outputs += means
else:
outputs = inputs.type(dtype)
return outputs

quantize相反,如果有使用均值,则对输入添加均值。

compress

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def compress(self, inputs, indexes, means=None):
"""
Compress input tensors to char strings.

Args:
inputs (torch.Tensor): input tensors
indexes (torch.IntTensor): tensors CDF indexes
means (torch.Tensor, optional): optional tensor means
"""
symbols = self.quantize(inputs, "symbols", means) # 将inputs量化为Int

if len(inputs.size()) < 2:
raise ValueError(
"Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
)

if inputs.size() != indexes.size():
raise ValueError("`inputs` and `indexes` should have the same size.")

self._check_cdf_size()
self._check_cdf_length()
self._check_offsets_size()

strings = []
# 将量化后的inputs(即symbols)通过熵编码器进行压缩并获得字节流rv
for i in range(symbols.size(0)):
rv = self.entropy_coder.encode_with_indexes(
symbols[i].reshape(-1).int().tolist(),
indexes[i].reshape(-1).int().tolist(),
self._quantized_cdf.tolist(),
self._cdf_length.reshape(-1).int().tolist(),
self._offset.reshape(-1).int().tolist(),
)
strings.append(rv)
return strings

decompress

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def decompress(
self,
strings: str,
indexes: torch.IntTensor,
dtype: torch.dtype = torch.float,
means: torch.Tensor = None,
):
"""
Decompress char strings to tensors.

Args:
strings (str): compressed tensors
indexes (torch.IntTensor): tensors CDF indexes
dtype (torch.dtype): type of dequantized output
means (torch.Tensor, optional): optional tensor means
"""

if not isinstance(strings, (tuple, list)):
raise ValueError("Invalid `strings` parameter type.")
if not len(strings) == indexes.size(0):
raise ValueError("Invalid strings or indexes parameters")

if len(indexes.size()) < 2:
raise ValueError(
"Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
)

self._check_cdf_size()
self._check_cdf_length()
self._check_offsets_size()

if means is not None:
if means.size()[:2] != indexes.size()[:2]:
raise ValueError("Invalid means or indexes parameters")
if means.size() != indexes.size():
for i in range(2, len(indexes.size())):
if means.size(i) != 1:
raise ValueError("Invalid means parameters")

cdf = self._quantized_cdf
outputs = cdf.new_empty(indexes.size())

for i, s in enumerate(strings):
values = self.entropy_coder.decode_with_indexes(
s,
indexes[i].reshape(-1).int().tolist(),
cdf.tolist(),
self._cdf_length.reshape(-1).int().tolist(),
self._offset.reshape(-1).int().tolist(),
)
outputs[i] = torch.tensor(
values, device=outputs.device, dtype=outputs.dtype
).reshape(outputs[i].size())
outputs = self.dequantize(outputs, means, dtype)
return outputs

_likelihood

1
2
3
4
5
6
7
8
def _likelihood(
self, inputs: Tensor, stop_gradient: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
half = float(0.5)
lower = self._logits_cumulative(inputs - half, stop_gradient=stop_gradient)
upper = self._logits_cumulative(inputs + half, stop_gradient=stop_gradient)
likelihood = torch.sigmoid(upper) - torch.sigmoid(lower)
return likelihood, lower, upper

EntropyBottleneck

EntropyBottleneck继承EntropyModel。实现了在训练时对传入Tensor的熵估计。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
self,
channels: int, # 输入的Tensor的通道数
*args: Any,
tail_mass: float = 1e-9,
init_scale: float = 10,
filters: Tuple[int, ...] = (3, 3, 3, 3),
**kwargs: Any,
):
super().__init__(*args, **kwargs)

self.channels = int(channels)
self.filters = tuple(int(f) for f in filters)
self.init_scale = float(init_scale)
self.tail_mass = float(tail_mass)

filters = (1,) + self.filters + (1,)
scale = self.init_scale ** (1 / (len(self.filters) + 1))
channels = self.channels

for i in range(len(self.filters) + 1):
init = np.log(np.expm1(1 / scale / filters[i + 1]))
matrix = torch.Tensor(channels, filters[i + 1], filters[i])
matrix.data.fill_(init)
self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix))

bias = torch.Tensor(channels, filters[i + 1], 1)
nn.init.uniform_(bias, -0.5, 0.5)
self.register_parameter(f"_bias{i:d}", nn.Parameter(bias))

if i < len(self.filters):
factor = torch.Tensor(channels, filters[i + 1], 1)
nn.init.zeros_(factor)
self.register_parameter(f"_factor{i:d}", nn.Parameter(factor))

self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3))
init = torch.Tensor([-self.init_scale, 0, self.init_scale])
self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1)

target = np.log(2 / self.tail_mass - 1)
self.register_buffer("target", torch.Tensor([-target, 0, target]))
  • tail_mass:对于离散概率质量函数的尾部,不使用range coding而是使用Golomb coding,默认值为$1e-9$,即所有值中的1e-9将会使用Golomb code进行编码。
  • init_scale:用于确定概率密度的初始宽度。为确保输入的范围落在[-init_scale, init_scale]内,需要设置得比较大。
  • filters:给出了密度模型每层过滤器的数量。

_standardized_cumulative

1
2
3
4
def _standardized_cumulative(self, inputs: Tensor) -> Tensor:
half = float(0.5)
const = float(-(2**-0.5))
return half * torch.erfc(const * inputs)

此处使用了误差互补函数,

当输入为$x_0>0$时,返回为$\frac{1}{2}\cdot\text{erfc}(-\frac{x_0}{\sqrt{2}})$,其值等于正态分布$X\sim \mathcal{N}(0,1)$中,当$x>x_0$的分布下部分与x轴之间的面积。

_standardized_quantile

1
2
3
@staticmethod
def _standardized_quantile(quantile: float):
return scipy.stats.norm.ppf(quantile)

scipy.stats.norm.ppf为分位点函数,即CDF函数的逆函数,由CDF函数的y值求出对应的x。

update_scale_table

1
2
3
4
5
6
def update_scale_table(self, scale_table, force=False):
if self._offset.numel() > 0 and not force:
return False
device = self.scale_table.device
self.scale_table = self._prepare_scale_table(scale_table).to(device)
self.update()

update

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def update(self):
#
multiplier: float = -self._standardized_quantile(self.tail_mass / 2)
pmf_center = torch.ceil(self.scale_table * multiplier).int()
pmf_length = 2 * pmf_center + 1
max_length = torch.max(pmf_length).item()

device = pmf_center.device
samples = torch.abs(
torch.arange(max_length, device=device).int() - pmf_center[:, None]
)
samples_scale = self.scale_table.unsqueeze(1)
samples = samples.float()
samples_scale = samples_scale.float()
upper = self._standardized_cumulative((0.5 - samples) / samples_scale)
lower = self._standardized_cumulative((-0.5 - samples) / samples_scale)
pmf = upper - lower

tail_mass = 2 * lower[:, :1]

quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
self._quantized_cdf = quantized_cdf
self._offset = -pmf_center
self._cdf_length = pmf_length + 2
1
2
3
4
def pmf_to_quantized_cdf(pmf: Tensor, precision: int = 16) -> Tensor:
cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision)
cdf = torch.IntTensor(cdf)
return cdf

GaussianConditional

mean-scale hyper输出$\mu$与$\sigma$,由此对$y$中的每一点进行高斯建模,计算出每个像素值的概率,同时进行熵编码得到比特流。

GaussianConditional会返回两个值,量化后的输出$\hat{y}$,每个待编码值的出现概率估计likelihood。

其中训练过程中bpp loss可用得到的likelihood进行计算得到:

1
2
3
4
5
6
num_pixels = N * H * W

out["bpp_loss"] = sum(
(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
for likelihoods in output["likelihoods"].values()
)