Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions cross_router_training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# cross-router (CrossThrift) — 代码说明

这份文档讲**代码结构与复现方式**;方法/模型本身的设计见同目录 [`ROUTER_v17_CrossThrift.md`](./ROUTER_v17_CrossThrift.md)。

一句话:cross-router 对每个 query 预测「能答对它的最便宜部署模型」,从 7 个模型里选一个;判为无解的题路由到最便宜的模型。它在 **LLMRouterBench** 上训练,**不碰 RouterArena 数据**,提交时只输出 `prediction`,`generated_result` 由官方缓存回填。

---

## 1. 推理侧(提交真正用到的代码)

都在 `router_inference/` 下,框架按 `config` 里的 `router_cls_name` 动态加载:

```
router_inference/
├── config/cross-router.json # 配置:类名 / 权重 / 特征缓存 / 7 个模型
├── router/
│ ├── cross_router.py # CrossRouter 类(推理逻辑)
│ ├── __init__.py # 注册 CrossRouter
│ └── cross_router_assets/
│ ├── model_v17.py # 模型定义 + 常量 + TARGET2ARENA(自包含)
│ ├── cross_router_v17.pt # 冻结的权重(含 model_features)
│ └── features_*.npz # 离线特征缓存(见 §3,大文件,未入库)
└── predictions/
├── cross-router.json # full 预测(已提交)
└── cross-router-robustness.json # robustness 预测(已提交)
```

**`CrossRouter._get_prediction(query)` 的流程**(`cross_router.py`):

1. 用 query 原文串去 `features_*.npz` 查这道题的离线特征(Qwen 句向量 1024 + 归一化长度 + 领域置信度 + 14 维领域 one-hot)。
2. 喂进 `PairProfileRouterV17` → 得到 8 维 logits(7 个模型 + 1 个 `unsolvable`)。
3. `argmax` 取类别;若是 `unsolvable` → 回退到 `cheapest_idx`(最便宜的可部署模型 `qwen3-235b-a22b-2507`)。
4. 把内部目标名经 `TARGET2ARENA` 映射成榜单模型 id 返回。

> 模型结构(cross-transformer 双塔)见 `model_v17.py` 顶部注释与方法文档。所有参数从训练 checkpoint 冻结加载,推理时不再调整。

## 2. 训练侧(`cross_router_training/`,用于复现/审查)

| 文件 | 作用 |
|---|---|
| `build_cheapest_target_table.py` | bench-release 原始评测 → 每题一个标签(最便宜正确模型 / unsolvable):`cheapest_correct_target.csv` |
| `embed_queries.py` | 对题面做 Qwen embedding → `embeddings.npy` + `emb_qids.npy` |
| `split_dataset.py` | 划分 train/val/test → `splits.json` |
| `build_target_matrix.py` | 每题×每目标的正确性(mean 聚合)→ `target_correct_mean_cc_router.npy` |
| `config.py` | 上述产物的路径配置(被训练脚本 import) |
| `train_router_v17.py` | v17 训练脚本:建模型特征、训练(CE + 成对排序)、按 val macro 选 checkpoint、导出 `.pt` |
| `ROUTER_v17_CrossThrift.md` | 方法文档(数据 / 特征 / 候选模型 / 训练过程) |

## 3. 从零复现的步骤

```bash
# 0) 路径:把 cross_router_training/config.py 里的 REPO_ROOT 改成你的本地路径

# 1) 标签表(从 bench-release / LLMRouterBench)
python build_cheapest_target_table.py # -> cheapest_correct_target.csv

# 2) 句向量 + 3) 划分 + 4) 正确性矩阵
python embed_queries.py # -> embeddings.npy, emb_qids.npy
python split_dataset.py # -> splits.json
python build_target_matrix.py --agg mean # -> target_correct_mean_cc_router.npy

# 5) 训练,导出权重
python train_router_v17.py # -> models/...v17....pt (= cross_router_v17.pt)

# 6) 为 RouterArena 的题离线预计算特征缓存 features_*.npz
# (同一个 Qwen 编码器 + 领域分类器 + model_cost.json 成本表,
# 产物放到 router_inference/router/cross_router_assets/)

# 7) 生成提交预测
python router_inference/generate_prediction_file.py cross-router full
python router_inference/generate_prediction_file.py cross-router robustness
```

