feat(kei-buddy): wire OpenAiExtractor + chat_id whitelist + env-configurable LLM
Two additions on top of the MVP serve binary:
1. Whitelist by chat_id (KEI_BUDDY_ALLOWED_CHAT_IDS env, CSV).
* BuddyContext gains Arc<Option<Vec<i64>>> allowed_chat_ids
* chat_allowed() check fires before process_text
* Non-whitelisted chats: warn-log + ignore (no response sent)
* None or empty list = accept all (back-compat with prior behaviour)
2. Real LLM wiring (KEI_BUDDY_LLM_PROXY / _LLM_KEY / _LLM_MODEL).
* When extractor-openai feature compiled in AND both proxy+key set,
run_serve instantiates OpenAiExtractor instead of MockExtractor
* Defaults: proxy=https://api.openai.com, key=OPENAI_API_KEY env,
model=gpt-4o-mini
* Fallback: warns + MockExtractor (state machine still walks, but
LLM-extracted fields are empty)
* extractor::OpenAiExtractor gains new_with_model(proxy, key, model);
model is now per-instance instead of compile-time DEFAULT_MODEL
3. start_listener extracted as helper — keeps run_serve readable across
the two feature-gated branches.
Verify-before-commit:
* cargo check -p kei-buddy (default): PASS
* cargo check -p kei-buddy --features extractor-openai: PASS
* cargo test -p kei-buddy --lib: 20/0 unchanged
This commit is contained in:
parent
7414d14cc7
commit
0045b6ac77
3 changed files with 90 additions and 5 deletions
|
|
@ -48,10 +48,34 @@ async fn cmd_serve() -> anyhow::Result<()> {
|
|||
db_path: db_path_from_env(),
|
||||
bot_token: require_env("TELEGRAM_BOT_TOKEN")?,
|
||||
webhook_secret: require_env("TELEGRAM_WEBHOOK_SECRET")?,
|
||||
allowed_chat_ids: allowed_chat_ids_from_env(),
|
||||
llm_proxy_url: std::env::var("KEI_BUDDY_LLM_PROXY")
|
||||
.ok()
|
||||
.or_else(|| Some("https://api.openai.com".to_string())),
|
||||
llm_api_key: std::env::var("KEI_BUDDY_LLM_KEY")
|
||||
.ok()
|
||||
.or_else(|| std::env::var("OPENAI_API_KEY").ok()),
|
||||
llm_model: std::env::var("KEI_BUDDY_LLM_MODEL").ok(),
|
||||
};
|
||||
run_serve(cfg).await
|
||||
}
|
||||
|
||||
/// Parse `KEI_BUDDY_ALLOWED_CHAT_IDS` CSV → Some(Vec<i64>); empty/missing → None.
|
||||
fn allowed_chat_ids_from_env() -> Option<Vec<i64>> {
|
||||
let raw = std::env::var("KEI_BUDDY_ALLOWED_CHAT_IDS").ok()?;
|
||||
let list: Vec<i64> = raw
|
||||
.split(',')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter_map(|s| s.parse::<i64>().ok())
|
||||
.collect();
|
||||
if list.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(list)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "serve"))]
|
||||
async fn cmd_serve() -> anyhow::Result<()> {
|
||||
anyhow::bail!("kei-buddy was compiled without the `serve` feature. Rebuild with --features serve.");
|
||||
|
|
|
|||
|
|
@ -140,14 +140,20 @@ pub mod openai {
|
|||
pub struct OpenAiExtractor {
|
||||
pub proxy_url: String,
|
||||
pub api_key: String,
|
||||
pub model: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl OpenAiExtractor {
|
||||
pub fn new(proxy_url: String, api_key: String) -> Self {
|
||||
Self::new_with_model(proxy_url, api_key, DEFAULT_MODEL.to_string())
|
||||
}
|
||||
|
||||
pub fn new_with_model(proxy_url: String, api_key: String, model: String) -> Self {
|
||||
Self {
|
||||
proxy_url,
|
||||
api_key,
|
||||
model,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
|
@ -157,7 +163,7 @@ pub mod openai {
|
|||
impl LlmExtractor for OpenAiExtractor {
|
||||
async fn extract(&self, system: &str, user_text: &str) -> Result<Value, BuddyError> {
|
||||
let body = serde_json::json!({
|
||||
"model": DEFAULT_MODEL,
|
||||
"model": &self.model,
|
||||
"temperature": 0,
|
||||
"max_tokens": 200,
|
||||
"messages": [
|
||||
|
|
|
|||
|
|
@ -29,6 +29,15 @@ pub struct ServeConfig {
|
|||
pub db_path: String,
|
||||
pub bot_token: String,
|
||||
pub webhook_secret: String,
|
||||
/// If `Some`, only these chat_ids are processed; others are warn-logged + ignored.
|
||||
/// `None` (or empty) means accept all chat_ids.
|
||||
pub allowed_chat_ids: Option<Vec<i64>>,
|
||||
/// Optional OpenAI-compatible LLM proxy. If set together with `llm_api_key`,
|
||||
/// `run_serve` instantiates `OpenAiExtractor`; otherwise falls back to
|
||||
/// `MockExtractor` with a warning.
|
||||
pub llm_proxy_url: Option<String>,
|
||||
pub llm_api_key: Option<String>,
|
||||
pub llm_model: Option<String>,
|
||||
}
|
||||
|
||||
/// Axum state — implements `WebhookContext` for the webhook handler.
|
||||
|
|
@ -40,6 +49,8 @@ pub struct BuddyContext<E: LlmExtractor + Send + Sync + 'static> {
|
|||
pub store: Arc<SqliteBuddyStore>,
|
||||
pub extractor: Arc<E>,
|
||||
pub http: reqwest::Client,
|
||||
/// Whitelist of chat_ids; `None` or empty = accept all.
|
||||
pub allowed_chat_ids: Arc<Option<Vec<i64>>>,
|
||||
}
|
||||
|
||||
impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
|
||||
|
|
@ -50,6 +61,7 @@ impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
|
|||
store: Arc::clone(&self.store),
|
||||
extractor: Arc::clone(&self.extractor),
|
||||
http: self.http.clone(),
|
||||
allowed_chat_ids: Arc::clone(&self.allowed_chat_ids),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -73,7 +85,18 @@ impl<E: LlmExtractor + Send + Sync + 'static> WebhookContext for BuddyContext<E>
|
|||
}
|
||||
|
||||
impl<E: LlmExtractor + Send + Sync + 'static> BuddyContext<E> {
|
||||
fn chat_allowed(&self, chat_id: i64) -> bool {
|
||||
match self.allowed_chat_ids.as_ref() {
|
||||
Some(list) if !list.is_empty() => list.contains(&chat_id),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_text(&self, chat_id: i64, text: String) {
|
||||
if !self.chat_allowed(chat_id) {
|
||||
warn!(chat_id, "chat_id not in whitelist; ignoring");
|
||||
return;
|
||||
}
|
||||
if let Err(e) = self.process_text(chat_id, &text).await {
|
||||
error!(chat_id, error = %e, "failed to process text event");
|
||||
}
|
||||
|
|
@ -138,16 +161,48 @@ where
|
|||
pub async fn run_serve(cfg: ServeConfig) -> anyhow::Result<()> {
|
||||
init_tracing();
|
||||
let store = Arc::new(SqliteBuddyStore::from_path(&cfg.db_path)?);
|
||||
let allowed_chat_ids = Arc::new(cfg.allowed_chat_ids);
|
||||
let http = reqwest::Client::new();
|
||||
|
||||
#[cfg(feature = "extractor-openai")]
|
||||
{
|
||||
if let (Some(proxy), Some(key)) = (cfg.llm_proxy_url, cfg.llm_api_key) {
|
||||
let model = cfg
|
||||
.llm_model
|
||||
.unwrap_or_else(|| "gpt-4o-mini".to_string());
|
||||
tracing::info!(model = %model, "using OpenAiExtractor (LiteLLM-compatible)");
|
||||
let extractor = Arc::new(crate::extractor::openai::OpenAiExtractor::new_with_model(
|
||||
proxy, key, model,
|
||||
));
|
||||
return start_listener(cfg.port, BuddyContext {
|
||||
secret: cfg.webhook_secret,
|
||||
bot_token: cfg.bot_token,
|
||||
store,
|
||||
extractor,
|
||||
http,
|
||||
allowed_chat_ids,
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
|
||||
warn!("no LLM extractor configured — using MockExtractor (state machine will advance but field-extraction returns empty)");
|
||||
let extractor = Arc::new(crate::extractor::MockExtractor::new(json!({})));
|
||||
let ctx = BuddyContext {
|
||||
start_listener(cfg.port, BuddyContext {
|
||||
secret: cfg.webhook_secret,
|
||||
bot_token: cfg.bot_token,
|
||||
store,
|
||||
extractor,
|
||||
http: reqwest::Client::new(),
|
||||
};
|
||||
http,
|
||||
allowed_chat_ids,
|
||||
}).await
|
||||
}
|
||||
|
||||
async fn start_listener<E>(port: u16, ctx: BuddyContext<E>) -> anyhow::Result<()>
|
||||
where
|
||||
E: LlmExtractor + Send + Sync + 'static,
|
||||
{
|
||||
let router = build_router(ctx);
|
||||
let addr = format!("0.0.0.0:{}", cfg.port);
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
tracing::info!(addr = %addr, "kei-buddy listening");
|
||||
axum::serve(listener, router).await?;
|
||||
|
|
|
|||
Loading…
Reference in a new issue