feat(kei-buddy): wire OpenAiExtractor + chat_id whitelist + env-configurable LLM
Two additions on top of the MVP serve binary:
1. Whitelist by chat_id (KEI_BUDDY_ALLOWED_CHAT_IDS env, CSV).
* BuddyContext gains Arc<Option<Vec<i64>>> allowed_chat_ids
* chat_allowed() check fires before process_text
* Non-whitelisted chats: warn-log + ignore (no response sent)
* None or empty list = accept all (back-compat with prior behaviour)
2. Real LLM wiring (KEI_BUDDY_LLM_PROXY / _LLM_KEY / _LLM_MODEL).
* When extractor-openai feature compiled in AND both proxy+key set,
run_serve instantiates OpenAiExtractor instead of MockExtractor
* Defaults: proxy=https://api.openai.com, key=OPENAI_API_KEY env,
model=gpt-4o-mini
* Fallback: warns + MockExtractor (state machine still walks, but
LLM-extracted fields are empty)
* extractor::OpenAiExtractor gains new_with_model(proxy, key, model);
model is now per-instance instead of compile-time DEFAULT_MODEL
3. start_listener extracted as helper — keeps run_serve readable across
the two feature-gated branches.
Verify-before-commit:
* cargo check -p kei-buddy (default): PASS
* cargo check -p kei-buddy --features extractor-openai: PASS
* cargo test -p kei-buddy --lib: 20/0 unchanged
This commit is contained in:
parent
621ac8685f
commit
44502507a2
3 changed files with 90 additions and 5 deletions
|
|
@ -48,10 +48,34 @@ async fn cmd_serve() -> anyhow::Result<()> {
|
||||||
db_path: db_path_from_env(),
|
db_path: db_path_from_env(),
|
||||||
bot_token: require_env("TELEGRAM_BOT_TOKEN")?,
|
bot_token: require_env("TELEGRAM_BOT_TOKEN")?,
|
||||||
webhook_secret: require_env("TELEGRAM_WEBHOOK_SECRET")?,
|
webhook_secret: require_env("TELEGRAM_WEBHOOK_SECRET")?,
|
||||||
|
allowed_chat_ids: allowed_chat_ids_from_env(),
|
||||||
|
llm_proxy_url: std::env::var("KEI_BUDDY_LLM_PROXY")
|
||||||
|
.ok()
|
||||||
|
.or_else(|| Some("https://api.openai.com".to_string())),
|
||||||
|
llm_api_key: std::env::var("KEI_BUDDY_LLM_KEY")
|
||||||
|
.ok()
|
||||||
|
.or_else(|| std::env::var("OPENAI_API_KEY").ok()),
|
||||||
|
llm_model: std::env::var("KEI_BUDDY_LLM_MODEL").ok(),
|
||||||
};
|
};
|
||||||
run_serve(cfg).await
|
run_serve(cfg).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse `KEI_BUDDY_ALLOWED_CHAT_IDS` CSV → Some(Vec<i64>); empty/missing → None.
|
||||||
|
fn allowed_chat_ids_from_env() -> Option<Vec<i64>> {
|
||||||
|
let raw = std::env::var("KEI_BUDDY_ALLOWED_CHAT_IDS").ok()?;
|
||||||
|
let list: Vec<i64> = raw
|
||||||
|
.split(',')
|
||||||
|
.map(|s| s.trim())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.filter_map(|s| s.parse::<i64>().ok())
|
||||||
|
.collect();
|
||||||
|
if list.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "serve"))]
|
#[cfg(not(feature = "serve"))]
|
||||||
async fn cmd_serve() -> anyhow::Result<()> {
|
async fn cmd_serve() -> anyhow::Result<()> {
|
||||||
anyhow::bail!("kei-buddy was compiled without the `serve` feature. Rebuild with --features serve.");
|
anyhow::bail!("kei-buddy was compiled without the `serve` feature. Rebuild with --features serve.");
|
||||||
|
|
|
||||||
|
|
@ -140,14 +140,20 @@ pub mod openai {
|
||||||
pub struct OpenAiExtractor {
|
pub struct OpenAiExtractor {
|
||||||
pub proxy_url: String,
|
pub proxy_url: String,
|
||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
|
pub model: String,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiExtractor {
|
impl OpenAiExtractor {
|
||||||
pub fn new(proxy_url: String, api_key: String) -> Self {
|
pub fn new(proxy_url: String, api_key: String) -> Self {
|
||||||
|
Self::new_with_model(proxy_url, api_key, DEFAULT_MODEL.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_with_model(proxy_url: String, api_key: String, model: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
proxy_url,
|
proxy_url,
|
||||||
api_key,
|
api_key,
|
||||||
|
model,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -157,7 +163,7 @@ pub mod openai {
|
||||||
impl LlmExtractor for OpenAiExtractor {
|
impl LlmExtractor for OpenAiExtractor {
|
||||||
async fn extract(&self, system: &str, user_text: &str) -> Result<Value, BuddyError> {
|
async fn extract(&self, system: &str, user_text: &str) -> Result<Value, BuddyError> {
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
"model": DEFAULT_MODEL,
|
"model": &self.model,
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_tokens": 200,
|
"max_tokens": 200,
|
||||||
"messages": [
|
"messages": [
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,15 @@ pub struct ServeConfig {
|
||||||
pub db_path: String,
|
pub db_path: String,
|
||||||
pub bot_token: String,
|
pub bot_token: String,
|
||||||
pub webhook_secret: String,
|
pub webhook_secret: String,
|
||||||
|
/// If `Some`, only these chat_ids are processed; others are warn-logged + ignored.
|
||||||
|
/// `None` (or empty) means accept all chat_ids.
|
||||||
|
pub allowed_chat_ids: Option<Vec<i64>>,
|
||||||
|
/// Optional OpenAI-compatible LLM proxy. If set together with `llm_api_key`,
|
||||||
|
/// `run_serve` instantiates `OpenAiExtractor`; otherwise falls back to
|
||||||
|
/// `MockExtractor` with a warning.
|
||||||
|
pub llm_proxy_url: Option<String>,
|
||||||
|
pub llm_api_key: Option<String>,
|
||||||
|
pub llm_model: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Axum state — implements `WebhookContext` for the webhook handler.
|
/// Axum state — implements `WebhookContext` for the webhook handler.
|
||||||
|
|
@ -40,6 +49,8 @@ pub struct BuddyContext<E: LlmExtractor + Send + Sync + 'static> {
|
||||||
pub store: Arc<SqliteBuddyStore>,
|
pub store: Arc<SqliteBuddyStore>,
|
||||||
pub extractor: Arc<E>,
|
pub extractor: Arc<E>,
|
||||||
pub http: reqwest::Client,
|
pub http: reqwest::Client,
|
||||||
|
/// Whitelist of chat_ids; `None` or empty = accept all.
|
||||||
|
pub allowed_chat_ids: Arc<Option<Vec<i64>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
|
impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
|
||||||
|
|
@ -50,6 +61,7 @@ impl<E: LlmExtractor + Send + Sync + 'static> Clone for BuddyContext<E> {
|
||||||
store: Arc::clone(&self.store),
|
store: Arc::clone(&self.store),
|
||||||
extractor: Arc::clone(&self.extractor),
|
extractor: Arc::clone(&self.extractor),
|
||||||
http: self.http.clone(),
|
http: self.http.clone(),
|
||||||
|
allowed_chat_ids: Arc::clone(&self.allowed_chat_ids),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -73,7 +85,18 @@ impl<E: LlmExtractor + Send + Sync + 'static> WebhookContext for BuddyContext<E>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: LlmExtractor + Send + Sync + 'static> BuddyContext<E> {
|
impl<E: LlmExtractor + Send + Sync + 'static> BuddyContext<E> {
|
||||||
|
fn chat_allowed(&self, chat_id: i64) -> bool {
|
||||||
|
match self.allowed_chat_ids.as_ref() {
|
||||||
|
Some(list) if !list.is_empty() => list.contains(&chat_id),
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_text(&self, chat_id: i64, text: String) {
|
async fn handle_text(&self, chat_id: i64, text: String) {
|
||||||
|
if !self.chat_allowed(chat_id) {
|
||||||
|
warn!(chat_id, "chat_id not in whitelist; ignoring");
|
||||||
|
return;
|
||||||
|
}
|
||||||
if let Err(e) = self.process_text(chat_id, &text).await {
|
if let Err(e) = self.process_text(chat_id, &text).await {
|
||||||
error!(chat_id, error = %e, "failed to process text event");
|
error!(chat_id, error = %e, "failed to process text event");
|
||||||
}
|
}
|
||||||
|
|
@ -138,16 +161,48 @@ where
|
||||||
pub async fn run_serve(cfg: ServeConfig) -> anyhow::Result<()> {
|
pub async fn run_serve(cfg: ServeConfig) -> anyhow::Result<()> {
|
||||||
init_tracing();
|
init_tracing();
|
||||||
let store = Arc::new(SqliteBuddyStore::from_path(&cfg.db_path)?);
|
let store = Arc::new(SqliteBuddyStore::from_path(&cfg.db_path)?);
|
||||||
|
let allowed_chat_ids = Arc::new(cfg.allowed_chat_ids);
|
||||||
|
let http = reqwest::Client::new();
|
||||||
|
|
||||||
|
#[cfg(feature = "extractor-openai")]
|
||||||
|
{
|
||||||
|
if let (Some(proxy), Some(key)) = (cfg.llm_proxy_url, cfg.llm_api_key) {
|
||||||
|
let model = cfg
|
||||||
|
.llm_model
|
||||||
|
.unwrap_or_else(|| "gpt-4o-mini".to_string());
|
||||||
|
tracing::info!(model = %model, "using OpenAiExtractor (LiteLLM-compatible)");
|
||||||
|
let extractor = Arc::new(crate::extractor::openai::OpenAiExtractor::new_with_model(
|
||||||
|
proxy, key, model,
|
||||||
|
));
|
||||||
|
return start_listener(cfg.port, BuddyContext {
|
||||||
|
secret: cfg.webhook_secret,
|
||||||
|
bot_token: cfg.bot_token,
|
||||||
|
store,
|
||||||
|
extractor,
|
||||||
|
http,
|
||||||
|
allowed_chat_ids,
|
||||||
|
}).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!("no LLM extractor configured — using MockExtractor (state machine will advance but field-extraction returns empty)");
|
||||||
let extractor = Arc::new(crate::extractor::MockExtractor::new(json!({})));
|
let extractor = Arc::new(crate::extractor::MockExtractor::new(json!({})));
|
||||||
let ctx = BuddyContext {
|
start_listener(cfg.port, BuddyContext {
|
||||||
secret: cfg.webhook_secret,
|
secret: cfg.webhook_secret,
|
||||||
bot_token: cfg.bot_token,
|
bot_token: cfg.bot_token,
|
||||||
store,
|
store,
|
||||||
extractor,
|
extractor,
|
||||||
http: reqwest::Client::new(),
|
http,
|
||||||
};
|
allowed_chat_ids,
|
||||||
|
}).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_listener<E>(port: u16, ctx: BuddyContext<E>) -> anyhow::Result<()>
|
||||||
|
where
|
||||||
|
E: LlmExtractor + Send + Sync + 'static,
|
||||||
|
{
|
||||||
let router = build_router(ctx);
|
let router = build_router(ctx);
|
||||||
let addr = format!("0.0.0.0:{}", cfg.port);
|
let addr = format!("0.0.0.0:{}", port);
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||||
tracing::info!(addr = %addr, "kei-buddy listening");
|
tracing::info!(addr = %addr, "kei-buddy listening");
|
||||||
axum::serve(listener, router).await?;
|
axum::serve(listener, router).await?;
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue