文章

CS3: Efficient Online Capability Synergy for Two-Tower Recommendation

CS3: Efficient Online Capability Synergy for Two-Tower Recommendation

论文元信息

  • 标题:CS3: Efficient Online Capability Synergy for Two-Tower Recommendation
  • 作者:Lixiang Wang, Shaoyun Shi, Peng Wang, Wenjin Wu, Peng Jiang
  • 机构:快手 (Kuaishou Technology)
  • 会议:SIGIR 2026(2026.07)
  • arXiv2604.19269(2026-04-21 submitted)
  • 代码github.com/lixiangwang/CS3Rec
  • 线上规模:快手广告,DAU > 4 亿

一句话定位

CS3 在做的事情和 HSNN 方向完全相反

  • HSNN:打破双塔范式去解决根本问题
  • CS3:承认双塔范式不能动,然后在 plug-and-play 的前提下尽可能给它注入跨塔/跨阶段的信息

框架总览

CS3 从三个层面给双塔注入信息:

模块全称作用信息来源
CASCycle-Adaptive Structure内部自我修正/去噪自己
CTSCross-Tower Synchronization之间显式对齐伙伴塔的历史输出(cached)
CMSCascade-Model Sharing阶段知识复用下游精排模型的中间表示(cached)

模块 1:CAS (Cycle-Adaptive Structure)

思路

每个 FC 层做一次”自我去噪 + 重算”,灵感来自 RecycleNet 和 diffusion model 的 cyclic refinement。

流程(替换原本的一个 FC 层)

1
2
3
4
5
6
7
8
9
10
11
输入: h_l
─────────────────────────────────
1. Pre-Forward:    h_{l+1} = σ(W_l · h_l + b_l)
                   (跟标准 FC 一样,先正向算一次)

2. Adaptive Reweighting:
   r_l = σ(W'_l · h_{l+1} + b'_l)      # 每维 ∈ [0,1] 的重要性
   ~h_l = 2 · r_l ⊙ h_l                # 重权输入,×2 保持期望防梯度消失

3. Cycle-Forward:  h_{l+1} = σ(W_l · ~h_l + b_l)
                   (用同一套 W_l, b_l 重算一遍)

关键点

  • 单 cycle(不递归),参数共享 W / b
  • r_l 越小说明该维度越像噪声,被压制
  • 不引入新特征,只对现有表征做”两遍跑”
  • 不破坏双塔独立性,可以同时插在 user 塔和 item 塔里
  • 部署代价:retriever QPS −0.589%

落地建议

直接替换 user_tower_layers / item_tower_layers 里每个 Dense 层的实现 —— 把它从”线性+激活”换成 CAS 三步走。不需要任何外部 infra,不动 serving。


模块 2:CTS (Cross-Tower Synchronization)

这是整篇论文最巧妙的设计。 用”缓存对方塔的输出 + EMA 更新”在不破坏双塔独立 forward 的前提下,把伙伴塔的信息塞回自己塔里。

缓存设计

符号物理含义索引键一句话
c^u[user_id]给 user 塔用的 cacheuser_id“这个 user 历史正反馈过的 item 们的 EMA 表征”
c^v[item_id]给 item 塔用的 cacheitem_id“和这个 item 正反馈过的 user 们的 EMA 表征”

注意:c^uc^v 不是按 user/item 分别维护一组的意思,而是”按 user_id 索引”和”按 item_id 索引”。

输入构造(论文公式 5)

1
2
3
user 塔输入: [原始 user 特征,  c^u[user_id],  g^u[user_id]]   ← 含来自 item 塔的信息
item 塔输入: [原始 item 特征,  c^v[item_id],  g^v[item_id]]   ← 含来自 user 塔的信息
                              (CTS)             (CMS)

EMA 更新公式(论文公式 6)

1
2
3
对每条正样本 (user u, item v) at time t:
  c^u_{t+1} = β · c^u_t + (1-β) · v_t       [更新 user 侧 cache]
  c^v_{t+1} = β · c^v_t + (1-β) · u_t       [更新 item 侧 cache]

更新方向是交叉的

  • user 侧的 cache c^uitem 塔输出 v_t 来更新
  • item 侧的 cache c^vuser 塔输出 u_t 来更新

