feat(kei-buddy): wire OpenAiExtractor + chat_id whitelist + env-configurable LLM

Two additions on top of the MVP serve binary:

1. Whitelist by chat_id (KEI_BUDDY_ALLOWED_CHAT_IDS env, CSV).
   * BuddyContext gains Arc<Option<Vec<i64>>> allowed_chat_ids
   * chat_allowed() check fires before process_text
   * Non-whitelisted chats: warn-log + ignore (no response sent)
   * None or empty list = accept all (back-compat with prior behaviour)

2. Real LLM wiring (KEI_BUDDY_LLM_PROXY / _LLM_KEY / _LLM_MODEL).
   * When extractor-openai feature compiled in AND both proxy+key set,
     run_serve instantiates OpenAiExtractor instead of MockExtractor
   * Defaults: proxy=https://api.openai.com, key=OPENAI_API_KEY env,
     model=gpt-4o-mini
   * Fallback: warns + MockExtractor (state machine still walks, but
     LLM-extracted fields are empty)
   * extractor::OpenAiExtractor gains new_with_model(proxy, key, model);
     model is now per-instance instead of compile-time DEFAULT_MODEL

3. start_listener extracted as helper — keeps run_serve readable across
   the two feature-gated branches.

Verify-before-commit:
  * cargo check -p kei-buddy (default): PASS
  * cargo check -p kei-buddy --features extractor-openai: PASS
  * cargo test -p kei-buddy --lib: 20/0 unchanged
This commit is contained in:
Parfii-bot 2026-05-12 14:49:43 +08:00
parent 621ac8685f
commit 44502507a2
3 changed files with 90 additions and 5 deletions

View file