## 4. 不随仓库提交的大文件

为控制仓库体积,以下**未入库**,需按上面步骤本地再生:

- `embeddings.npy`(~98 MB)、`target_correct_mean_cc_router.npy`、`cheapest_correct_target.csv` 等训练产物;
- `router_inference/router/cross_router_assets/features_*.npz`(~48 MB)推理特征缓存。

缺少 `features_*.npz` 时 `CrossRouter` 会直接报错提示先再生;**已生成的 `predictions/*.json` 不依赖它们**,提交不受影响。

## 5. 提交说明

- 三个提交文件(`config/cross-router.json` + `predictions/cross-router.json` + `predictions/cross-router-robustness.json`)已就绪。
- `predictions` 里 `generated_result` 已用「同模型同题」的现成输出本地预填一部分;剩余 null 由官方 `/evaluate` 用完整缓存回填、并以官方定价计成本。提交端只需保证 `prediction` 正确。
100 changes: 100 additions & 0 deletions cross_router_training/ROUTER_v17_CrossThrift.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# CrossThrift — Cheapest-Correct Cross-Transformer Router (cc-router / v17)

> **命名**:**CrossThrift**(正式名 *Cheapest-Correct Cross-transformer Router, CC-XR*)。
> *Cross* = query×model 的 cross-transformer 交互;*Thrift* = 目标是「能答对的最便宜模型」。
> 配置 id 沿用仓库里的 `cc-router`。

---

## 1. 一句话

对每个 query,直接预测**能正确回答它的最便宜的部署模型**,从 7 个候选里选一个路由过去;判为「无解」的题交给最便宜的模型兜底。

## 2. 核心思路

不去预测「每个模型各自对不对」(这会被覆盖率混淆——强模型只在难题上被测、弱模型只在易题上被测),而是把**最便宜正确模型**当作一个直接的目标来学。让 query 特征和每个候选模型的特征**显式交互**,由一个小 cross-transformer 判断这一对 (query, model) 配不配。

## 3. 候选路由模型(7 个)

按「成本—能力」阶梯选出 7 个部署目标,覆盖从最便宜到最强:

| 目标 | 角色 |
|---|---|
| `qwen3-235b-a22b-2507` | 部署成本最低 → **无解题的兜底模型** |
| `Qwen3.5-9b` | 小模型,易题主力 |
| `qwen3-coder-30b-a3b-instruct` | 代码类 |
| `deepseek-reasoner` | 推理类 |
| `deepseek/deepseek-v4-pro` | 通用强模型 |
| `gemini-2.5-flash` | 通用 + 非常规来源的归口 |
| `gpt-5.1` | 能力上限,难题锚点(榜单映射为 `gpt-5`) |

数据集里实际被评测过的是 **37 个来源模型**;它们通过一张「来源→部署」映射表归并到上面这 7 个目标层,所以训练标签只在这 7 个目标上取值。

## 4. 训练数据(标签怎么来)

底层数据是 **LLMRouterBench(bench-release)**,**完全不碰 RouterArena 数据**。
一张脚本(`build_cheapest_target_table.py`)从原始评测一路处理到每个 query 的单一标签:

1. **裁切 split**:每个多 split 数据集只留一个(mmlupro→test_3000、hle→test、simpleqa→test),去掉重复评测模型、跳过空分数。
2. **修复 openrouter**:把 `openrouter` 壳记录展开成它真实的 `actual_model`。
3. **取最便宜正确解**:某 query 上,所有答对(score≥0.5)且能映射到 7 目标的模型里,按来源成本取**最便宜**的那个 → 它对应的目标层就是标签。
4. **无解**:没有任何模型答对 → 标签 `unsolvable`。
5. **非常规归口**:只有不在映射表里的模型答对 → 归到 `gemini-2.5-flash`。

产物:`cheapest_correct_target.csv`,每行一个 query 及其最终标签。
再按固定 split 切成 train/val/test,**所有统计与选模型只用 train/val**。

