CompressAI使用的熵编码器为ANS,训练时用于估计输入分布的方式为论文Variational Image Compression With a Scale Hyperprior 中提出的单变量非参数密度模型,具体分析可见基于熵编码VAE的图像压缩模型 。
部分工具函数与类 LowerBound
LowerBound
实现了使用初始化的下界对输入进行可微分的截断。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 def lower_bound_fwd (x: Tensor, bound: Tensor ) -> Tensor: return torch.max (x, bound) def lower_bound_bwd (x: Tensor, bound: Tensor, grad_output: Tensor ): pass_through_if = (x >= bound) | (grad_output < 0 ) return pass_through_if * grad_output, None class LowerBoundFunction (torch.autograd.Function): """Autograd function for the `LowerBound` operator.""" @staticmethod def forward (ctx, x, bound ): ctx.save_for_backward(x, bound) return lower_bound_fwd(x, bound) @staticmethod def backward (ctx, grad_output ): x, bound = ctx.saved_tensors return lower_bound_bwd(x, bound, grad_output) class LowerBound (nn.Module): """Lower bound operator, computes `torch.max(x, bound)` with a custom gradient. The derivative is replaced by the identity function when `x` is moved towards the `bound`, otherwise the gradient is kept to zero. """ bound: Tensor def __init__ (self, bound: float ): super ().__init__() self .register_buffer("bound" , torch.Tensor([float (bound)])) @torch.jit.unused def lower_bound (self, x ): return LowerBoundFunction.apply(x, self .bound) def forward (self, x ): if torch.jit.is_scripting(): return torch.max (x, self .bound) return self .lower_bound(x)
_pmf_to_cdf
1 2 3 4 5 6 7 8 9 def _pmf_to_cdf (self, pmf, tail_mass, pmf_length, max_length: int ): cdf = torch.zeros( (len (pmf_length), max_length + 2 ), dtype=torch.int32, device=pmf.device ) for i, p in enumerate (pmf): prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0 ) _cdf = pmf_to_quantized_cdf(prob, self .entropy_coder_precision) cdf[i, : _cdf.size(0 )] = _cdf return cdf
EntropyModel EntropyModel
定义在compressai/entropy_models/entropy_models.py
中。EntropyModel
继承了nn.Module
,所以可以通过.train()
或.eval()
的方式在训练、推理时使用不同的量化方法。其主要实现了量化方式以及利用ANS进行压缩、解压的流程。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def __init__ ( self, likelihood_bound: float = 1e-9 , entropy_coder: Optional [str ] = None , entropy_coder_precision: int = 16 , ): super ().__init__() if entropy_coder is None : entropy_coder = default_entropy_coder() self .entropy_coder = _EntropyCoder(entropy_coder) self .entropy_coder_precision = int (entropy_coder_precision) self .use_likelihood_bound = likelihood_bound > 0 if self .use_likelihood_bound: self .likelihood_lower_bound = LowerBound(likelihood_bound) self .register_buffer("_offset" , torch.IntTensor()) self .register_buffer("_quantized_cdf" , torch.IntTensor()) self .register_buffer("_cdf_length" , torch.IntTensor())
likelihood_bound
:截断返回的概率分布,以避免entropy loss出现极大的梯度。
entropy_coder_precision
设定了熵编码器的精度,默认设定为16,对于ANS而言即使用$[0,2^{16})$范围内的整数去表示压缩的数据。精度的设置影响了熵编码器压缩效率与压缩速度的权衡。
quantize
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 def quantize ( self, inputs: Tensor, mode: str , means: Optional [Tensor] = None ) -> Tensor: if mode not in ("noise" , "dequantize" , "symbols" ): raise ValueError(f'Invalid quantization mode: "{mode} "' ) if mode == "noise" : half = float (0.5 ) noise = torch.empty_like(inputs).uniform_(-half, half) inputs = inputs + noise return inputs outputs = inputs.clone() if means is not None : outputs -= means outputs = torch.round (outputs) if mode == "dequantize" : if means is not None : outputs += means return outputs assert mode == "symbols" , mode outputs = outputs.int () return outputs
quantize
方法中支持三种量化方式:
dequantize
1 2 3 4 5 6 7 8 9 10 @staticmethod def dequantize ( inputs: Tensor, means: Optional [Tensor] = None , dtype: torch.dtype = torch.float ) -> Tensor: if means is not None : outputs = inputs.type_as(means) outputs += means else : outputs = inputs.type (dtype) return outputs
与quantize
相反,如果有使用均值,则对输入添加均值。
compress
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 def compress (self, inputs, indexes, means=None ): """ Compress input tensors to char strings. Args: inputs (torch.Tensor): input tensors indexes (torch.IntTensor): tensors CDF indexes means (torch.Tensor, optional): optional tensor means """ symbols = self .quantize(inputs, "symbols" , means) if len (inputs.size()) < 2 : raise ValueError( "Invalid `inputs` size. Expected a tensor with at least 2 dimensions." ) if inputs.size() != indexes.size(): raise ValueError("`inputs` and `indexes` should have the same size." ) self ._check_cdf_size() self ._check_cdf_length() self ._check_offsets_size() strings = [] for i in range (symbols.size(0 )): rv = self .entropy_coder.encode_with_indexes( symbols[i].reshape(-1 ).int ().tolist(), indexes[i].reshape(-1 ).int ().tolist(), self ._quantized_cdf.tolist(), self ._cdf_length.reshape(-1 ).int ().tolist(), self ._offset.reshape(-1 ).int ().tolist(), ) strings.append(rv) return strings
decompress
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 def decompress ( self, strings: str , indexes: torch.IntTensor, dtype: torch.dtype = torch.float , means: torch.Tensor = None , ): """ Decompress char strings to tensors. Args: strings (str): compressed tensors indexes (torch.IntTensor): tensors CDF indexes dtype (torch.dtype): type of dequantized output means (torch.Tensor, optional): optional tensor means """ if not isinstance (strings, (tuple , list )): raise ValueError("Invalid `strings` parameter type." ) if not len (strings) == indexes.size(0 ): raise ValueError("Invalid strings or indexes parameters" ) if len (indexes.size()) < 2 : raise ValueError( "Invalid `indexes` size. Expected a tensor with at least 2 dimensions." ) self ._check_cdf_size() self ._check_cdf_length() self ._check_offsets_size() if means is not None : if means.size()[:2 ] != indexes.size()[:2 ]: raise ValueError("Invalid means or indexes parameters" ) if means.size() != indexes.size(): for i in range (2 , len (indexes.size())): if means.size(i) != 1 : raise ValueError("Invalid means parameters" ) cdf = self ._quantized_cdf outputs = cdf.new_empty(indexes.size()) for i, s in enumerate (strings): values = self .entropy_coder.decode_with_indexes( s, indexes[i].reshape(-1 ).int ().tolist(), cdf.tolist(), self ._cdf_length.reshape(-1 ).int ().tolist(), self ._offset.reshape(-1 ).int ().tolist(), ) outputs[i] = torch.tensor( values, device=outputs.device, dtype=outputs.dtype ).reshape(outputs[i].size()) outputs = self .dequantize(outputs, means, dtype) return outputs
_likelihood
1 2 3 4 5 6 7 8 def _likelihood ( self, inputs: Tensor, stop_gradient: bool = False ) -> Tuple [Tensor, Tensor, Tensor]: half = float (0.5 ) lower = self ._logits_cumulative(inputs - half, stop_gradient=stop_gradient) upper = self ._logits_cumulative(inputs + half, stop_gradient=stop_gradient) likelihood = torch.sigmoid(upper) - torch.sigmoid(lower) return likelihood, lower, upper
EntropyBottleneck EntropyBottleneck
继承EntropyModel
。实现了在训练时对传入Tensor的熵估计。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 def __init__ ( self, channels: int , *args: Any , tail_mass: float = 1e-9 , init_scale: float = 10 , filters: Tuple [int , ...] = (3 , 3 , 3 , 3 ), **kwargs: Any , ): super ().__init__(*args, **kwargs) self .channels = int (channels) self .filters = tuple (int (f) for f in filters) self .init_scale = float (init_scale) self .tail_mass = float (tail_mass) filters = (1 ,) + self .filters + (1 ,) scale = self .init_scale ** (1 / (len (self .filters) + 1 )) channels = self .channels for i in range (len (self .filters) + 1 ): init = np.log(np.expm1(1 / scale / filters[i + 1 ])) matrix = torch.Tensor(channels, filters[i + 1 ], filters[i]) matrix.data.fill_(init) self .register_parameter(f"_matrix{i:d} " , nn.Parameter(matrix)) bias = torch.Tensor(channels, filters[i + 1 ], 1 ) nn.init.uniform_(bias, -0.5 , 0.5 ) self .register_parameter(f"_bias{i:d} " , nn.Parameter(bias)) if i < len (self .filters): factor = torch.Tensor(channels, filters[i + 1 ], 1 ) nn.init.zeros_(factor) self .register_parameter(f"_factor{i:d} " , nn.Parameter(factor)) self .quantiles = nn.Parameter(torch.Tensor(channels, 1 , 3 )) init = torch.Tensor([-self .init_scale, 0 , self .init_scale]) self .quantiles.data = init.repeat(self .quantiles.size(0 ), 1 , 1 ) target = np.log(2 / self .tail_mass - 1 ) self .register_buffer("target" , torch.Tensor([-target, 0 , target]))
tail_mass
:对于离散概率质量函数的尾部,不使用range coding而是使用Golomb coding,默认值为$1e-9$,即所有值中的1e-9
将会使用Golomb code进行编码。
init_scale
:用于确定概率密度的初始宽度。为确保输入的范围落在[-init_scale, init_scale]
内,需要设置得比较大。
filters
:给出了密度模型每层过滤器的数量。
_standardized_cumulative
1 2 3 4 def _standardized_cumulative (self, inputs: Tensor ) -> Tensor: half = float (0.5 ) const = float (-(2 **-0.5 )) return half * torch.erfc(const * inputs)
此处使用了误差互补函数,
当输入为$x_0>0$时,返回为$\frac{1}{2}\cdot\text{erfc}(-\frac{x_0}{\sqrt{2}})$,其值等于正态分布$X\sim \mathcal{N}(0,1)$中,当$x>x_0$的分布下部分与x轴之间的面积。
_standardized_quantile
1 2 3 @staticmethod def _standardized_quantile (quantile: float ): return scipy.stats.norm.ppf(quantile)
scipy.stats.norm.ppf
为分位点函数,即CDF函数的逆函数,由CDF函数的y值求出对应的x。
update_scale_table
1 2 3 4 5 6 def update_scale_table(self, scale_table, force=False): if self._offset.numel() > 0 and not force: return False device = self.scale_table.device self.scale_table = self._prepare_scale_table(scale_table).to(device) self.update()
update
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def update (self ): multiplier: float = -self ._standardized_quantile(self .tail_mass / 2 ) pmf_center = torch.ceil(self .scale_table * multiplier).int () pmf_length = 2 * pmf_center + 1 max_length = torch.max (pmf_length).item() device = pmf_center.device samples = torch.abs ( torch.arange(max_length, device=device).int () - pmf_center[:, None ] ) samples_scale = self .scale_table.unsqueeze(1 ) samples = samples.float () samples_scale = samples_scale.float () upper = self ._standardized_cumulative((0.5 - samples) / samples_scale) lower = self ._standardized_cumulative((-0.5 - samples) / samples_scale) pmf = upper - lower tail_mass = 2 * lower[:, :1 ] quantized_cdf = torch.Tensor(len (pmf_length), max_length + 2 ) quantized_cdf = self ._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) self ._quantized_cdf = quantized_cdf self ._offset = -pmf_center self ._cdf_length = pmf_length + 2
1 2 3 4 def pmf_to_quantized_cdf (pmf: Tensor, precision: int = 16 ) -> Tensor: cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) cdf = torch.IntTensor(cdf) return cdf
GaussianConditional mean-scale hyper输出$\mu$与$\sigma$,由此对$y$中的每一点进行高斯建模,计算出每个像素值的概率,同时进行熵编码得到比特流。
GaussianConditional
会返回两个值,量化后的输出$\hat{y}$,每个待编码值的出现概率估计likelihood。
其中训练过程中bpp loss
可用得到的likelihood进行计算得到:
1 2 3 4 5 6 num_pixels = N * H * W out["bpp_loss" ] = sum ( (torch.log(likelihoods).sum () / (-math.log(2 ) * num_pixels)) for likelihoods in output["likelihoods" ].values() )