这就是名字里”Cross”的来源 —— 每个 id 缓存的是对方塔的信息。

一次完整训练步流程

假设当前来一条样本:user_id=A,item_id=X,标签 y。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
Step 1 (查):
  从 ParSvr 读 c^u[A]、c^v[X],作为额外输入塞进两个塔

Step 2 (前向):
  u_t = UserTower([A 的原始特征,  c^u[A],  g^u[A]])
  v_t = ItemTower([X 的原始特征,  c^v[X],  g^v[X]])
  score = u_t · v_t

Step 3 (反传):
  正常算 loss、回传梯度,更新 W、b、sparse embedding
  注意: c^u 和 c^v 不参与梯度反传 (论文: custom gradient 实现 EMA)

Step 4 (EMA 更新, 只在 y=1 时):
  c^u[A] ← β · c^u[A] + (1-β) · v_t       # 用这次 item 塔的输出
  c^v[X] ← β · c^v[X] + (1-β) · u_t       # 用这次 user 塔的输出

Step 5 (写回):
  把更新后的 c^u[A]、c^v[X] 写回 ParSvr

y=0 的样本:照样跑前向和反传,但 跳过 step 4,cache 不动。

数值例子

假设 β=0.9,看 user A 的 c^u[A] 怎么演化。初始 c^u[A] = 0

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
事件 1: A 点击了 item X (正样本)
  当时 ItemTower 输出 v_1 = [0.5, 0.3, ...]
  c^u[A] ← 0.9·0 + 0.1·v_1 = 0.1·v_1

事件 2: A 给 item Y 点了赞
  当时 ItemTower 输出 v_2 = [0.2, 0.8, ...]
  c^u[A] ← 0.9·(0.1·v_1) + 0.1·v_2
        = 0.09·v_1 + 0.1·v_2

事件 3: A 看到 item Z 但没点 (负样本)
  c^u[A] 不更新

事件 4: A 转化了 item W
  c^u[A] ← 0.9·(0.09·v_1 + 0.1·v_2) + 0.1·v_4
        = 0.081·v_1 + 0.09·v_2 + 0.1·v_4

可以看到 c^u[A] 本质是 A 历史所有正反馈 item 的 item-tower 输出的指数衰减加权和

  • 越近的事件权重越大(系数 0.1)
  • 越远的事件权重指数衰减(系数 0.1·β^k)
  • 没有完整序列存储,O(1) 内存

β 怎么选

β 值EMA 行为适用
接近 1(如 0.99)慢更新,记忆长用户/物品兴趣稳定,注重去抖
中间(0.9)平衡默认
接近 0(如 0.5)快更新,记忆短强追时变(爆点物品、用户兴趣漂移)

论文未给具体数值,快手语境一般偏 0.9–0.95。

为什么这个设计在工程上特别巧妙

  1. 绕开了双塔独立 forward 的限制
    • 标准双塔:user 塔不能见 item 信息
    • CTS:user 塔见的不是”当前 item”的信息,而是”该用户历史所有 item 的 EMA 表征” —— 这个东西按 user_id 一查就有,不需要 query time 跨塔通信
    • Serving 时 user 塔输入 = [user feat, c^u[user_id], g^u[user_id]],全部按 user_id 索引可得,ANN 完全兼容
  2. 存储是 O(N_users + N_items)
    • 每个 user / item 各加一个 D 维向量(D ≈ 32 或 64)
    • 跟 sparse embedding 量级一样,工程可承受
  3. EMA 自带”跟着模型走”的能力
    • 用的是当前模型输出的 u_t / v_t 来更新 cache
    • 模型权重在 online learning 中持续更新,cache 也跟着走
    • 不会像离线预计算那样过期
  4. 梯度可控
    • cache 不参与反传 → 不会和 sparse embedding 抢梯度,不会震荡
    • 实现上是 ParSvr 里写一个 custom op,把”梯度”硬解释成”EMA 更新增量”

一个微妙的点:循环依赖

u_t = UserTower(..., c^u_t, ...),而 c^u 由历史的 v_τ 累积,v_τ 又依赖于历史的 c^v_τ,看起来像循环依赖。