## 5. 特征(分别从哪来)

**query 侧特征**(每个 query 一条,离线算好):

| 特征 | 来源 |
|---|---|
| 句向量 1024 维 | Qwen embedding,对题面文本离线编码 |
| 长度 | `log1p(prompt token 数)`,用 train 统计量归一化 |
| 领域 one-hot(14)+ 置信度 | mmbert 领域分类器(14 类 mmlu 体系),存于 `bench-release-domain/.../domain.json` |

**model 侧特征**(每个目标一条,26 维,**全部仅用 train 统计**):

| 维度 | 内容 | 来源 |
|---|---|---|
| 3 | 难度先验(hard/medium/easy 的能力档位) | 人工设定的难度 profile |
| 14 | 各领域 log-lift(该模型在某领域被选为最便宜正确解的相对倾向) | train 标签里 `P(模型\|领域)` 相对其全局被选率,含个别人工修正 |
| 1 | 全局被选率 | train 标签 |
| 1 | `log(train 标签数)` | train 标签 |
| 7 | 价格特征(进/出价、几种混合价、价比、便宜度排名) | `model_cost.json` |

## 6. 架构

```
query 塔: Qwen emb(1024)+长度+置信度+domain(14) → bottleneck(128) → encoder → z_q(64)
model 塔: model_features(26) → encoder → m(64)

每个 (query, model) 组 4 个 token: [ q , m , q⊙m , |q−m| ]
→ 1 层 cross-transformer(4 头) → 该模型 logit (×7)
z_q → unsolvable logit (×1)
→ 8 维输出 = [7 个目标, unsolvable]
```

## 7. 训练过程

- **损失** = 交叉熵 + 成对排序损失:
- 交叉熵:按类别频率做 sqrt 重加权,缓解类别不平衡;
- 成对排序:正确标签模型分 > 其它模型、> unsolvable;标签为 unsolvable 时 unsolvable 分 > 所有模型;多个模型都对时,**更便宜的对的模型分更高**。
- **选 checkpoint**:按验证集 **macro 准确率**(兼顾少数类,不是只看整体 top-1),配早停。
- **纪律**:特征统计、训练、选模型全在 LLMRouterBench 上完成,RouterArena 数据只做前向推理,绝不回看其答案。

## 8. 推理与路由

1. 查离线特征缓存(emb / 长度 / 置信度 / 领域 / 成本),前向得到 8 维 logits。
2. `argmax` 取类别。
3. 若是 **unsolvable → 回退到最便宜的可部署模型**(`qwen3-235b-a22b-2507`)。
4. 缓存未命中时同样兜底到最便宜模型。

## 9. 相关文件

- 训练:`scripts/codex_router/difficulty_clf/train_router_cc_v8_pair_profile_v17_cross_transformer_pairwise_rank.py`
- 建标签:`scripts/codex_router/build_cheapest_target_table.py` → `codex_artifacts/difficulty_clf/cheapest_correct_target.csv`
- 权重:`codex_artifacts/difficulty_clf/models/router_cc_v8_pair_profile_v17_cross_transformer_pairwise_rank.pt`
- 推理侧:`RouterArena-main/router_inference/router/cc_router.py` + `cc_router_assets/cc_v8_pair_profile_v17_model.py` + `config/cc-router.json`
- 预测产物:`router_inference/predictions/cc-router.json`(full)、`cc-router-robustness.json`(robustness)
147 changes: 147 additions & 0 deletions cross_router_training/build_cheapest_target_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
"""Full pipeline: raw bench-release -> clean per-query routing label table.

Each query gets ONE label among the 7 deploy targets (or "unsolvable"):
the cheapest source model that answered it correctly, mapped to its target tier.

Operations (all in this one script):
1. split trimming : keep one split per multi-split dataset (SPLIT_SELECTION) +
drop the duplicate eval model glm-4.6-3888 + skip None scores.
2. openrouter fix : bench-release "openrouter" files collapse 9 real models under
model_name="openrouter"; we expand each record to its true
extra_fields.actual_model so it counts as that real model.
3. cheapest target : per query, the cheapest (by source cost) CORRECT model whose
name maps to one of the 7 targets -> that target is the label.
4. unsolvable : if NO model answered correctly, label = "unsolvable".
5. off-scheme : if correct only by off-scheme models (not mappable), label = "gemini-2.5-flash".

Run: python build_cheapest_target_table.py
Out: codex_artifacts/difficulty_clf/cheapest_correct_target.csv
"""

