HALO

Raphael Pisoni 提出的一种名为 HALO (Hyperspherical Alignment & Latent Optimization, 超球面界定与潜在优化) 的全新损失函数。

该算法的提出旨在解决现代神经网络分类器存在的“盲目自信”和“幻觉”问题,通过重构潜在空间(Latent Space)的几何结构,使得模型在保持高准确率的同时,具备强大的“分布外(OOD)检测”能力。

博客:pisoni.ai — Soap Bubbles and Attention Sinks: The Theory and History of the HALO-Loss | pisoni.ai


motivation

在标准的神经网络中,分类头通常使用分类交叉熵损失,通过不受约束的“点积”计算 logits。 由于最后一步使用了 Softmax 激活函数,模型为了达到 100% 的置信度并将损失值降为0,会在训练中被迫将特征向量无限推离原点(即所谓的“径向爆炸”)。

方法

假设有一个分类任务,共有 $K$ 个目标类别,特征维度为 $D$。

  1. 输入提取:样本 $\mathbf{x}$ 经过骨干网络(如 ResNet, ViT,或大模型的特征提取层),输出一个 $D$ 维的特征向量 $\mathbf{z}$。

  2. 中心点矩阵:模型内部维护着一个可学习的参数矩阵 $\mathbf{C} \in \mathbb{R}^{K \times D}$,代表 $K$ 个类别的几何中心。

  3. 零均值化干预:为了防止整个潜在空间在训练中发生“整体漂移”,在每一步计算前,强制将所有中心点减去它们的均值,将其锚定在原点周围:

    $$\mathbf{C} \leftarrow \mathbf{C} - \frac{1}{K}\sum_{i=1}^{K}\mathbf{c}_i$$

在 HALO 算法中,模型不再计算普通的点积,而是计算输入特征 $x$ 和某个类别中心点 $c$ 之间的负平方欧氏距离,以此作为未归一化的得分(logit):

$$\text{原始 Logit} = -|z - c|^2$$

将其展开,我们得到:

$$-|z - c|^2 = -|z|^2 + 2(z \cdot c) - |c|^2$$

对于某一个特定的输入图片 $x$ 来说,它与所有类别中心计算距离时,$-|x|^2$ 这一项都是完全一样的常数

利用 Softmax 的平移不变性,整体加上一个偏置 $|\mathbf{z}|^2$,转化为计算“平移后的 Logit(Shifted Logit)”。

为了控制高维空间的数值方差,引入一个基于维度的缩放系数 $\frac{1}{D}$,以及一个动态温度参数 $\gamma$。对于第 $i$ 个类别,其得分为:

$$s_i = \frac{\gamma}{D} \left( 2(\mathbf{z} \cdot \mathbf{c}_i) - |\mathbf{c}_i|^2 \right)$$

为了从架构层面拦截分布外(OOD)的噪音数据,HALO 在 $K$ 个常规类别之外,虚拟叠加了第 $K+1$ 个类别——“弃权类 (Abstain Class)”。

根据前面的平移法则,假设弃权类的中心死死钉在绝对原点 $\mathbf{0}$,那么它原始的得分为 $0$。

2(z*0)-||0||^2=0

但在工程实现中,为了提供一个数学上最优的“拒识阈值”,作者根据空间的初始几何状态(目标半径 $r_{target}^2$ 和交叉熵裕度 $\text{margin}{ce}$),直接算出了一个固定的理想偏差值 $s{abstain}$:

$$s_{abstain} = t_{ideal} - \text{margin}_{ce}$$

将这个标量 $s_{abstain}$ 直接拼接到前面算出的 $K$ 个 Logit 后面,我们得到了最终长度为 $K+1$ 的 Logit 向量 $\mathbf{S}$:

$$\mathbf{S} = [s_1, s_2, \dots, s_K, s_{abstain}]$$

高维向量

在高维空间 $D$ 中,如果不加干预,将特征向量强行压向中心点(距离为 0)会导致表示坍缩。因此,计算特征向量 $\mathbf{z}$ 的真实平方范数 $r^2 = \frac{|\mathbf{z}|^2}{D}$,并对其施加一个符合高维球壳分布的负对数似然损失:

$$L_{radial} = - \left( \left(\frac{1}{2} - \frac{1}{D}\right) \log(r^2) - \frac{1}{2} r^2 \right)$$

  • $-\frac{1}{2}r^2$:高斯先验,像引力一样防止向量飞向无限远。
  • $\left(\frac{1}{2} - \frac{1}{D}\right)\log(r^2)$:斥力项,模拟高维体积的剧烈扩张,将向量向外推,使其稳稳地落在 $D$ 维的“肥皂泡”壳上。