实际没问题:

  • t 时刻读的是 t 时刻之前累积出来的 c^u_t(用的是上一步更新后的值)
  • t 时刻产生的 v_t 是用来更新 c^u_{t+1} 的(写给未来用)
  • 时间方向单向,没有真正的 feedback loop

和用户行为序列建模(DIN/SIM/TWIN)的对比

CTS 的 c^u[user_id] 本质上和”用户历史交互序列的 attention pooling”在做同一类事,但:

维度DIN/SIM 序列建模CS3 CTS
存储完整序列 list一个 D 维向量
计算query time 做 attentionquery time 一次 lookup
跟模型走每次重算(用最新 emb table)训练时 EMA 同步累积
跨塔双塔做不了(item 在另一侧)天然解决,因为是按 user_id 缓存对方塔输出

所以 CTS 可以理解为 “序列建模的极简压缩版 + 解决了双塔跨塔难题的版本”


模块 3:CMS (Cascade-Model Sharing)

思路:把下游精排模型的中间层输出(penultimate FC output)按 id 缓存,反向喂给召回。

CMS 在数学形式上几乎是 CTS 的翻版,但信息源完全换了,所以即使公式长得像,跑出来的东西意义大不同。

CMS 公式(论文公式 7)

1
2
3
对每条样本 (user u, item v) at time t (无论正负):
  g^u_{t+1} = β · g^u_t + (1-β) · z_t
  g^v_{t+1} = β · g^v_t + (1-β) · z_t

三个变量:

符号物理含义索引键
g^u[user_id]给 user 塔用的 cacheuser_id
g^v[item_id]给 item 塔用的 cacheitem_id
z_t精排模型对 (u,v) pair 的中间表征(u, v) pair-level

z_t 论文明确说了是 “the penultimate FC output before the final prediction layer” —— 也就是精排模型最后一层 FC 之前的那个 D 维向量(业界常说的 “embedding before the head”)。

CMS vs CTS 的逐项对比

维度CTSCMS
更新源召回自己塔的输出 u_t / v_t外部精排模型的中间表征 z_t
更新源是不是 pair 级否,单塔 emb(u_t 只来自 user 塔,v_t 只来自 item 塔)是,pair-level(z_t 来自 ranker(u,v))
更新方向交叉:c^uv_tc^vu_t对称:g^ug^v 都用同一个 z_t
正负样本只用正样本 (y=1)正负样本都用
存储位置ParSvr(和召回联训)EmbSvr(独立 KV 服务)
本质角色解决”双塔跨塔信息互通”解决”召回-精排 capability gap”

为什么 g^u 和 g^v 用同一个 z_t

z_t 是精排对 pair (u, v) 算出来的,自带 user / item / cross 三方面信息。论文的逻辑是把它同时投影到两个索引维度

  • 从 user_id 视角累积g^u[user_id] 是”该 user 历史所有 pair 的精排表征 EMA” → 这个 user 在精排空间里的画像
  • 从 item_id 视角累积g^v[item_id] 是”该 item 历史所有 pair 的精排表征 EMA” → 这个 item 在精排空间里的画像

虽然来自同一个向量 z_t,但累积出来的语义不一样:因为不同 user 见过的 pair 集合不一样,不同 item 见过的 pair 集合也不一样,EMA 路径完全不同。

一次完整训练步流程(含 CTS 和 CMS 同时启用)

来一条样本:user_id=A,item_id=X,无论是正是负:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Step 1 (查):
  从 EmbSvr 读 g^u[A]、g^v[X]
  从 ParSvr 读 c^u[A]、c^v[X]

Step 2 (前向, 召回模型):
  u_t = UserTower([A 的特征, c^u[A], g^u[A]])
  v_t = ItemTower([X 的特征, c^v[X], g^v[X]])
  score = u_t · v_t

Step 3 (前向, 精排模型 — 通常是另一个独立模型/训练流):
  z_t = RankerModel(u, v).penultimate_layer    // 拿倒数第二层 FC 输出

Step 4 (反传 召回): 正常算 loss、更新召回权重
        c^u / c^v / g^u / g^v 不参与反传

Step 5 (CMS EMA 更新, 不挑正负):
  g^u[A] ← β · g^u[A] + (1-β) · z_t            // 写到 EmbSvr
  g^v[X] ← β · g^v[X] + (1-β) · z_t            // 写到 EmbSvr