@ -48,10 +48,34 @@ async fn cmd_serve() -> anyhow::Result<()> {
db_path: db_path_from_env(),
bot_token: require_env("TELEGRAM_BOT_TOKEN")?,
webhook_secret: require_env("TELEGRAM_WEBHOOK_SECRET")?,
allowed_chat_ids: allowed_chat_ids_from_env(),
llm_proxy_url: std::env::var("KEI_BUDDY_LLM_PROXY")
.ok()
.or_else(|| Some("https://api.openai.com".to_string())),
llm_api_key: std::env::var("KEI_BUDDY_LLM_KEY")
.ok()
.or_else(|| std::env::var("OPENAI_API_KEY").ok()),
llm_model: std::env::var("KEI_BUDDY_LLM_MODEL").ok(),
};
run_serve(cfg).await
}
/// Parse `KEI_BUDDY_ALLOWED_CHAT_IDS` CSV → Some(Vec<i64>); empty/missing → None.
fn allowed_chat_ids_from_env() -> Option<Vec<i64>> {
let raw = std::env::var("KEI_BUDDY_ALLOWED_CHAT_IDS").ok()?;
let list: Vec<i64> = raw
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.filter_map(|s| s.parse::<i64>().ok())
.collect();
if list.is_empty() {
None
} else {
Some(list)
}
}
#[cfg(not(feature = "serve"))]
async fn cmd_serve() -> anyhow::Result<()> {
anyhow::bail!("kei-buddy was compiled without the `serve` feature. Rebuild with --features serve.");

View file

@ -140,14 +140,20 @@ pub mod openai {
pub struct OpenAiExtractor {
pub proxy_url: String,
pub api_key: String,
pub model: String,
client: reqwest::Client,
}
impl OpenAiExtractor {
pub fn new(proxy_url: String, api_key: String) -> Self {
Self::new_with_model(proxy_url, api_key, DEFAULT_MODEL.to_string())
}
pub fn new_with_model(proxy_url: String, api_key: String, model: String) -> Self {
Self {
proxy_url,
api_key,
model,
client: reqwest::Client::new(),
}
}
@ -157,7 +163,7 @@ pub mod openai {
impl LlmExtractor for OpenAiExtractor {
async fn extract(&self, system: &str, user_text: &str) -> Result<Value, BuddyError> {
let body = serde_json::json!({
"model": DEFAULT_MODEL,
"model": &self.model,
"temperature": 0,
"max_tokens": 200,
"messages": [

View file

@ -29,6 +29,15 @@ pub struct ServeConfig {
pub db_path: String,
pub bot_token: String,
pub webhook_secret: String,
/// If `Some`, only these chat_ids are processed; others are warn-logged + ignored.
/// `None` (or empty) means accept all chat_ids.
pub allowed_chat_ids: Option<Vec<i64>>,
/// Optional OpenAI-compatible LLM proxy. If set together with `llm_api_key`,
/// `run_serve` instantiates `OpenAiExtractor`; otherwise falls back to
/// `MockExtractor` with a warning.
pub llm_proxy_url: Option<String>,
pub llm_api_key: Option<String>,
pub llm_model: Option<String>,
}
/// Axum state — implements `WebhookContext` for the webhook handler.
@ -40,6 +49,8 @@ pub struct BuddyContext<E: LlmExtractor + Send + Sync + 'static> {
pub store: Arc<SqliteBuddyStore>,
pub extractor: Arc<E>,
pub http: reqwest::Client,
/// Whitelist of chat_ids; `None` or empty = accept all.
pub allowed_chat_ids: Arc<Option<Vec<i64>>>,
}
impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
@ -50,6 +61,7 @@ impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
store: Arc::clone(&self.store),
extractor: Arc::clone(&self.extractor),
http: self.http.clone(),
allowed_chat_ids: Arc::clone(&self.allowed_chat_ids),
}
}
}
@ -73,7 +85,18 @@ impl<E: LlmExtractor + Send + Sync + 'static> WebhookContext for BuddyContext<E>
}
impl<E: LlmExtractor + Send + Sync + 'static> BuddyContext<E> {
fn chat_allowed(&self, chat_id: i64) -> bool {
match self.allowed_chat_ids.as_ref() {
Some(list) if !list.is_empty() => list.contains(&chat_id),
_ => true,
}
}
async fn handle_text(&self, chat_id: i64, text: String) {
if !self.chat_allowed(chat_id) {
warn!(chat_id, "chat_id not in whitelist; ignoring");
return;
}
if let Err(e) = self.process_text(chat_id, &text).await {
error!(chat_id, error = %e, "failed to process text event");
}
@ -138,16 +161,48 @@ where
pub async fn run_serve(cfg: ServeConfig) -> anyhow::Result<()> {
init_tracing();
let store = Arc::new(SqliteBuddyStore::from_path(&cfg.db_path)?);
let allowed_chat_ids = Arc::new(cfg.allowed_chat_ids);
let http = reqwest::Client::new();
#[cfg(feature = "extractor-openai")]
{
if let (Some(proxy), Some(key)) = (cfg.llm_proxy_url, cfg.llm_api_key) {
let model = cfg
.llm_model
.unwrap_or_else(|| "gpt-4o-mini".to_string());
tracing::info!(model = %model, "using OpenAiExtractor (LiteLLM-compatible)");
let extractor = Arc::new(crate::extractor::openai::OpenAiExtractor::new_with_model(
proxy, key, model,
));
return start_listener(cfg.port, BuddyContext {
secret: cfg.webhook_secret,
bot_token: cfg.bot_token,
store,
extractor,
http,
allowed_chat_ids,
}).await;
}
}
warn!("no LLM extractor configured — using MockExtractor (state machine will advance but field-extraction returns empty)");
let extractor = Arc::new(crate::extractor::MockExtractor::new(json!({})));
let ctx = BuddyContext {
start_listener(cfg.port, BuddyContext {
secret: cfg.webhook_secret,
bot_token: cfg.bot_token,
store,
extractor,
http: reqwest::Client::new(),
};
http,
allowed_chat_ids,
}).await
}
async fn start_listener<E>(port: u16, ctx: BuddyContext<E>) -> anyhow::Result<()>
where
E: LlmExtractor + Send + Sync + 'static,
{
let router = build_router(ctx);
let addr = format!("0.0.0.0:{}", cfg.port);
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!(addr = %addr, "kei-buddy listening");
axum::serve(listener, router).await?;