无教师自蒸馏与软标签构建

这是为了保护类别之间的相对语义拓扑关系。假设当前样本的正确类别(Positive Target)是 $y$。

  1. 生成当前分布:先用 Softmax 将 $\mathbf{S}$ 转化为概率分布 $\mathbf{P} = \text{Softmax}(\mathbf{S})$。

  2. 屏蔽正类:将正确类别 $y$ 的 Logit 设为负无穷 $-\infty$,再次计算 Softmax,得到一个仅包含负类和弃权类的软标签分布 $\mathbf{P}_{soft}$。由于这是基于模型内部真实的相对距离生成的,它完美保留了负类之间的语义相似度(比如猫距离狗近,距离飞机远)。

  3. 计算蒸馏损失:将网络当前的非正类输出概率,与这个 $\mathbf{P}_{soft}$ 对齐,计算 KL 散度或交叉熵损失:

    $$L_{distill} = \text{KL}(\mathbf{P}{soft} | \mathbf{P}{\text{negative_classes}})$$

损失函数

最后,整个 HALO 算法的前向传播会输出一个综合的 Loss。优化器在反向传播时,会同时满足三个几何目标:

$$L_{total} = L_{CE}(s_y, 1.0) + \alpha L_{distill} + \beta L_{radial}$$

  1. $L_{CE}$ (正类吸引):标准的交叉熵,驱使特征 $\mathbf{z}$ 靠近正确的类别中心 $\mathbf{c}_y$。
  2. $L_{distill}$ (语义排斥):用自蒸馏的软标签,温和地将特征推离其他负类,同时保留负类间的相对距离,防止空间撕裂。
  3. $L_{radial}$ (流形约束):维持整个系统的“肥皂泡”高维球形结构,确保所有计算都不会发生径向爆炸。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class HALOModel(nn.Module):
"""
模型包装器:管理特征提取与类别中心点 (Centroids)
"""
def __init__(self, model, n_classes, embedding_dim):
super(HALOModel, self).__init__()
self.model = model
self.n_classes = n_classes
self.emb_dims = embedding_dim

# 初始化 K 个类别中心点
self.centroids = nn.Parameter(
torch.randn(n_classes, embedding_dim, dtype=torch.float32)
)

def forward(self, x):
# Step 1: 特征提取
embeddings = self.model(x) # [B, D]

# Step 1: 中心点居中 (Centering)
# 强制所有中心点减去均值,防止潜在空间整体漂移
centroids = self.centroids
centroids = centroids - centroids.mean(dim=0, keepdim=True) # [K, D]

return embeddings, centroids


class HALOLoss(nn.Module):
"""
HALO 损失函数:超球面界定与潜在优化
"""
def __init__(self, emb_dims, num_classes, learn_gamma=True, distill=True, label_smoothing=0.1, reduction="mean"):
super().__init__()
assert emb_dims > 1, "Embedding dimensions must be > 1"
self.D = float(emb_dims)
self.K = float(num_classes)
self.distill = distill
self.label_smoothing = label_smoothing
self.reduction = reduction

# --- 初始化高维空间几何目标 ---
r_sq_target = 1.0 - (2.0 / self.D) # 目标“肥皂泡”超球面半径
r_sq_init = 2.0 # 随机初始化时的预期距离
init_gamma = 20.0 / (r_sq_init - r_sq_target)

# --- Step 3: 计算理想的弃权偏置 (Ideal Abstain Bias) ---
if label_smoothing > 0:
max_prob = 1.0 - label_smoothing + (label_smoothing / self.K)
min_prob = label_smoothing / self.K
else:
max_prob = 0.99
min_prob = 0.01 / self.K

margin_ce = math.log(max_prob / min_prob) # 交叉熵所需对数裕度
t_ideal = init_gamma * (1.0 - r_sq_target) # 正确类的理想 Logit 锚点

# 将拒识阈值锚定在比理想锚点低一个 margin_ce 的位置
self.abstain_bias = t_ideal - margin_ce

# --- 动态温度参数 gamma ---
if init_gamma > 20.0:
gamma_start = init_gamma
else:
gamma_start = math.log(math.expm1(init_gamma)) # 逆 Softplus
self.gamma = nn.Parameter(
torch.tensor([gamma_start], dtype=torch.float32),
requires_grad=learn_gamma
)

def forward(self, embeddings, targets, centroids):
"""
embeddings: [B, D] - 当前批次的特征向量
targets: [B] - 真实类别标签 (0 ~ K-1)
centroids: [K, D] - 类别中心点
"""
B = embeddings.size(0)

