KeiSeiKit-1.0/_primitives/_rust/kei-model-router/src/main.rs
Parfii-bot 4d79049eff feat(kei-model-router): registry-driven, three-layer DNA
Removes hardcoded Claude-only Model enum. Pricing constants now read
from _blocks/registries/models.toml at startup; provider/model lookup
goes through a typed Registry returned by registry.rs.

New API surface:
  - Registry::load(dir) → (providers, models, profiles)
  - pick(profile_id, &Registry) → Result<(provider_id, model_id)>
  - cost_micro_cents(model_id, in, out, &Registry) → Option<u64>
  - next_model(model_id, &Registry) → Option<&Model> (ascending cost,
    same provider, skip deprecated)

Files:
  - registry_types.rs      new   107 LOC  (Provider/Model/Profile structs)
  - registry.rs            new   152 LOC  (TOML load + lookups)
  - pricing.rs             rew   127 LOC  (registry-backed, no constants)
  - escalate.rs            rew   181 LOC  (registry-backed ladder + skip deprecated)
  - select.rs              rew   131 LOC
  - select_kernel.rs       new    74 LOC  (Constructor-Pattern split)
  - select_posterior.rs    new   178 LOC  (Constructor-Pattern split)
  - posterior.rs           rew   197 LOC
  - calibrate.rs           rew   175 LOC
  - lib.rs                 rew    53 LOC
  - main.rs                rew   163 LOC  (CLI updated to new API)
  - Cargo.toml             dep   added toml 0.8

Verification (orchestrator-side, RULE 0.13 §Verify-before-commit):
  - cargo check                 → clean
  - cargo test --release        → 58 passed / 0 failed / 0 ignored
  - LOC limit (Constructor)     → max 197 / limit 200
  - largest fn cmd_select       → ~27 LOC / limit 30

DNA-INDEX.md regenerated by kei-registry hook (primitive count
144 → 150 reflects the 6 new/split modules).

=== STATUS-TRUTH MARKER ===
shipped: functional
stubs: 0
cargo-check: PASS
behaviour-verified: yes
follow-up-required:
  - select.rs `estimated_cost` still embeds inline cost constants
    mirroring models.toml; if non-Anthropic providers need dynamic
    pricing in select-time estimation, thread Registry through.
  - External callers of old `cost_micro_cents(Model, ...)` signature
    will break — intentional, no external callers in this workspace.
2026-05-13 21:23:53 +08:00

163 lines
5.9 KiB
Rust