Step 6 (CTS EMA 更新, 仅正样本):
  if y == 1:
    c^u[A] ← β · c^u[A] + (1-β) · v_t
    c^v[X] ← β · c^v[X] + (1-β) · u_t

注意:Step 5 不挑正负,Step 6 只在正样本做。这是 CMS 和 CTS 在更新策略上的核心差异。

为什么 CMS 不挑正负,CTS 挑正样本

论文原话:”CMS treats both positive and negative z_t as useful, since cascade models encode richer knowledge.”

理解:

  • CTS 挑正样本:因为 CTS cache 的是召回自己塔的输出。负样本意味着这个 pair 没有真实兴趣信号,把它的塔输出累进 cache 反而引入噪声。”用户喜欢什么”必须看正反馈。
  • CMS 不挑:因为 z_t 来自精排模型,精排已经做过正负判别了 —— 它的中间表征本身就包含了”这是个负样本”这个信号(比如 z_t 在某些维度上是低的)。把负样本的 z_t 累进去,相当于让 user/item 的 cache 也吸收了”哪些组合是不感兴趣的”这种结构性信息。

简单类比:

  • CTS 像”只记你点过的”
  • CMS 像”记下精排对你打的所有分(包括打低分的)”,因为精排打分本身就是带评判的高级信息

CMS 数值例子

设 β=0.9,看 user A 的 g^u[A] 演化(初始为 0):

1
2
3
4
5
6
7
8
9
10
事件 1: A 看到 X (负)。精排 z_1 = [0.1, 0.2, ...]
  g^u[A] ← 0.9·0 + 0.1·z_1 = 0.1·z_1                   ← 负样本也更新

事件 2: A 点击 Y (正)。精排 z_2 = [0.7, 0.5, ...]
  g^u[A] ← 0.9·(0.1·z_1) + 0.1·z_2

事件 3: A 看到 W (负)。精排 z_3 = [0.05, 0.3, ...]
  g^u[A] ← 0.9·(...) + 0.1·z_3                         ← 继续更新

  (对照: c^u[A] 在事件 1 和 3 不动, 只在事件 2 更新)

可以看到 g^u[A] 累积的是 A 经历过的所有曝光(无论结果)在精排表征空间里的”轨迹”。

CMS 和传统知识蒸馏的本质差别

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
传统蒸馏 (Tang & Wang 2018, Reddi 2021):
  teacher_logit = TeacherModel(u, v)
  student_logit = StudentModel(u, v)
  L = CE(student, hard_label) + KL(student, teacher_soft_label)
  → 在 LOSS 层面注入信息
  → student inference 时不需要 teacher
  → 信息带宽:1 个 scalar (logit)

CMS:
  z_t = TeacherModel(u, v).penultimate          // D 维向量
  g^u[user_id] ← EMA + z_t
  v_emb = ItemTower([item_feat, ..., g^v[item_id]])
  → 在 FEATURE 层面注入信息
  → student inference 时仍需要 cached z (从 EmbSvr 读)
  → 信息带宽:D 维向量 (D 通常 32~128)

这也是为什么论文里 ablation 显示 CMS 是单组件最大贡献者(线上 +6.2% revenue 中 CMS 拿走大头):信息带宽差几十倍,效果自然差不少。

工程上的关键:EmbSvr 为什么必须独立于 ParSvr

论文 §2.4.3 解释得很清楚:

  1. ParSvr 作用域是单训练任务。召回的 ParSvr 不能直接看到精排的 ParSvr。
  2. 召回和精排的训练流通常不共享。它们有不同的:
    • 数据流(曝光-点击 vs. 整个 ranking 阶段的 pair)
    • 采样策略(召回有大量负采样,精排负样本少)
    • 训练频率
  3. 即使联训,把 z_t 直接经 ParSvr 传也很挤(精排维度通常比召回大)。

所以快手单独搞了一个 EmbSvr:

  • 专门为 embedding 存储优化(KV 高 QPS)
  • 跨训练任务共享:精排训练时写 z_t,召回训练时读
  • p99 < 5ms,serving 时和其它特征处理并行
  • 召回 serving QPS 几乎没影响