from __future__ import annotations

import ast
import csv
import hashlib
import json
import math
from collections import defaultdict
from pathlib import Path

BENCH = Path("/Users/jiahg/Desktop/work/LLMRouter/dataset/bench-release")
OUT = Path("/Users/jiahg/Desktop/work/LLMRouter/RouterArena-main-Codex/codex_artifacts/difficulty_clf/cheapest_correct_target.csv")

# op 1: split trimming
SPLIT_SELECTION = {"mmlupro": "test_3000", "hle": "test", "simpleqa": "test"}
EXCLUDED_MODELS = {"glm-4.6-3888"}

TARGETS = ["Qwen3.5-9b", "qwen3-235b-a22b-2507", "qwen3-coder-30b-a3b-instruct",
"deepseek-reasoner", "deepseek/deepseek-v4-pro", "gemini-2.5-flash", "gpt-5.1"]

# source/right model -> target/left tier. (37 base + resolved extras + openrouter actual_models)
_MAP = """gpt-5 gpt-5.1|Qwen3-8B Qwen3.5-9b|gemini-2.5-pro gpt-5.1|DeepSeek-R1-0528-Qwen3-8B deepseek-reasoner|GLM-4.1B-0414 Qwen3.5-9b|NVIDIA-Nemotron-Nano-9B-v2 Qwen3.5-9b|qwen3-235b-a22b-thinking-2507 qwen3-235b-a22b-2507|MiniCPM4.1-8B Qwen3.5-9b|qwen3-235b-a22b-2507 qwen3-235b-a22b-2507|Intern-S1-mini Qwen3.5-9b|deepseek-r1-0528 deepseek-reasoner|deepseek-v3.1-terminus deepseek/deepseek-v4-pro|gpt-5-chat gpt-5.1|kimi-k2-0905 qwen3-235b-a22b-2507|glm-4.6 qwen3-235b-a22b-2507|qwen3-235b-a22b-thinking qwen3-235b-a22b-2507|intern-s1-new deepseek/deepseek-v4-pro|gemini-2.5-flash gemini-2.5-flash|DeepSeek-R1-Distill-Qwen-7B deepseek-reasoner|claude-sonnet-4 qwen3-235b-a22b-2507|Intern-s1 deepseek/deepseek-v4-pro|Fin-R1 deepseek-reasoner|deepseek-v3-0324 deepseek/deepseek-v4-pro|internlm3-8b-instruct Qwen3.5-9b|claude-opus-4.1 gpt-5.1|Qwen2.5-Coder-7B-Instruct qwen3-coder-30b-a3b-instruct|cogito-v1-preview-llama-8B deepseek-reasoner|gemma-2-9b-it Qwen3.5-9b|Llama-3.1-Nemotron-Nano-8B-v1 Qwen3.5-9b|glm-4-flash gemini-2.5-flash|granite-3.3-8b-instruct Qwen3.5-9b|Llama-3.1-8B-Instruct Qwen3.5-9b|OpenThinker3-7B deepseek-reasoner|DeepHermes-3-Llama-3-8B-Preview deepseek-reasoner|Llama-3.1-8B-UltraMedical qwen3-coder-30b-a3b-instruct|qwen3-235b-a22b-no-thinking qwen3-235b-a22b-2507|MiMo-7B-RL-0530 deepseek-reasoner|glm-4-9b-chat Qwen3.5-9b|gpt-4.1 qwen3-235b-a22b-2507|GLM-Z1-9B-0414 Qwen3.5-9b|perplexity/sonar Qwen3.5-9b|mistralai/mistral-nemo Qwen3.5-9b|openai/gpt-5 gpt-5.1|anthropic/claude-opus-4.1 gpt-5.1|openai/gpt-5-mini gpt-5.1|openai/gpt-5-nano gpt-5.1"""
R2L = {p.rsplit(" ", 1)[0]: p.rsplit(" ", 1)[1] for p in _MAP.split("|")}
NORM = {k.lower(): TARGETS.index(v) for k, v in R2L.items()}