//! kei-model-router CLI.
//!
//! Subcommands:
//! pricing — print pricing table from models.toml
//! select <agent> [--prompt P]
//! — query router decision for given agent spawn
//! calibrate — re-fit kernel weights against ledger outcomes
//! --help
use kei_model_router::{
calibrate, pick, select, DecisionInput, KernelWeights, Model, Registry,
OPUS_47_TOKENIZER_OVERHEAD,
};
use rusqlite::Connection;
fn main() {
let args: Vec<String> = std::env::args().collect();
match args.get(1).map(String::as_str) {
Some("pricing") | None => print_pricing(),
Some("select") => cmd_select(&args[2..]),
Some("calibrate") => cmd_calibrate(),
Some("--help") | Some("-h") => print_help(),
Some(other) => {
eprintln!("unknown command: {other}");
print_help();
std::process::exit(2);
}
}
}
fn print_help() {
print!(concat!(
"kei-model-router — model selection for Claude Code Agent spawns\n\n",
"Usage:\n",
" kei-model-router [pricing] print pricing table from models.toml\n",
" kei-model-router select <agent> [--prompt P]\n",
" kei-model-router calibrate re-fit kernel weights\n",
" kei-model-router --help\n\n",
"Env:\n",
" KEI_LEDGER_DB override ledger path\n",
" KEI_REGISTRIES_DIR override registries dir\n",
));
}
fn print_pricing() {
let reg = match Registry::load() {
Ok(r) => r,
Err(e) => { eprintln!("registry load error: {e}"); std::process::exit(1); }
};
println!("kei-model-router — pricing from models.toml\n");
println!("{:<30} {:>12} {:>12} {:>12}", "model", "input/M", "output/M", "cache_r");
for m in &reg.models {
println!(
"{:<30} {:>12} {:>12} {:>12}",
m.id,
fmt_micro(m.cost_input_per_mtok_micro),
fmt_micro(m.cost_output_per_mtok_micro),
fmt_micro(m.cache_read_per_mtok_micro),
);
}
println!("\nNote: Opus 4.7 tokenizer may use up to {:.0}% more tokens vs Sonnet/Haiku.",
(OPUS_47_TOKENIZER_OVERHEAD - 1.0) * 100.0);
}
fn cmd_select(args: &[String]) {
let agent = args.first().unwrap_or_else(|| {
eprintln!("usage: kei-model-router select <agent> [--prompt PROMPT]");
std::process::exit(2);
});
let prompt = parse_prompt_flag(args);
if let Ok(reg) = Registry::load() {
if let Some((prov, model)) = pick(agent, &reg) {
println!("agent: {agent}\nprovider: {prov}\nmodel: {model}\nreason: profile_default_model_ref");
return;
}
}
let synthetic_dna = format!("{agent}::?::00000000::00000000-00000000");
let conn = match open_ledger() {
Some(c) => c,
None => {
eprintln!("warning: ledger not available; falling back to default");
print_decision_no_ledger(&synthetic_dna, &prompt);
return;
}
};
let mut input = DecisionInput::new(synthetic_dna.clone(), prompt);
input.kernel_weights = KernelWeights::default();
input.pinned = read_pinned_for_agent(agent);
let d = match select(&input, &conn) {
Ok(d) => d,
Err(e) => { eprintln!("ledger query failed: {e}"); std::process::exit(1); }
};
println!("agent: {agent}");
println!("model: {}", d.model.slug());
println!("expected_cost ${:.4} (microcents={})",
d.expected_cost_micro_cents as f64 / 100_000_000.0, d.expected_cost_micro_cents);
println!("q_lower_bound {:.3} (posterior n={})", d.quality_lower_bound, d.posterior_n);
println!("reason: {}", d.reason);
}
fn parse_prompt_flag(args: &[String]) -> String {
let mut i = 1;
while i < args.len() {
if args[i] == "--prompt" {
if let Some(p) = args.get(i + 1) { return p.clone(); }
}
i += 1;
}
String::new()
}
fn print_decision_no_ledger(dna: &str, prompt: &str) {
let inp = DecisionInput::new(dna.to_string(), prompt.to_string());
let est = kei_model_router::complexity::estimate(
prompt, kei_model_router::dna_class::role(dna));
println!("model: {}\nτ: {:.2}\nreason: no_ledger_fallback",
inp.fallback.slug(), est.tau);
}
fn cmd_calibrate() {
let conn = match open_ledger() {
Some(c) => c,
None => { eprintln!("ledger not found; aborting calibration"); std::process::exit(1); }
};
let r = match calibrate::calibrate(&conn) {
Ok(r) => r,
Err(e) => { eprintln!("calibration query failed: {e}"); std::process::exit(1); }
};
println!("rows evaluated: {}", r.rows_evaluated);
if r.rows_evaluated < 5 {
println!("(too few rows for calibration; using default weights)");
return;
}
println!("baseline MSE: {:.4}\nbest MSE: {:.4}\nimprovement: {:.4}",
r.baseline_mse, r.best_mse, r.baseline_mse - r.best_mse);
println!("calibrated weights:\n alpha_role: {:.2}\n alpha_caps: {:.2}\n alpha_scope: {:.2}\n alpha_body: {:.2}",
r.best_weights.alpha_role, r.best_weights.alpha_caps,
r.best_weights.alpha_scope, r.best_weights.alpha_body);
}
fn open_ledger() -> Option<Connection> {
let path = std::env::var("KEI_LEDGER_DB").unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_default();
format!("{home}/.claude/agents/ledger.sqlite")
});
Connection::open(&path).ok()
}
fn read_pinned_for_agent(agent: &str) -> Option<Model> {
let home = std::env::var("HOME").ok()?;
let raw = std::fs::read_to_string(format!("{home}/.claude/settings.json")).ok()?;
let json: serde_json::Value = serde_json::from_str(&raw).ok()?;
let model_slug = json.get("router")?.get("pinned")?.get(agent)?.as_str()?;
Model::from_slug(model_slug)
}
fn fmt_micro(uc: u64) -> String {
format!("${:.2}", uc as f64 / 100_000_000.0)
}