一个容易踩的坑:冷启动

冷启动问题在 CMS 上比 CTS 更明显:

  • 新 user / 新 item 没出现过 → g^u / g^v 是 0 向量
  • 召回模型在训练阶段几乎没见过 0 向量的 g(因为训练样本都是已有 user/item 的曝光)
  • 推到线上遇到新用户,相当于这个特征”塌了”

工程上一般要做:

  • 给冷启用户 / 新 item 一个学习出来的”冷启 embedding”作为 fallback
  • 或者按”出现次数”做软切换(出现次数 < N 时用 fallback,≥ N 时用 EMA cache)
  • 论文未特别讨论,但落地时这是必处理的

在线学习架构(论文 Figure 2)

系统组件名词

ParSvr(Parameter Server,参数服务器)

  • 工业 ML 系统的标配组件,用于在大规模分布式训练中集中存储 + 同步模型参数
  • 一般是分布式 KV 服务:worker 从它拉参数 → 本地算梯度 → 推回更新;周期性导出快照供 serving 加载
  • 典型存储内容:dense weight(W、b)、sparse embedding table(每个 user_id / item_id / categorical 特征 → 向量)
  • 在 CS3 中除了常规 sparse embedding,还存:
    • CTS 的 cross vector c^u[user_id]c^v[item_id]
    • 这两个虽然形式上像 sparse embedding(按 id 索引的向量表),但更新机制不同:sparse embedding 是梯度更新,cross vector 是 EMA 更新
    • 实现上把 EMA 包装成一个 “custom gradient” op 塞进 ParSvr 的更新流程,复用同一套同步基建
  • 作用域:单个训练任务。召回的 ParSvr 看不见精排的 ParSvr —— 这正是 CMS 不能用 ParSvr、必须另起 EmbSvr 的根本原因

EmbSvr(Embedding Server)

  • 一个独立的分布式 KV 服务,专门为 embedding 存储和高 QPS lookup 优化
  • 在 CS3 中存放:CMS 的 cascade vector g^u[user_id]g^v[item_id]
  • 跨训练任务共享:精排训练流把 z_t 写进 EmbSvr,召回训练流和 serving 都从 EmbSvr 读
  • 设计指标:p99 < 5ms,访问可以和其它特征处理并行,对召回 serving QPS 影响 < 1%

一句话区分ParSvr 是”模型自己的参数仓库”,EmbSvr 是”跨模型的共享缓存”。CS3 把 CTS 放 ParSvr、CMS 放 EmbSvr,本质是因为信息源在系统层面就分属两个不同的训练任务。

整体数据流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
┌──────────────────────────────────────────────────────────┐
│            Online Learning Loop (~30 min)                │
│                                                          │
│   曝光日志 → 等 label → 实时训练 → ParSvr 同步 → Serving  │
│                              │                           │
│                              ↓                           │
│                  ┌──────────────────────┐                │
│                  │  ParSvr              │                │
│                  │  - sparse embedding  │ ← 梯度更新     │
│                  │  - CTS cross vector  │ ← EMA (custom) │
│                  └──────────────────────┘                │
│                                                          │
│                  ┌──────────────────────┐                │
│                  │  EmbSvr (独立)       │                │
│                  │  - CMS cascade vec   │ ← EMA          │
│                  │  - 跨训练任务共享    │                │
│                  └──────────────────────┘                │
└──────────────────────────────────────────────────────────┘

效率(论文 §2.4.4):

  • CTS / CMS 几乎零开销(只是 lookup)
  • CAS 让 retriever QPS 降 < 1%
  • EmbSvr p99 < 5ms,可与其它特征处理并行
  • item embedding 仍然预计算缓存,主开销在实时 user 塔

实验结果

离线(4 个 backbone × 3 个数据集)

BackboneTaobaoAd AUCKuaiRand AUCRecSys2017 AUC
DSSM0.6194 → 0.6855 (+CS3)0.6646 → 0.74840.6855 → 0.8380
IntTower0.6507 → 0.68950.7503 → 0.76150.8178 → 0.8657
IHM-DAT0.6302 → 0.67830.7059 → 0.75560.7694 → 0.8660
RCG (transformer)0.6680 → 0.68600.7814 → 0.83040.7870 → 0.8676

