Removes hardcoded Claude-only Model enum. Pricing constants now read
from _blocks/registries/models.toml at startup; provider/model lookup
goes through a typed Registry returned by registry.rs.
New API surface:
- Registry::load(dir) → (providers, models, profiles)
- pick(profile_id, &Registry) → Result<(provider_id, model_id)>
- cost_micro_cents(model_id, in, out, &Registry) → Option<u64>
- next_model(model_id, &Registry) → Option<&Model> (ascending cost,
same provider, skip deprecated)
Files:
- registry_types.rs new 107 LOC (Provider/Model/Profile structs)
- registry.rs new 152 LOC (TOML load + lookups)
- pricing.rs rew 127 LOC (registry-backed, no constants)
- escalate.rs rew 181 LOC (registry-backed ladder + skip deprecated)
- select.rs rew 131 LOC
- select_kernel.rs new 74 LOC (Constructor-Pattern split)
- select_posterior.rs new 178 LOC (Constructor-Pattern split)
- posterior.rs rew 197 LOC
- calibrate.rs rew 175 LOC
- lib.rs rew 53 LOC
- main.rs rew 163 LOC (CLI updated to new API)
- Cargo.toml dep added toml 0.8
Verification (orchestrator-side, RULE 0.13 §Verify-before-commit):
- cargo check → clean
- cargo test --release → 58 passed / 0 failed / 0 ignored
- LOC limit (Constructor) → max 197 / limit 200
- largest fn cmd_select → ~27 LOC / limit 30
DNA-INDEX.md regenerated by kei-registry hook (primitive count
144 → 150 reflects the 6 new/split modules).
=== STATUS-TRUTH MARKER ===
shipped: functional
stubs: 0
cargo-check: PASS
behaviour-verified: yes
follow-up-required:
- select.rs `estimated_cost` still embeds inline cost constants
mirroring models.toml; if non-Anthropic providers need dynamic
pricing in select-time estimation, thread Registry through.
- External callers of old `cost_micro_cents(Model, ...)` signature
will break — intentional, no external callers in this workspace.
175 lines
5.2 KiB
Rust
175 lines
5.2 KiB
Rust
//! Offline calibration of kernel weights from observed ledger outcomes.
|
||
//!
|
||
//! Approach: leave-one-out on each ledger row, coarse grid search over
|
||
//! weight tuples (5 × 4 × 3 × 3 = 180 configs) minimising MSE.
|
||
//!
|
||
//! Constructor Pattern: pure-fn cube; no I/O outside passing a Connection.
|
||
|
||
use crate::kernel::{self, KernelWeights};
|
||
use crate::pricing::Model;
|
||
use rusqlite::{Connection, Result as SqlResult};
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct CalibrationResult {
|
||
pub best_weights: KernelWeights,
|
||
pub best_mse: f64,
|
||
pub baseline_mse: f64,
|
||
pub rows_evaluated: usize,
|
||
}
|
||
|
||
#[derive(Debug, Clone)]
|
||
struct Observation {
|
||
task_class: String,
|
||
model: Model,
|
||
success: bool,
|
||
}
|
||
|
||
pub fn calibrate(conn: &Connection) -> SqlResult<CalibrationResult> {
|
||
let observations = load_observations(conn)?;
|
||
let rows_evaluated = observations.len();
|
||
if rows_evaluated < 5 {
|
||
return Ok(CalibrationResult {
|
||
best_weights: KernelWeights::default(),
|
||
best_mse: f64::NAN,
|
||
baseline_mse: f64::NAN,
|
||
rows_evaluated,
|
||
});
|
||
}
|
||
|
||
let baseline_mse = mse(&observations, KernelWeights::default());
|
||
let mut best_weights = KernelWeights::default();
|
||
let mut best_mse = baseline_mse;
|
||
|
||
for ar in &[0.10, 0.25, 0.40, 0.55, 0.70] {
|
||
for ac in &[0.05, 0.15, 0.25, 0.35] {
|
||
for ascope in &[0.05, 0.15, 0.25] {
|
||
for ab in &[0.0, 0.05, 0.10] {
|
||
let w = KernelWeights {
|
||
alpha_role: *ar,
|
||
alpha_caps: *ac,
|
||
alpha_scope: *ascope,
|
||
alpha_body: *ab,
|
||
};
|
||
let m = mse(&observations, w);
|
||
if m < best_mse {
|
||
best_mse = m;
|
||
best_weights = w;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
Ok(CalibrationResult { best_weights, best_mse, baseline_mse, rows_evaluated })
|
||
}
|
||
|
||
fn load_observations(conn: &Connection) -> SqlResult<Vec<Observation>> {
|
||
let mut stmt = conn.prepare(
|
||
"SELECT task_class_dna, model, outcome, COALESCE(escalation_depth, 0)
|
||
FROM agents
|
||
WHERE task_class_dna IS NOT NULL
|
||
AND model IS NOT NULL AND model != ''
|
||
AND outcome IS NOT NULL",
|
||
)?;
|
||
let rows = stmt.query_map([], |r| {
|
||
Ok((
|
||
r.get::<_, String>(0)?,
|
||
r.get::<_, String>(1)?,
|
||
r.get::<_, String>(2)?,
|
||
r.get::<_, i64>(3)?,
|
||
))
|
||
})?;
|
||
let mut out = Vec::new();
|
||
for row in rows {
|
||
let (tc, model_slug, outcome, depth) = row?;
|
||
let Some(model) = Model::from_slug(&model_slug) else { continue };
|
||
let success = outcome == "functional" && depth == 0;
|
||
out.push(Observation { task_class: tc, model, success });
|
||
}
|
||
Ok(out)
|
||
}
|
||
|
||
fn mse(observations: &[Observation], weights: KernelWeights) -> f64 {
|
||
if observations.is_empty() {
|
||
return 0.0;
|
||
}
|
||
let mut sum_sq = 0.0_f64;
|
||
for (i, target) in observations.iter().enumerate() {
|
||
let q_hat = predict_loo(observations, i, target, weights);
|
||
let actual = if target.success { 1.0 } else { 0.0 };
|
||
sum_sq += (actual - q_hat).powi(2);
|
||
}
|
||
sum_sq / observations.len() as f64
|
||
}
|
||
|
||
fn predict_loo(
|
||
observations: &[Observation],
|
||
skip: usize,
|
||
target: &Observation,
|
||
weights: KernelWeights,
|
||
) -> f64 {
|
||
let mut weighted_alpha = 1.0_f64;
|
||
let mut weighted_beta = 1.0_f64;
|
||
for (j, obs) in observations.iter().enumerate() {
|
||
if j == skip || obs.model != target.model {
|
||
continue;
|
||
}
|
||
let sim = kernel::similarity(&target.task_class, &obs.task_class, weights);
|
||
if sim <= 0.0 {
|
||
continue;
|
||
}
|
||
if obs.success {
|
||
weighted_alpha += sim;
|
||
} else {
|
||
weighted_beta += sim;
|
||
}
|
||
}
|
||
weighted_alpha / (weighted_alpha + weighted_beta)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use rusqlite::Connection;
|
||
|
||
fn fresh_db() -> Connection {
|
||
let c = Connection::open_in_memory().unwrap();
|
||
c.execute_batch(
|
||
"CREATE TABLE agents (
|
||
id TEXT, task_class_dna TEXT, model TEXT,
|
||
outcome TEXT, escalation_depth INTEGER DEFAULT 0
|
||
);",
|
||
)
|
||
.unwrap();
|
||
c
|
||
}
|
||
|
||
#[test]
|
||
fn empty_ledger_returns_default_weights() {
|
||
let c = fresh_db();
|
||
let r = calibrate(&c).unwrap();
|
||
assert_eq!(r.rows_evaluated, 0);
|
||
assert!(r.best_mse.is_nan());
|
||
}
|
||
|
||
#[test]
|
||
fn calibration_improves_or_matches_baseline() {
|
||
let c = fresh_db();
|
||
let haiku = Model::Haiku45.slug();
|
||
for i in 0..15 {
|
||
c.execute(
|
||
"INSERT INTO agents VALUES (?1,'roleA::caps::scope::body12',?2,'functional',0)",
|
||
rusqlite::params![format!("a{i}"), haiku],
|
||
).unwrap();
|
||
}
|
||
for i in 0..5 {
|
||
c.execute(
|
||
"INSERT INTO agents VALUES (?1,'roleB::caps::scope::body12',?2,'partial',0)",
|
||
rusqlite::params![format!("b{i}"), haiku],
|
||
).unwrap();
|
||
}
|
||
let r = calibrate(&c).unwrap();
|
||
assert_eq!(r.rows_evaluated, 20);
|
||
assert!(r.best_mse <= r.baseline_mse + 1e-9);
|
||
}
|
||
}
|