# -------------------------------------------------------------
# Step 2: 距离到 Logit 的代数转换 (The Shifted Logit)
# -------------------------------------------------------------
# 保证 gamma 为正,并应用特征维度缩放控制方差
gamma = F.softplus(self.gamma)
scale = gamma / self.D

# 矩阵乘法计算点积: 2(z \cdot c) -> 尺寸 [B, K]
dot_product = 2.0 * F.linear(embeddings, centroids)

# 计算中心点的 L2 范数惩罚: ||c||^2 -> 尺寸 [K]
centroid_sq_norm = (centroids ** 2).sum(dim=-1)

# 获得代数平移后的 K 个基础类别 Logits
logits = scale * (dot_product - centroid_sq_norm) # [B, K]

# -------------------------------------------------------------
# Step 3: 植入“弃权”水槽 (Injecting the Abstain Sink)
# -------------------------------------------------------------
# 将提前算好的拒识标量扩展到 Batch Size 维度: [B, 1]
abstain_logits = torch.full((B, 1), self.abstain_bias, device=embeddings.device)

# 拼接到最后,形成 [B, K+1] 的完整空间
full_logits = torch.cat([logits, abstain_logits], dim=-1)

# 计算主分类交叉熵。注意:如果有分布外(OOD)样本,将其 target 设为 K 即可自然触发弃权惩罚
loss_ce = F.cross_entropy(
full_logits, targets,
label_smoothing=self.label_smoothing,
reduction='none'
)

# -------------------------------------------------------------
# Step 4: 计算“肥皂泡”空间正则化 (Soap Bubble Regularization)
# -------------------------------------------------------------
# 计算每个特征向量的真实归一化范数 r^2
r_sq = (embeddings ** 2).sum(dim=-1) / self.D # [B]

# 斥力系数 (模拟高维体积扩张)
volume_coeff = 0.5 - (1.0 / self.D)

# -0.5*r^2 是向内拉的引力(高斯先验),log(r_sq) 是向外推的斥力。加上 1e-8 防止 Log(0)
loss_radial = -(volume_coeff * torch.log(r_sq + 1e-8) - 0.5 * r_sq)

# -------------------------------------------------------------
# Step 5: 无教师自蒸馏与软标签构建 (Teacher-Free Self-Distillation)
# -------------------------------------------------------------
if self.distill:
with torch.no_grad():
soft_logits = full_logits.clone()

# 屏蔽正类 (Zeroing the positive target)
# 使用 scatter_ 创建掩码,将当前样本正确类别的 Logit 强制设为负无穷
mask = torch.zeros_like(soft_logits, dtype=torch.bool).scatter_(1, targets.unsqueeze(1), True)
soft_logits.masked_fill_(mask, float('-inf'))

# 在剩余的负类和弃权类中,基于当前真实的相对几何距离,生成软概率分布
soft_targets = F.softmax(soft_logits, dim=-1)

# 让网络的负类预测输出拟合这个相对距离分布
log_probs = F.log_softmax(full_logits, dim=-1)
# KL散度:迫使网络在推开负类时,尊重它们之间的相对语义相似度
loss_distill = F.kl_div(log_probs, soft_targets, reduction='none', log_target=False).sum(dim=-1)
else:
loss_distill = torch.zeros_like(loss_ce)

# -------------------------------------------------------------
# Step 6: 最终损失函数的拼接 (Final Loss Assembly)
# -------------------------------------------------------------
# 原文实现中,三者的权重系数均为隐式的 1.0
total_loss = loss_ce + loss_distill + loss_radial

if self.reduction == "mean":
return total_loss.mean()
return total_loss

实验

CIFAR-10 Benchmark (ResNet-18)
Metric Standard CCE HALO
ID Accuracy (↑) 96.30% 96.53%
Calibration (ECE) (↓) 0.0798 0.0151
Far OOD (SVHN) AUROC (↑) 92.51% 98.08%
Far OOD (SVHN) FPR@95 (↓) 22.08% 10.27%
Near OOD (CIFAR-100) AUROC (↑) 82.83% 91.72%
Near OOD (CIFAR-100) FPR@95 (↓) 48.94% 37.63%
CIFAR-100 Benchmark (ResNet-18)
Metric Standard CCE HALO
ID Accuracy (↑) 80.94% 80.80%
Calibration (ECE) (↓) 0.1102 0.0283
Far OOD (SVHN) AUROC (↑) 81.01% 86.91%
Far OOD (SVHN) FPR@95 (↓) 81.00% 63.70%
Near OOD (CIFAR-10) AUROC (↑) 79.75% 81.00%
Near OOD (CIFAR-10) FPR@95 (↓) 76.77% 75.38%