关键观察:DSSM + CS3 ≈ IntTower / IHM-DAT 这种”专门为双塔设计的复杂结构”。即 CS3 把 vanilla DSSM 的 ceiling 抬到了 SOTA 双塔变体级别。

线上 A/B(快手广告)

Scenario A 消融

配置RevenueDAC
Base0%0%
+ CAS+1.677%+0.144%
+ CAS + CMS+7.880%+0.435%
+ CAS + CMS + CTS (full CS3)+8.356%+0.468%

关键观察CMS 是单组件最大贡献者(CAS+CMS 已经吃了 80% 的收益)。这说明”跨阶段一致性”比”单塔内部去噪”和”跨塔对齐”都更值钱。

3 个场景泛化

场景RevenueDACQPS
Scenario A+8.356%+0.468%−0.589%
Scenario B+1.366%+0.143%−0.388%
Scenario C+2.177%+0.228%−0.456%

QPS 下降都 < 1%,主要是 CAS 让 user tower 多算了一遍。


CS3 vs HSNN:两条完全不同的路线

维度HSNNCS3
思路打破双塔范式保留双塔范式
是否兼容 ANN否,要换 serving 栈是,零改动
Plug-and-play否,重构整个召回是,三个独立模块
在线学习友好一般(cluster collapse 等问题)设计就是为在线学习
工程量大(层次索引 + LTI + JOIM)小(ParSvr + EmbSvr)
能拿到的收益上限高(结构性改造)中(增量优化)
适合谁Meta Ads 那种规模 + infra 的团队现有双塔 + 想吃增量的团队

对当前 twotower 召回模型的具体借鉴

CS3 三个模块都和现有架构兼容,落地成本和收益排序:

1. CAS(最容易)

  • 直接替换 user_tower_layers / item_tower_layers 里的 keras.layers.Dense
  • 实现一个 CASLayer:内部三步 pre-forward / reweight / cycle-forward
  • 不需要任何外部 infra,不动 serving
  • 单组件离线就能拿到稳定收益

2. CTS(中等)

  • 需要按 user_id / item_id 维护 cross vector cache
  • 如果有 ParSvr / Redis 这类 KV 存储,加两个新 sparse embedding(按 user_id / item_id 索引)即可
  • 训练时改成 EMA 更新(不参与梯度),serving 时按 id 直接读
  • 关键工程问题:要确保 user_id 索引的 cross vector 在 serving 时和训练时分布一致 — 即都从 EMA cache 读
  • 模型代码里:在 _build_user_tower 入口增加一个 c^u 向量 concat,_build_item_tower 同理

3. CMS(最值钱但最重)

  • 需要打通”召回模型 ↔ 精排模型”的中间表示传递
  • 需要独立 Embedding Server 跨任务缓存
  • 但论文证明这是单组件最大收益来源(线上 +6.2% revenue 占了 CS3 总收益 75% 以上)
  • 适合先跑通 CAS / CTS 验证框架,再推 CMS

几点判断

  1. CS3 比 HSNN 更适合现阶段。HSNN 是”infra 已经准备好搞层次召回,想把表达力推到极限”的方案,CS3 是”现有 EBR 框架不动,给召回一个 plug-and-play 的能力补丁”。

  2. CMS 是真正的关键贡献。CAS 和 CTS 在学术上更花哨,但线上 ablation 显示 CMS 一个就吃掉 75%+ 的收益。这印证了一个工业经验:召回和精排的 capability gap 才是 EBR 真正的瓶颈,比塔内部结构和塔间对齐都更值钱

  3. EMA + cached cross vector 这个 pattern 值得收藏。它是”我想要跨塔/跨阶段信息但又不想破坏 serving 独立性”的通用解法,可以泛化到很多场景(比如跨域召回 cache 另一个 domain 的 user emb)。

  4. CS3 不解决 HSNN 解决的问题。如果最终目标是”召回能用上 <user, item> 交叉特征 + 复杂 ranking 级模型”,CS3 走不到那一步 —— 它本质还是 late fusion,只是 late fusion 之前每个塔都拿到了对方/下游的”残影”。


Sources

本文由作者按照 CC BY 4.0 进行授权