log-sum-exp 的计算溢出解决

mac2022-06-30  27

1 原始的定义 Log ⁡ Sum ⁡ Exp ⁡ ( x 1 … x n ) = log ⁡ ( ∑ i = 1 n e x i ) \operatorname{Log} \operatorname{Sum} \operatorname{Exp}\left(x_{1} \ldots x_{n}\right)=\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) LogSumExp(x1xn)=log(i=1nexi) 2 SoftMax e x j ∑ i = 1 n e x i \frac{e^{x_{j}}}{\sum_{i=1}^{n} e^{x_{i}}} i=1nexiexj 3 对SoftMax取对数 log ⁡ ( e x j ∑ i = 1 n e x i ) \log \left(\frac{e^{x_{j}}}{\sum_{i=1}^{n} e^{x_{i}}}\right) log(i=1nexiexj) 4 变形 log ⁡ ( e x j ∑ i = 1 n e x i ) = log ⁡ ( e x j ) − log ⁡ ( ∑ i = 1 n e x i ) = x j − log ⁡ ( ∑ i = 1 n e x i ) \begin{aligned} \log \left(\frac{e^{x_{j}}}{\sum_{i=1}^{n} e^{x_{i}}}\right) &=\log \left(e^{x_{j}}\right)-\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) \\ &=x_{j}-\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) \end{aligned} log(i=1nexiexj)=log(exj)log(i=1nexi)=xjlog(i=1nexi) 5 计算LogSumExp 当数据较大时, e x i e^{x_{i}} exi的计算可能溢出,因此对4的第二项LogSumExp继续变形 log ⁡ Sum ⁡ Exp ⁡ ( x 1 … x n ) = log ⁡ ( ∑ i = 1 n e x i ) = log ⁡ ( ∑ i = 1 n e x i − c e c ) = log ⁡ ( e c ∑ i = 1 n e x i − c ) = log ⁡ ( ∑ i = 1 n e x i − c ) + log ⁡ ( e c ) = log ⁡ ( ∑ i = 1 n e x i − c ) + c \begin{aligned} \log \operatorname{Sum} \operatorname{Exp}\left(x_{1} \ldots x_{n}\right) &=\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) \\ &=\log \left(\sum_{i=1}^{n} e^{x_{i}-c} e^{c}\right) \\ &=\log \left(e^{c} \sum_{i=1}^{n} e^{x_{i}-c}\right) \\ &=\log \left(\sum_{i=1}^{n} e^{x_{i}-c}\right)+\log \left(e^{c}\right) \\ &=\log \left(\sum_{i=1}^{n} e^{x_{i}-c}\right)+c \end{aligned} logSumExp(x1xn)=log(i=1nexi)=log(i=1nexicec)=log(eci=1nexic)=log(i=1nexic)+log(ec)=log(i=1nexic)+c

6 得到4的变形

log ⁡ ( SoftMax ⁡ ( x j , x 1 … x n ) ) = x j − log ⁡ Sum ⁡ Exp ⁡ ( x 1 … x n ) = x j − log ⁡ ( ∑ i = 1 n e x i − c ) − c \begin{aligned} \log \left(\operatorname{SoftMax}\left(x_{j}, x_{1} \ldots x_{n}\right)\right) &=x_{j}-\log \operatorname{Sum} \operatorname{Exp}\left(x_{1} \ldots x_{n}\right) \\ &=x_{j}-\log \left(\sum_{i=1}^{n} e^{x_{i}-c}\right)-c \end{aligned} log(SoftMax(xj,x1xn))=xjlogSumExp(x1xn)=xjlog(i=1nexic)c c取max(x1…xn)。

最新回复(0)