def to_target(model: str):
"""source model name -> target index, or None if not mappable (case-insensitive,
with a 'provider/name' prefix-strip fallback)."""
key = model.lower()
if key in NORM:
return NORM[key]
if "/" in key and key.split("/", 1)[1] in NORM: # e.g. google/gemini-2.5-pro
return NORM[key.split("/", 1)[1]]
return None


def num(v, default=None):
try:
x = float(v)
return None if (math.isnan(x) or math.isinf(x)) else x
except (TypeError, ValueError):
return default


def real_model(model_name: str, record: dict) -> str:
"""op 2: expand the 'openrouter' wrapper to the record's true actual_model."""
if model_name != "openrouter":
return model_name
ef = record.get("extra_fields")
if isinstance(ef, str):
try:
ef = ast.literal_eval(ef)
except Exception:
ef = {}
return (ef or {}).get("actual_model") or "openrouter"


def main():
# qid -> {model: (score, cost)} ; qid -> (prompt_hash, origin_query)
pool: dict[tuple, dict] = defaultdict(dict)
meta: dict[tuple, tuple] = {}
for jf in sorted(BENCH.rglob("*.json")):
try:
d = json.load(jf.open())
except Exception:
continue
if not isinstance(d, dict) or not isinstance(d.get("records"), list):
continue
ds, sp, mn = str(d.get("dataset_name")), str(d.get("split")), str(d.get("model_name"))
if mn in EXCLUDED_MODELS: # op 1
continue
if ds in SPLIT_SELECTION and sp != SPLIT_SELECTION[ds]: # op 1
continue
for r in d["records"]:
idx = str(r.get("index", "")).strip()
if not idx:
continue
score = num(r.get("score"))
if score is None: # op 1: skip None
continue
qid = (ds, sp, idx)
pool[qid][real_model(mn, r)] = (score, num(r.get("cost"), 0.0)) # op 2
if qid not in meta:
q = str(r.get("origin_query") or r.get("prompt") or "")
meta[qid] = (hashlib.sha256(q.encode()).hexdigest()[:16], q)

rows, n_uns, n_unmapped = [], 0, 0
for qid in sorted(pool):
ds, sp, idx = qid
models = pool[qid]
n_correct = sum(1 for s, _ in models.values() if s >= 0.5)
# op 3: cheapest correct model that maps to a target
cand = [(c, m, to_target(m)) for m, (s, c) in models.items() if s >= 0.5 and to_target(m) is not None]
if not models or n_correct == 0: # op 4
label, src = "unsolvable", ""
n_uns += 1
elif cand:
c, src, ti = min(cand, key=lambda x: x[0])
label = TARGETS[ti]
else:
# op 5: correct only by off-scheme models (e.g. claude-3.5-haiku) -> gemini-2.5-flash
src = min(((cc, mm) for mm, (ss, cc) in models.items() if ss >= 0.5), key=lambda x: x[0])[1]
label = "gemini-2.5-flash"
n_unmapped += 1
ph, q = meta[qid]
rows.append({"dataset": ds, "split": sp, "index": idx, "qid": f"{ds}/{sp}/{idx}",
"prompt_hash": ph, "n_models": len(models), "n_correct": n_correct,
"cheapest_correct_source": src, "label": label, "origin_query": q})

OUT.parent.mkdir(parents=True, exist_ok=True)
fields = ["dataset", "split", "index", "qid", "prompt_hash", "n_models", "n_correct",
"cheapest_correct_source", "label", "origin_query"]
with OUT.open("w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=fields)
w.writeheader()
w.writerows(rows)

from collections import Counter
print(f"wrote {len(rows)} rows -> {OUT}")
print(f"unsolvable={n_uns} unmapped(correct but off-scheme)={n_unmapped}")
print("label distribution:", dict(Counter(r["label"] for r in rows).most_common()))


if __name__ == "__main__":
main()
Loading
Loading