diff --git a/_primitives/_rust/kei-buddy/src/bin/kei-buddy.rs b/_primitives/_rust/kei-buddy/src/bin/kei-buddy.rs index 40c9ebc..b97cb54 100644 --- a/_primitives/_rust/kei-buddy/src/bin/kei-buddy.rs +++ b/_primitives/_rust/kei-buddy/src/bin/kei-buddy.rs @@ -48,10 +48,34 @@ async fn cmd_serve() -> anyhow::Result<()> { db_path: db_path_from_env(), bot_token: require_env("TELEGRAM_BOT_TOKEN")?, webhook_secret: require_env("TELEGRAM_WEBHOOK_SECRET")?, + allowed_chat_ids: allowed_chat_ids_from_env(), + llm_proxy_url: std::env::var("KEI_BUDDY_LLM_PROXY") + .ok() + .or_else(|| Some("https://api.openai.com".to_string())), + llm_api_key: std::env::var("KEI_BUDDY_LLM_KEY") + .ok() + .or_else(|| std::env::var("OPENAI_API_KEY").ok()), + llm_model: std::env::var("KEI_BUDDY_LLM_MODEL").ok(), }; run_serve(cfg).await } +/// Parse `KEI_BUDDY_ALLOWED_CHAT_IDS` CSV → Some(Vec); empty/missing → None. +fn allowed_chat_ids_from_env() -> Option> { + let raw = std::env::var("KEI_BUDDY_ALLOWED_CHAT_IDS").ok()?; + let list: Vec = raw + .split(',') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .filter_map(|s| s.parse::().ok()) + .collect(); + if list.is_empty() { + None + } else { + Some(list) + } +} + #[cfg(not(feature = "serve"))] async fn cmd_serve() -> anyhow::Result<()> { anyhow::bail!("kei-buddy was compiled without the `serve` feature. Rebuild with --features serve."); diff --git a/_primitives/_rust/kei-buddy/src/extractor.rs b/_primitives/_rust/kei-buddy/src/extractor.rs index 4bcf573..41129bd 100644 --- a/_primitives/_rust/kei-buddy/src/extractor.rs +++ b/_primitives/_rust/kei-buddy/src/extractor.rs @@ -140,14 +140,20 @@ pub mod openai { pub struct OpenAiExtractor { pub proxy_url: String, pub api_key: String, + pub model: String, client: reqwest::Client, } impl OpenAiExtractor { pub fn new(proxy_url: String, api_key: String) -> Self { + Self::new_with_model(proxy_url, api_key, DEFAULT_MODEL.to_string()) + } + + pub fn new_with_model(proxy_url: String, api_key: String, model: String) -> Self { Self { proxy_url, api_key, + model, client: reqwest::Client::new(), } } @@ -157,7 +163,7 @@ pub mod openai { impl LlmExtractor for OpenAiExtractor { async fn extract(&self, system: &str, user_text: &str) -> Result { let body = serde_json::json!({ - "model": DEFAULT_MODEL, + "model": &self.model, "temperature": 0, "max_tokens": 200, "messages": [ diff --git a/_primitives/_rust/kei-buddy/src/serve.rs b/_primitives/_rust/kei-buddy/src/serve.rs index acbde35..b32a9bd 100644 --- a/_primitives/_rust/kei-buddy/src/serve.rs +++ b/_primitives/_rust/kei-buddy/src/serve.rs @@ -29,6 +29,15 @@ pub struct ServeConfig { pub db_path: String, pub bot_token: String, pub webhook_secret: String, + /// If `Some`, only these chat_ids are processed; others are warn-logged + ignored. + /// `None` (or empty) means accept all chat_ids. + pub allowed_chat_ids: Option>, + /// Optional OpenAI-compatible LLM proxy. If set together with `llm_api_key`, + /// `run_serve` instantiates `OpenAiExtractor`; otherwise falls back to + /// `MockExtractor` with a warning. + pub llm_proxy_url: Option, + pub llm_api_key: Option, + pub llm_model: Option, } /// Axum state — implements `WebhookContext` for the webhook handler. @@ -40,6 +49,8 @@ pub struct BuddyContext { pub store: Arc, pub extractor: Arc, pub http: reqwest::Client, + /// Whitelist of chat_ids; `None` or empty = accept all. + pub allowed_chat_ids: Arc>>, } impl Clone for BuddyContext { @@ -50,6 +61,7 @@ impl Clone for BuddyContext { store: Arc::clone(&self.store), extractor: Arc::clone(&self.extractor), http: self.http.clone(), + allowed_chat_ids: Arc::clone(&self.allowed_chat_ids), } } } @@ -73,7 +85,18 @@ impl WebhookContext for BuddyContext } impl BuddyContext { + fn chat_allowed(&self, chat_id: i64) -> bool { + match self.allowed_chat_ids.as_ref() { + Some(list) if !list.is_empty() => list.contains(&chat_id), + _ => true, + } + } + async fn handle_text(&self, chat_id: i64, text: String) { + if !self.chat_allowed(chat_id) { + warn!(chat_id, "chat_id not in whitelist; ignoring"); + return; + } if let Err(e) = self.process_text(chat_id, &text).await { error!(chat_id, error = %e, "failed to process text event"); } @@ -138,16 +161,48 @@ where pub async fn run_serve(cfg: ServeConfig) -> anyhow::Result<()> { init_tracing(); let store = Arc::new(SqliteBuddyStore::from_path(&cfg.db_path)?); + let allowed_chat_ids = Arc::new(cfg.allowed_chat_ids); + let http = reqwest::Client::new(); + + #[cfg(feature = "extractor-openai")] + { + if let (Some(proxy), Some(key)) = (cfg.llm_proxy_url, cfg.llm_api_key) { + let model = cfg + .llm_model + .unwrap_or_else(|| "gpt-4o-mini".to_string()); + tracing::info!(model = %model, "using OpenAiExtractor (LiteLLM-compatible)"); + let extractor = Arc::new(crate::extractor::openai::OpenAiExtractor::new_with_model( + proxy, key, model, + )); + return start_listener(cfg.port, BuddyContext { + secret: cfg.webhook_secret, + bot_token: cfg.bot_token, + store, + extractor, + http, + allowed_chat_ids, + }).await; + } + } + + warn!("no LLM extractor configured — using MockExtractor (state machine will advance but field-extraction returns empty)"); let extractor = Arc::new(crate::extractor::MockExtractor::new(json!({}))); - let ctx = BuddyContext { + start_listener(cfg.port, BuddyContext { secret: cfg.webhook_secret, bot_token: cfg.bot_token, store, extractor, - http: reqwest::Client::new(), - }; + http, + allowed_chat_ids, + }).await +} + +async fn start_listener(port: u16, ctx: BuddyContext) -> anyhow::Result<()> +where + E: LlmExtractor + Send + Sync + 'static, +{ let router = build_router(ctx); - let addr = format!("0.0.0.0:{}", cfg.port); + let addr = format!("0.0.0.0:{}", port); let listener = tokio::net::TcpListener::bind(&addr).await?; tracing::info!(addr = %addr, "kei-buddy listening"); axum::serve(listener, router).await?;