fix(kei-cortex): SSRF + atomic token + body limits + capped reads
Group C — kei-cortex daemon security hardening (post-audit 2026-05-02).
- fal_ssrf.rs (new): validate_fal_url whitelist (fal.ai/.media/.run only).
Applied to upload_url, file_url, status_url, images[0].url,
and download_image. Closes SSRF where compromised fal response
could direct daemon to fetch IMDSv1 (169.254.169.254) and
stream cloud creds.
- fal_pipeline.rs (new): HTTP step functions extracted from fal.rs; fal.rs trimmed
to thin orchestrator (101 LOC, was over 200 LOC limit).
- auth.rs: save_token now writes to <path>.<nanos>.tmp + sync_all + rename. Was
non-atomic OpenOptions truncate+write — crash mid-write produced empty
token file -> bootstrap rotated -> stale clients locked out.
- routes.rs + routes_auth.rs (new): explicit DefaultBodyLimit per route — chat 256 KiB,
tool/apply 11 MiB, pet/interaction 64 KiB, tts 32 KiB.
Bearer auth middleware extracted to routes_auth.
- handlers/chat.rs: validate_body enforces MAX_MESSAGE_CHARS = 50_000. Closed cost
amplification where 1.99 MiB chat message billed 500K tokens
($1.50/turn at Sonnet pricing) on every send.
- anthropic_sse.rs: SseParser MAX_BUF = 1 MiB cap; was unbounded — peer streaming
1 GB without \\n\\n would OOM daemon.
- http_helpers.rs (new): HTTP_CLIENT: Lazy<reqwest::Client> shared across handlers
(was per-request Client::new() => 100-300ms TLS handshake
per chat turn, no HTTP/2 multiplexing, fd leak risk on
macOS TIME_WAIT).
- http_helpers.rs::read_capped: per-response body cap (16 KiB error / 64 MiB success).
Applied to anthropic, anthropic_invoker, elevenlabs,
fal_pipeline. Closed unbounded resp.text() / .bytes()
pattern that compromised upstream could exploit.
Test results: 462 passed; 0 failed (single-threaded). cargo check clean.
2 pre-existing port-binding flakes in openai_loop_wiring tests are unrelated.
Findings consensus: fal SSRF + body-size + bearer-token-atomicity appeared in
Wave-A retest; chat message cap + SSE buf cap appeared in Wave-A only. Would have
been missed by single audit pass.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
f47c087646
commit
03d57c7395
14 changed files with 579 additions and 303 deletions
|
|
@ -1,8 +1,8 @@
|
|||
[package]
|
||||
name = "kei-cortex"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
rust-version = "1.75"
|
||||
edition.workspace = true
|
||||
rust-version.workspace = true
|
||||
description = "Local HTTP daemon exposing cortex state for UI consumption"
|
||||
authors = ["Denis Parfionovich <info@greendragon.info>"]
|
||||
|
||||
|
|
@ -16,31 +16,31 @@ path = "src/lib.rs"
|
|||
|
||||
[dependencies]
|
||||
axum = { version = "0.7", features = ["multipart", "ws"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal", "net", "time", "process", "fs", "io-util", "sync"] }
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
tower = { version = "0.4", features = ["limit", "buffer", "util"] }
|
||||
tower = { workspace = true }
|
||||
tower-http = { version = "0.5", features = ["cors", "trace"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
thiserror = "1"
|
||||
rusqlite = { version = "0.31", features = ["bundled"] }
|
||||
anyhow = "1"
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
rusqlite = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
rand = "0.8"
|
||||
reqwest = { version = "0.12", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false }
|
||||
tokio-stream = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
async-stream = "0.3"
|
||||
toml = "0.8"
|
||||
bytes = "1"
|
||||
tempfile = "3"
|
||||
dashmap = "5"
|
||||
walkdir = "2"
|
||||
toml = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
walkdir = { workspace = true }
|
||||
which = "6"
|
||||
once_cell = "1"
|
||||
regex = "1.10"
|
||||
portable-pty = "0.8"
|
||||
regex = { workspace = true }
|
||||
portable-pty = { workspace = true }
|
||||
# Wave 44a — tool-sandbox hardening
|
||||
shell-words = { workspace = true }
|
||||
url = { workspace = true }
|
||||
|
|
@ -64,4 +64,4 @@ kei-model = { path = "../kei-model" }
|
|||
kei-token-tracker = { path = "../kei-token-tracker" }
|
||||
|
||||
[dev-dependencies]
|
||||
reqwest = { version = "0.12", features = ["json", "blocking", "stream", "rustls-tls"], default-features = false }
|
||||
reqwest = { workspace = true, features = ["blocking"] }
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
//! an SSE error event rather than hanging the client.
|
||||
|
||||
use crate::anthropic_sse::SseParser;
|
||||
use crate::http_helpers::{read_capped, HTTP_CLIENT};
|
||||
use async_stream::try_stream;
|
||||
use futures::stream::Stream;
|
||||
use futures::StreamExt;
|
||||
|
|
@ -35,6 +36,9 @@ pub const IDLE: Duration = Duration::from_secs(30);
|
|||
/// large error page into our logs or client.
|
||||
const BODY_PREVIEW_CAP: usize = 512;
|
||||
|
||||
/// Cap on upstream error body reads via `read_capped` (16 KiB).
|
||||
const ERROR_BODY_CAP: usize = 16 * 1024;
|
||||
|
||||
/// A single turn in the conversation.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct Message {
|
||||
|
|
@ -91,7 +95,10 @@ fn body_to_text_stream(
|
|||
};
|
||||
let Some(chunk) = chunk_opt else { break };
|
||||
let chunk = chunk.map_err(Error::Http)?;
|
||||
for text in parser.push(&chunk) {
|
||||
let texts = parser.push(&chunk).map_err(|_| {
|
||||
Error::Upstream { status: 502, body: "SSE frame exceeds 1MB cap".into() }
|
||||
})?;
|
||||
for text in texts {
|
||||
yield text;
|
||||
}
|
||||
}
|
||||
|
|
@ -114,8 +121,7 @@ async fn send_request(
|
|||
api_key: &str,
|
||||
body: &serde_json::Value,
|
||||
) -> Result<reqwest::Response, Error> {
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
let resp = HTTP_CLIENT
|
||||
.post(endpoint().as_ref())
|
||||
.header("x-api-key", api_key)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
|
|
@ -142,7 +148,8 @@ async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, Erro
|
|||
if code == 503 || code == 529 {
|
||||
return Err(Error::ServiceUnavailable);
|
||||
}
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default();
|
||||
let body = String::from_utf8_lossy(&raw).into_owned();
|
||||
Err(Error::Upstream {
|
||||
status: code,
|
||||
body: truncate(&body, BODY_PREVIEW_CAP),
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
//! multi-block content (`text` + `tool_use`) into `Vec<ContentBlock>`.
|
||||
|
||||
use crate::anthropic::{default_model, endpoint, API_VERSION};
|
||||
use crate::http_helpers::{read_capped, HTTP_CLIENT};
|
||||
use crate::tool::loop_driver::{
|
||||
ContentBlock, ConversationMessage, ModelInvoker, ModelTurn, TokenUsage,
|
||||
};
|
||||
|
|
@ -42,9 +43,11 @@ async fn invoke(
|
|||
Ok(Err(e)) => return Err(format!("anthropic request: {e}")),
|
||||
Err(_) => return Err("anthropic request: timeout".into()),
|
||||
};
|
||||
let raw: Value = resp
|
||||
.json()
|
||||
const SUCCESS_CAP: usize = 64 * 1024 * 1024; // 64 MiB
|
||||
let bytes = read_capped(resp, SUCCESS_CAP)
|
||||
.await
|
||||
.map_err(|e| format!("anthropic body read: {e}"))?;
|
||||
let raw: Value = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| format!("anthropic body json: {e}"))?;
|
||||
parse_turn(&raw)
|
||||
}
|
||||
|
|
@ -103,7 +106,7 @@ fn render_assistant_blocks(blocks: &[ContentBlock]) -> Vec<Value> {
|
|||
|
||||
/// Fire the POST request and surface 4xx/5xx as a string error.
|
||||
async fn send(api_key: &str, body: &Value) -> Result<reqwest::Response, reqwest::Error> {
|
||||
let resp = reqwest::Client::new()
|
||||
let resp = HTTP_CLIENT
|
||||
.post(endpoint().as_ref())
|
||||
.header("x-api-key", api_key)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,14 @@ struct Delta {
|
|||
text: Option<String>,
|
||||
}
|
||||
|
||||
/// Maximum buffer size per SSE frame — guards against a runaway upstream
|
||||
/// that never emits a `\n\n` frame boundary.
|
||||
const MAX_BUF: usize = 1 * 1024 * 1024; // 1 MiB
|
||||
|
||||
/// Error returned by `SseParser::push` when the buffer cap is exceeded.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SseBufOverflow;
|
||||
|
||||
/// Incremental SSE parser — SSE frames are separated by `\n\n`.
|
||||
///
|
||||
/// We buffer partial chunks across `push` calls and return every extracted
|
||||
|
|
@ -37,8 +45,14 @@ impl SseParser {
|
|||
}
|
||||
|
||||
/// Consume a byte chunk, return every text delta completed in this push.
|
||||
pub fn push(&mut self, chunk: &Bytes) -> Vec<String> {
|
||||
/// Returns `Err(SseBufOverflow)` if the internal buffer exceeds `MAX_BUF`
|
||||
/// before a complete frame arrives; caller should abort the stream.
|
||||
pub fn push(&mut self, chunk: &Bytes) -> Result<Vec<String>, SseBufOverflow> {
|
||||
self.buf.push_str(&String::from_utf8_lossy(chunk));
|
||||
if self.buf.len() > MAX_BUF {
|
||||
self.buf.clear();
|
||||
return Err(SseBufOverflow);
|
||||
}
|
||||
let mut out = Vec::new();
|
||||
while let Some(idx) = self.buf.find("\n\n") {
|
||||
let frame: String = self.buf.drain(..idx + 2).collect();
|
||||
|
|
@ -46,7 +60,7 @@ impl SseParser {
|
|||
out.push(text);
|
||||
}
|
||||
}
|
||||
out
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -89,7 +103,17 @@ mod tests {
|
|||
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"ab\"}",
|
||||
);
|
||||
let part2 = Bytes::from("}\n\n");
|
||||
assert!(p.push(&part1).is_empty());
|
||||
assert_eq!(p.push(&part2), vec!["ab".to_string()]);
|
||||
assert!(p.push(&part1).unwrap().is_empty());
|
||||
assert_eq!(p.push(&part2).unwrap(), vec!["ab".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parser_rejects_oversized_buf() {
|
||||
let mut p = SseParser::new();
|
||||
// Push just over MAX_BUF bytes without a frame boundary.
|
||||
let big = Bytes::from(vec![b'x'; MAX_BUF + 1]);
|
||||
assert!(p.push(&big).is_err());
|
||||
// Buffer must be cleared so subsequent calls don't accumulate.
|
||||
assert!(p.push(&Bytes::from("\n\n")).unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,30 +103,51 @@ pub fn tokens_match(expected: &str, got: &str) -> bool {
|
|||
diff == 0
|
||||
}
|
||||
|
||||
/// Build a unique temp path next to `path`: `<path>.<nanos>.tmp`.
|
||||
fn tmp_path(path: &Path) -> std::path::PathBuf {
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_nanos())
|
||||
.unwrap_or(0);
|
||||
let name = path.file_name().map(|n| {
|
||||
let mut s = n.to_os_string();
|
||||
s.push(format!(".{ts}.tmp"));
|
||||
s
|
||||
});
|
||||
match (path.parent(), name) {
|
||||
(Some(p), Some(n)) => p.join(n),
|
||||
_ => path.with_extension("tmp"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
let tmp = tmp_path(path);
|
||||
let mut f = fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(path)?;
|
||||
.open(&tmp)?;
|
||||
f.write_all(bytes)?;
|
||||
f.sync_all()?;
|
||||
Ok(())
|
||||
drop(f);
|
||||
fs::rename(&tmp, path)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
|
||||
let tmp = tmp_path(path);
|
||||
let mut f = fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(path)?;
|
||||
.open(&tmp)?;
|
||||
f.write_all(bytes)?;
|
||||
f.sync_all()?;
|
||||
Ok(())
|
||||
drop(f);
|
||||
fs::rename(&tmp, path)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
//! reference, never hardcoded). We never log the full text; only its length
|
||||
//! and the response status / byte count.
|
||||
|
||||
use crate::http_helpers::{read_capped, HTTP_CLIENT};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Errors surfaced to the caller; handlers map them onto HTTP codes.
|
||||
|
|
@ -29,12 +30,15 @@ const TTS_BASE: &str = "https://api.elevenlabs.io/v1/text-to-speech";
|
|||
const BUDGET: Duration = Duration::from_secs(60);
|
||||
const BODY_PREVIEW_CAP: usize = 512;
|
||||
const MODEL_ID: &str = "eleven_turbo_v2_5";
|
||||
/// Cap on successful audio response bodies (64 MiB).
|
||||
const AUDIO_BODY_CAP: usize = 64 * 1024 * 1024;
|
||||
/// Cap on error response bodies (16 KiB).
|
||||
const ERROR_BODY_CAP: usize = 16 * 1024;
|
||||
|
||||
/// Synthesize speech for `text` using the given `voice_id`. Returns mp3 bytes.
|
||||
pub async fn synthesize(voice_id: &str, text: &str) -> Result<Vec<u8>, Error> {
|
||||
let key = std::env::var("ELEVENLABS_API_KEY").map_err(|_| Error::NoApiKey)?;
|
||||
let client = reqwest::Client::new();
|
||||
let fut = call_tts(&client, &key, voice_id, text);
|
||||
let fut = call_tts(&key, voice_id, text);
|
||||
match tokio::time::timeout(BUDGET, fut).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => Err(Error::Timeout),
|
||||
|
|
@ -44,14 +48,13 @@ pub async fn synthesize(voice_id: &str, text: &str) -> Result<Vec<u8>, Error> {
|
|||
/// POST the JSON body and collect the audio bytes. Split from `synthesize`
|
||||
/// so the timeout wrapper stays a thin shell.
|
||||
async fn call_tts(
|
||||
client: &reqwest::Client,
|
||||
key: &str,
|
||||
voice_id: &str,
|
||||
text: &str,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
let url = format!("{TTS_BASE}/{voice_id}");
|
||||
let body = build_body(text);
|
||||
let resp = client
|
||||
let resp = HTTP_CLIENT
|
||||
.post(&url)
|
||||
.header("xi-api-key", key)
|
||||
.header("Content-Type", "application/json")
|
||||
|
|
@ -77,13 +80,20 @@ fn build_body(text: &str) -> serde_json::Value {
|
|||
|
||||
/// Turn an ElevenLabs response into raw mp3 bytes; map non-2xx to `Upstream`.
|
||||
async fn decode_audio(resp: reqwest::Response) -> Result<Vec<u8>, Error> {
|
||||
use crate::http_helpers::CappedReadError;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
let truncated = truncate(&body, BODY_PREVIEW_CAP);
|
||||
return Err(Error::Upstream(status.as_u16(), truncated));
|
||||
let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default();
|
||||
let body = String::from_utf8_lossy(&raw).into_owned();
|
||||
return Err(Error::Upstream(status.as_u16(), truncate(&body, BODY_PREVIEW_CAP)));
|
||||
}
|
||||
match read_capped(resp, AUDIO_BODY_CAP).await {
|
||||
Ok(bytes) => Ok(bytes),
|
||||
Err(CappedReadError::TooLarge) => {
|
||||
Err(Error::Upstream(502, "audio body exceeded 64 MiB cap".into()))
|
||||
}
|
||||
Err(CappedReadError::Http(e)) => Err(Error::Http(e)),
|
||||
}
|
||||
Ok(resp.bytes().await?.to_vec())
|
||||
}
|
||||
|
||||
/// Cap a string at `max` bytes on a char boundary. Used for error previews
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
//! fal.ai Flux 2 Pro client — stylize a portrait into an anime frame.
|
||||
//!
|
||||
//! The public surface is a single async function `stylize(source_png, style)`.
|
||||
//! Internally we: (1) upload the source image to fal storage via the
|
||||
//! two-step `/storage/upload/initiate` handshake, (2) POST a generation
|
||||
//! request to the queue endpoint, (3) poll status until `COMPLETED`, (4)
|
||||
//! download the first image in the result. The whole pipeline has a 60-second
|
||||
//! wall-clock budget, past which we return `Error::Timeout` so the handler
|
||||
//! can surface an HTTP 504.
|
||||
//! Public surface: `stylize(source_png, style)`.
|
||||
//!
|
||||
//! `FAL_KEY` is read from the environment on every call; the daemon does not
|
||||
//! cache it because the user may rotate it without restarting.
|
||||
//! Pipeline: (1) upload source image to fal storage, (2) POST generation
|
||||
//! request to the queue endpoint, (3) poll status until `COMPLETED`, (4)
|
||||
//! download the first result image. 60-second wall-clock budget; past that
|
||||
//! the function returns `Error::Timeout` so the handler can surface HTTP 504.
|
||||
//!
|
||||
//! `FAL_KEY` is read from the environment on every call (env-rotation-friendly,
|
||||
//! per RULE 0.8 — never hardcoded, never cached in daemon state).
|
||||
//!
|
||||
//! HTTP pipeline steps live in `fal_pipeline`; SSRF mitigation in `fal_ssrf`.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
|
|
@ -31,7 +32,7 @@ impl Style {
|
|||
}
|
||||
}
|
||||
|
||||
fn prompt(self) -> &'static str {
|
||||
pub(crate) fn prompt(self) -> &'static str {
|
||||
match self {
|
||||
Self::Cute => CUTE_PROMPT,
|
||||
Self::Cool => COOL_PROMPT,
|
||||
|
|
@ -57,150 +58,44 @@ pub enum Error {
|
|||
BadShape(&'static str),
|
||||
#[error("fal polling exceeded 60s budget")]
|
||||
Timeout,
|
||||
#[error("non-fal URL rejected (SSRF): {0}")]
|
||||
NonFalUrl(String),
|
||||
}
|
||||
|
||||
const UPLOAD_INITIATE_URL: &str = "https://rest.alpha.fal.ai/storage/upload/initiate";
|
||||
const GENERATION_URL: &str = "https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra";
|
||||
const BUDGET: Duration = Duration::from_secs(60);
|
||||
const POLL_INTERVAL: Duration = Duration::from_millis(800);
|
||||
const BODY_PREVIEW_CAP: usize = 512;
|
||||
pub(crate) const BUDGET: Duration = Duration::from_secs(60);
|
||||
pub(crate) const POLL_INTERVAL: Duration = Duration::from_millis(800);
|
||||
pub(crate) const BODY_PREVIEW_CAP: usize = 512;
|
||||
pub(crate) const ERROR_BODY_CAP: usize = 16 * 1024;
|
||||
pub(crate) const IMAGE_BODY_CAP: usize = 64 * 1024 * 1024;
|
||||
|
||||
/// Stylize `source_png` into an anime portrait. Returns the raw PNG bytes
|
||||
/// of the generated image (caller writes them to disk).
|
||||
/// Stylize `source_png` into an anime portrait. Returns raw PNG bytes.
|
||||
pub async fn stylize(source_png: &[u8], style: Style) -> Result<Vec<u8>, Error> {
|
||||
let key = std::env::var("FAL_KEY").map_err(|_| Error::NoApiKey)?;
|
||||
let client = reqwest::Client::new();
|
||||
let deadline = tokio::time::Instant::now() + BUDGET;
|
||||
let uploaded_url = upload_image(&client, &key, source_png).await?;
|
||||
let status_url = enqueue(&client, &key, &uploaded_url, style).await?;
|
||||
let result_url = poll_until_done(&client, &key, &status_url, deadline).await?;
|
||||
download_image(&client, &key, &result_url).await
|
||||
}
|
||||
|
||||
/// Step 1 — ask fal storage for a signed PUT URL, then PUT the image to it.
|
||||
async fn upload_image(client: &reqwest::Client, key: &str, bytes: &[u8]) -> Result<String, Error> {
|
||||
let body = serde_json::json!({ "file_name": "portrait.png", "content_type": "image/png" });
|
||||
let resp = client
|
||||
.post(UPLOAD_INITIATE_URL)
|
||||
.header("Authorization", format!("Key {key}"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let json = decode_json(resp).await?;
|
||||
let upload_url = json.get("upload_url").and_then(|v| v.as_str()).ok_or(Error::BadShape("upload_url"))?;
|
||||
let file_url = json.get("file_url").and_then(|v| v.as_str()).ok_or(Error::BadShape("file_url"))?;
|
||||
let put = client.put(upload_url).header("Content-Type", "image/png").body(bytes.to_vec()).send().await?;
|
||||
if !put.status().is_success() {
|
||||
return Err(Error::BadStatus(put.status().as_u16(), "PUT upload failed".into()));
|
||||
}
|
||||
Ok(file_url.to_string())
|
||||
}
|
||||
|
||||
/// Step 2 — POST the generation request, return the status poll URL.
|
||||
async fn enqueue(client: &reqwest::Client, key: &str, image_url: &str, style: Style) -> Result<String, Error> {
|
||||
let body = serde_json::json!({
|
||||
"image_url": image_url,
|
||||
"prompt": style.prompt(),
|
||||
"strength": 0.65,
|
||||
"enable_safety_checker": true,
|
||||
});
|
||||
let resp = client
|
||||
.post(GENERATION_URL)
|
||||
.header("Authorization", format!("Key {key}"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let json = decode_json(resp).await?;
|
||||
json.get("status_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::to_string)
|
||||
.ok_or(Error::BadShape("status_url"))
|
||||
}
|
||||
|
||||
/// Step 3 — poll the status URL until `COMPLETED`, or give up at deadline.
|
||||
async fn poll_until_done(client: &reqwest::Client, key: &str, status_url: &str, deadline: tokio::time::Instant) -> Result<String, Error> {
|
||||
loop {
|
||||
if tokio::time::Instant::now() >= deadline {
|
||||
return Err(Error::Timeout);
|
||||
}
|
||||
let resp = client.get(status_url).header("Authorization", format!("Key {key}")).send().await?;
|
||||
let json = decode_json(resp).await?;
|
||||
let status = json.get("status").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if status == "COMPLETED" {
|
||||
return extract_first_image_url(&json);
|
||||
}
|
||||
if status == "FAILED" || status == "ERROR" {
|
||||
return Err(Error::BadStatus(502, format!("fal reported {status}")));
|
||||
}
|
||||
tokio::time::sleep(POLL_INTERVAL).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 4 — download the PNG bytes from the fal CDN URL.
|
||||
async fn download_image(client: &reqwest::Client, key: &str, url: &str) -> Result<Vec<u8>, Error> {
|
||||
let resp = client.get(url).header("Authorization", format!("Key {key}")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::BadStatus(resp.status().as_u16(), "download failed".into()));
|
||||
}
|
||||
Ok(resp.bytes().await?.to_vec())
|
||||
}
|
||||
|
||||
/// Dig out `images[0].url` from the completed status payload.
|
||||
fn extract_first_image_url(json: &serde_json::Value) -> Result<String, Error> {
|
||||
json.get("images")
|
||||
.and_then(|a| a.as_array())
|
||||
.and_then(|a| a.first())
|
||||
.and_then(|o| o.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::to_string)
|
||||
.ok_or(Error::BadShape("images[0].url"))
|
||||
}
|
||||
|
||||
/// Decode a fal JSON response, turning non-2xx into `BadStatus` with body.
|
||||
/// Body capped at `BODY_PREVIEW_CAP` so a large upstream error page cannot
|
||||
/// propagate through our logs or error channel.
|
||||
async fn decode_json(resp: reqwest::Response) -> Result<serde_json::Value, Error> {
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(Error::BadStatus(
|
||||
status.as_u16(),
|
||||
truncate(&body, BODY_PREVIEW_CAP),
|
||||
));
|
||||
}
|
||||
Ok(resp.json::<serde_json::Value>().await?)
|
||||
}
|
||||
|
||||
/// Cap a string at `max` bytes on a char boundary. Keeps fal error previews
|
||||
/// bounded regardless of what upstream sent back.
|
||||
fn truncate(s: &str, max: usize) -> String {
|
||||
if s.len() <= max {
|
||||
return s.to_string();
|
||||
}
|
||||
let mut end = max;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
s[..end].to_string()
|
||||
let uploaded_url = crate::fal_pipeline::upload_image(&key, source_png).await?;
|
||||
let status_url = crate::fal_pipeline::enqueue(&key, &uploaded_url, style).await?;
|
||||
let result_url =
|
||||
crate::fal_pipeline::poll_until_done(&key, &status_url, deadline, POLL_INTERVAL)
|
||||
.await?;
|
||||
crate::fal_pipeline::download_image(&key, &result_url).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn truncate_caps_long_strings() {
|
||||
let long = "a".repeat(10_000);
|
||||
assert_eq!(truncate(&long, 256).len(), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_leaves_short_strings() {
|
||||
assert_eq!(truncate("hi", 256), "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn style_from_wire_defaults_to_cute() {
|
||||
assert!(matches!(Style::from_wire("wat"), Style::Cute));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn style_from_wire_cool() {
|
||||
assert!(matches!(Style::from_wire("anime-cool"), Style::Cool));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn style_from_wire_studious() {
|
||||
assert!(matches!(Style::from_wire("anime-studious"), Style::Studious));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
173
_primitives/_rust/kei-cortex/src/fal_pipeline.rs
Normal file
173
_primitives/_rust/kei-cortex/src/fal_pipeline.rs
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
//! fal.ai HTTP pipeline steps — upload, enqueue, poll, download.
|
||||
//!
|
||||
//! Each step is a single HTTPS round-trip. All use the shared `HTTP_CLIENT`
|
||||
//! from `http_helpers` (no per-request `reqwest::Client::new()`).
|
||||
//! Response bodies are read through `read_capped` so a runaway upstream
|
||||
//! cannot allocate unbounded memory.
|
||||
|
||||
use crate::fal::{Error, Style};
|
||||
use crate::fal::{BODY_PREVIEW_CAP, ERROR_BODY_CAP, IMAGE_BODY_CAP};
|
||||
use crate::fal_ssrf::validate_fal_url;
|
||||
use crate::http_helpers::{read_capped, CappedReadError, HTTP_CLIENT};
|
||||
|
||||
pub(crate) const UPLOAD_INITIATE_URL: &str =
|
||||
"https://rest.alpha.fal.ai/storage/upload/initiate";
|
||||
pub(crate) const GENERATION_URL: &str =
|
||||
"https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra";
|
||||
|
||||
/// Step 1 — ask fal storage for a signed PUT URL, then PUT the image to it.
|
||||
pub(crate) async fn upload_image(key: &str, bytes: &[u8]) -> Result<String, Error> {
|
||||
let body =
|
||||
serde_json::json!({ "file_name": "portrait.png", "content_type": "image/png" });
|
||||
let resp = HTTP_CLIENT
|
||||
.post(UPLOAD_INITIATE_URL)
|
||||
.header("Authorization", format!("Key {key}"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let json = decode_json(resp).await?;
|
||||
let upload_url = json
|
||||
.get("upload_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or(Error::BadShape("upload_url"))?;
|
||||
let file_url = json
|
||||
.get("file_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or(Error::BadShape("file_url"))?;
|
||||
validate_fal_url(upload_url)?;
|
||||
validate_fal_url(file_url)?;
|
||||
HTTP_CLIENT
|
||||
.put(upload_url)
|
||||
.header("Content-Type", "image/png")
|
||||
.body(bytes.to_vec())
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()
|
||||
.map_err(Error::Http)?;
|
||||
Ok(file_url.to_string())
|
||||
}
|
||||
|
||||
/// Step 2 — POST the generation request, return the validated status URL.
|
||||
pub(crate) async fn enqueue(
|
||||
key: &str,
|
||||
image_url: &str,
|
||||
style: Style,
|
||||
) -> Result<String, Error> {
|
||||
let body = serde_json::json!({
|
||||
"image_url": image_url,
|
||||
"prompt": style.prompt(),
|
||||
"strength": 0.65,
|
||||
"enable_safety_checker": true,
|
||||
});
|
||||
let resp = HTTP_CLIENT
|
||||
.post(GENERATION_URL)
|
||||
.header("Authorization", format!("Key {key}"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
let json = decode_json(resp).await?;
|
||||
let status_url = json
|
||||
.get("status_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::to_string)
|
||||
.ok_or(Error::BadShape("status_url"))?;
|
||||
validate_fal_url(&status_url)?;
|
||||
Ok(status_url)
|
||||
}
|
||||
|
||||
/// Step 3 — poll `status_url` until `COMPLETED`, or give up at `deadline`.
|
||||
pub(crate) async fn poll_until_done(
|
||||
key: &str,
|
||||
status_url: &str,
|
||||
deadline: tokio::time::Instant,
|
||||
poll_interval: std::time::Duration,
|
||||
) -> Result<String, Error> {
|
||||
loop {
|
||||
if tokio::time::Instant::now() >= deadline {
|
||||
return Err(Error::Timeout);
|
||||
}
|
||||
let resp = HTTP_CLIENT
|
||||
.get(status_url)
|
||||
.header("Authorization", format!("Key {key}"))
|
||||
.send()
|
||||
.await?;
|
||||
let json = decode_json(resp).await?;
|
||||
match json.get("status").and_then(|v| v.as_str()).unwrap_or("") {
|
||||
"COMPLETED" => return extract_first_image_url(&json),
|
||||
s @ ("FAILED" | "ERROR") => {
|
||||
return Err(Error::BadStatus(502, format!("fal reported {s}")));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Step 4 — download the PNG bytes from a validated fal CDN URL.
|
||||
pub(crate) async fn download_image(key: &str, url: &str) -> Result<Vec<u8>, Error> {
|
||||
validate_fal_url(url)?;
|
||||
let resp = HTTP_CLIENT
|
||||
.get(url)
|
||||
.header("Authorization", format!("Key {key}"))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(Error::BadStatus(resp.status().as_u16(), "download failed".into()));
|
||||
}
|
||||
match read_capped(resp, IMAGE_BODY_CAP).await {
|
||||
Ok(b) => Ok(b),
|
||||
Err(CappedReadError::TooLarge) => {
|
||||
Err(Error::BadStatus(502, "image body exceeded 64 MiB cap".into()))
|
||||
}
|
||||
Err(CappedReadError::Http(e)) => Err(Error::Http(e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Dig out `images[0].url` from the completed status payload, then validate.
|
||||
pub(crate) fn extract_first_image_url(json: &serde_json::Value) -> Result<String, Error> {
|
||||
let url = json
|
||||
.get("images")
|
||||
.and_then(|a| a.as_array())
|
||||
.and_then(|a| a.first())
|
||||
.and_then(|o| o.get("url"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::to_string)
|
||||
.ok_or(Error::BadShape("images[0].url"))?;
|
||||
validate_fal_url(&url)?;
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
/// Decode a fal response into JSON, capping error/success bodies.
|
||||
pub(crate) async fn decode_json(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<serde_json::Value, Error> {
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default();
|
||||
let body = String::from_utf8_lossy(&raw).into_owned();
|
||||
return Err(Error::BadStatus(
|
||||
status.as_u16(),
|
||||
truncate(&body, BODY_PREVIEW_CAP),
|
||||
));
|
||||
}
|
||||
let raw = match read_capped(resp, IMAGE_BODY_CAP).await {
|
||||
Ok(b) => b,
|
||||
Err(CappedReadError::TooLarge) => {
|
||||
return Err(Error::BadStatus(502, "response body exceeded 64 MiB cap".into()));
|
||||
}
|
||||
Err(CappedReadError::Http(e)) => return Err(Error::Http(e)),
|
||||
};
|
||||
serde_json::from_slice(&raw).map_err(|e| Error::BadStatus(502, e.to_string()))
|
||||
}
|
||||
|
||||
/// Cap a string at `max` bytes on a char boundary.
|
||||
fn truncate(s: &str, max: usize) -> String {
|
||||
if s.len() <= max {
|
||||
return s.to_string();
|
||||
}
|
||||
let mut end = max;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
s[..end].to_string()
|
||||
}
|
||||
67
_primitives/_rust/kei-cortex/src/fal_ssrf.rs
Normal file
67
_primitives/_rust/kei-cortex/src/fal_ssrf.rs
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
//! SSRF mitigation for fal.ai outbound requests.
|
||||
//!
|
||||
//! `validate_fal_url` rejects any URL whose scheme is not `https://` or
|
||||
//! whose host does not match the fal.ai domain allowlist before the URL
|
||||
//! is ever passed to `reqwest`.
|
||||
|
||||
use crate::fal::Error;
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
|
||||
/// Compiled allowlist for fal.ai host names.
|
||||
/// Accepted: `<one-or-more-subdomains>.fal.ai`, `.fal.media`, `.fal.run`.
|
||||
/// Example: `rest.alpha.fal.ai`, `queue.fal.run`, `cdn.fal.media`.
|
||||
static FAL_HOST_RE: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"^([a-z0-9-]+\.)+fal\.(ai|media|run)$").unwrap());
|
||||
|
||||
/// Reject URLs that do not match the fal.ai HTTPS allowlist.
|
||||
///
|
||||
/// Returns `Err(Error::NonFalUrl)` when:
|
||||
/// - the scheme is not `https://`, or
|
||||
/// - the host does not match `^[a-z0-9-]+\.fal\.(ai|media|run)$`.
|
||||
pub(crate) fn validate_fal_url(url: &str) -> Result<(), Error> {
|
||||
let stripped = url.strip_prefix("https://").ok_or_else(|| {
|
||||
Error::NonFalUrl(url.chars().take(80).collect())
|
||||
})?;
|
||||
let host = stripped.split('/').next().unwrap_or("");
|
||||
if !FAL_HOST_RE.is_match(host) {
|
||||
return Err(Error::NonFalUrl(url.chars().take(80).collect()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn accepts_valid_fal_urls() {
|
||||
assert!(validate_fal_url("https://rest.alpha.fal.ai/storage/upload/initiate").is_ok());
|
||||
assert!(validate_fal_url("https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra").is_ok());
|
||||
assert!(validate_fal_url("https://cdn.fal.media/image.png").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_http_scheme() {
|
||||
assert!(matches!(
|
||||
validate_fal_url("http://queue.fal.run/foo"),
|
||||
Err(Error::NonFalUrl(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_non_fal_domain() {
|
||||
assert!(matches!(
|
||||
validate_fal_url("https://evil.example.com/steal"),
|
||||
Err(Error::NonFalUrl(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_ssrf_internal_ip() {
|
||||
assert!(matches!(
|
||||
validate_fal_url("https://169.254.169.254/latest/meta-data/"),
|
||||
Err(Error::NonFalUrl(_))
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
@ -97,10 +97,20 @@ fn validate_provider(state: &AppState, name: &str) -> Result<(), AppError> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Character ceiling for chat messages. Prevents runaway prompt injection
|
||||
/// and upstream token cost abuse.
|
||||
const MAX_MESSAGE_CHARS: usize = 50_000;
|
||||
|
||||
fn validate_body(req: &ChatRequest) -> Result<(), AppError> {
|
||||
if req.message.is_empty() {
|
||||
return Err(AppError::BadRequest("message is empty".into()));
|
||||
}
|
||||
let chars = req.message.chars().count();
|
||||
if chars > MAX_MESSAGE_CHARS {
|
||||
return Err(AppError::PayloadTooLarge(format!(
|
||||
"{chars} chars > {MAX_MESSAGE_CHARS}"
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
57
_primitives/_rust/kei-cortex/src/http_helpers.rs
Normal file
57
_primitives/_rust/kei-cortex/src/http_helpers.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
//! Shared HTTP utilities: a process-wide `reqwest::Client` and a capped
|
||||
//! response-body reader.
|
||||
//!
|
||||
//! A single `reqwest::Client` is reused for all outbound calls to avoid
|
||||
//! exhausting OS connection-table entries when the daemon is under load.
|
||||
//! The client is initialized once via `once_cell::sync::Lazy`.
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
/// The process-wide HTTP client. Shared by `anthropic`, `anthropic_invoker`,
|
||||
/// `elevenlabs`, and `fal` so that connection pooling is effective.
|
||||
pub static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.build()
|
||||
.expect("reqwest::Client::builder().build() — no TLS config error expected")
|
||||
});
|
||||
|
||||
/// Error returned when a response body exceeds the caller-supplied cap.
|
||||
#[derive(Debug)]
|
||||
pub struct BodyTooLarge;
|
||||
|
||||
impl std::fmt::Display for BodyTooLarge {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "response body exceeded size cap")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for BodyTooLarge {}
|
||||
|
||||
/// Read a response body up to `max_bytes`. Returns `Err(BodyTooLarge)` if
|
||||
/// the stream exceeds the cap; the partial buffer is discarded on overflow
|
||||
/// so the caller does not accidentally use truncated data.
|
||||
pub async fn read_capped(
|
||||
resp: reqwest::Response,
|
||||
max_bytes: usize,
|
||||
) -> Result<Vec<u8>, CappedReadError> {
|
||||
use futures::StreamExt;
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
let mut stream = resp.bytes_stream();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.map_err(CappedReadError::Http)?;
|
||||
if bytes.len() + chunk.len() > max_bytes {
|
||||
return Err(CappedReadError::TooLarge);
|
||||
}
|
||||
bytes.extend_from_slice(&chunk);
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
/// Combined error for `read_capped`.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum CappedReadError {
|
||||
#[error("response body exceeded size cap")]
|
||||
TooLarge,
|
||||
#[error("http stream error: {0}")]
|
||||
Http(reqwest::Error),
|
||||
}
|
||||
|
|
@ -20,11 +20,15 @@ pub mod context;
|
|||
pub mod elevenlabs;
|
||||
pub mod error;
|
||||
pub mod fal;
|
||||
pub(crate) mod fal_pipeline;
|
||||
pub(crate) mod fal_ssrf;
|
||||
pub mod handlers;
|
||||
pub mod http_helpers;
|
||||
pub mod persona;
|
||||
pub mod whisper_local;
|
||||
pub mod rig_clone;
|
||||
pub mod routes;
|
||||
pub(crate) mod routes_auth;
|
||||
pub mod sentiment;
|
||||
pub mod state;
|
||||
pub mod tool;
|
||||
|
|
|
|||
|
|
@ -1,59 +1,47 @@
|
|||
//! Router assembly + bearer-token middleware + CORS layer.
|
||||
//! Router assembly + CORS layer.
|
||||
//!
|
||||
//! `/healthz` is mounted OUTSIDE the auth middleware so monitors can hit it
|
||||
//! without a token. Everything under `/api` goes through `require_bearer`.
|
||||
//! without a token. Everything under `/api` goes through `require_bearer`
|
||||
//! (defined in `routes_auth`).
|
||||
//!
|
||||
//! Per-route concurrency caps protect us from a runaway client draining our
|
||||
//! upstream budget — `fal.ai` in particular bills per run, so we cap
|
||||
//! `/portrait/stylize` at 2 concurrent installs system-wide. Other expensive
|
||||
//! routes (`/tts`, `/stt`, `/chat`) get matching caps tuned to their bottleneck.
|
||||
|
||||
use crate::auth::tokens_match;
|
||||
use crate::error::AppError;
|
||||
use crate::handlers::{
|
||||
chat, fs_list, health, ledger, memory, pet, portrait, stt, summary, term, tool_apply, tts,
|
||||
usage,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use axum::error_handling::HandleErrorLayer;
|
||||
use axum::extract::{DefaultBodyLimit, Request, State};
|
||||
use axum::extract::DefaultBodyLimit;
|
||||
use axum::http::{header, HeaderValue, Method, StatusCode};
|
||||
use axum::middleware::{self, Next};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::middleware;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::Router;
|
||||
|
||||
// OpenAI-compatible /v1/* surface. Lives under `routes/openai/` as a
|
||||
// sibling tree to this file; declared via `#[path]` so the directory
|
||||
// can coexist with `routes.rs` without becoming an inline `routes/mod.rs`.
|
||||
#[path = "routes/openai/mod.rs"]
|
||||
pub mod openai;
|
||||
|
||||
use tower::buffer::BufferLayer;
|
||||
use tower::limit::ConcurrencyLimitLayer;
|
||||
use tower::{BoxError, ServiceBuilder};
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
/// Upper bound on the `/portrait/stylize` multipart body. The handler enforces
|
||||
/// the stricter 10 MiB cap on the `file` field; this is just the pre-parse
|
||||
/// gate that Axum applies before it can see individual fields.
|
||||
const PORTRAIT_BODY_LIMIT: usize = 12 * 1024 * 1024;
|
||||
// --- Body limits (per-route pre-parse gates) --------------------------------
|
||||
const PORTRAIT_BODY_LIMIT: usize = 12 * 1024 * 1024; // 12 MiB (handler re-checks 10)
|
||||
const STT_BODY_LIMIT: usize = 26 * 1024 * 1024; // 26 MiB (handler re-checks 25)
|
||||
const CHAT_BODY_LIMIT: usize = 256 * 1024; // 256 KiB
|
||||
const TOOL_APPLY_BODY_LIMIT: usize = 11 * 1024 * 1024; // 11 MiB (handler checks 10)
|
||||
const INTERACTION_BODY_LIMIT: usize = 64 * 1024; // 64 KiB
|
||||
const TTS_BODY_LIMIT: usize = 32 * 1024; // 32 KiB
|
||||
|
||||
/// Upper bound on the `/stt` multipart body. The handler enforces the
|
||||
/// stricter 25 MiB cap on the `audio` field itself; this is the slack so
|
||||
/// that field headers + form overhead do not trip the pre-parse gate.
|
||||
const STT_BODY_LIMIT: usize = 26 * 1024 * 1024;
|
||||
|
||||
/// Max concurrent Flux stylize runs system-wide. fal.ai bills per run.
|
||||
// --- Concurrency budgets ----------------------------------------------------
|
||||
const PORTRAIT_CONCURRENCY: usize = 2;
|
||||
|
||||
/// Max concurrent ElevenLabs TTS calls.
|
||||
const TTS_CONCURRENCY: usize = 4;
|
||||
|
||||
/// Max concurrent whisper worker runs. CPU-bound, so the cap matches a
|
||||
/// conservative small-laptop core count.
|
||||
const STT_CONCURRENCY: usize = 2;
|
||||
|
||||
/// Max concurrent Anthropic chat streams.
|
||||
const CHAT_CONCURRENCY: usize = 8;
|
||||
|
||||
/// Build the top-level router. `cors_origin` must have been validated at
|
||||
|
|
@ -62,43 +50,11 @@ pub fn build_router(state: AppState) -> Router {
|
|||
let cors = build_cors(state.config().cors_origin.as_str())
|
||||
.expect("cors_origin must be valid — validated in AppConfig::new");
|
||||
|
||||
// Per-route granular caps proved fragile with axum 0.7's MethodRouter
|
||||
// layer bounds (HandleErrorLayer + ConcurrencyLimitLayer service is not
|
||||
// `Clone` in a way that layer() accepts). Apply a single router-wide cap
|
||||
// via route_layer — it wraps every inner route uniformly without the
|
||||
// per-method error-type headache. The cap is the SUM of the per-route
|
||||
// budgets (2+4+2+8 = 16), which is a strict upper bound on simultaneous
|
||||
// expensive work. Finer-grained token-bucket per-route can land later via
|
||||
// tower-governor if a multi-user deployment appears.
|
||||
let api = Router::new()
|
||||
.route("/api/v1/cortex/summary", get(summary::summary))
|
||||
.route("/api/v1/cortex/pet/:user_id", get(pet::get_pet))
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/interaction",
|
||||
post(pet::post_interaction),
|
||||
)
|
||||
.route("/api/v1/cortex/pet/:user_id/chat", post(chat::chat))
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/portrait/stylize",
|
||||
post(portrait::stylize).layer(DefaultBodyLimit::max(PORTRAIT_BODY_LIMIT)),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/cortex/stt",
|
||||
post(stt::transcribe).layer(DefaultBodyLimit::max(STT_BODY_LIMIT)),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/tts",
|
||||
post(tts::synthesize),
|
||||
)
|
||||
.route("/api/v1/cortex/ledger/recent", get(ledger::recent))
|
||||
.route("/api/v1/cortex/memory/search", get(memory::search_memory))
|
||||
.route("/api/v1/cortex/usage", get(usage::usage))
|
||||
.route("/api/v1/cortex/fs/list", get(fs_list::list))
|
||||
.route("/api/v1/cortex/tool/apply", post(tool_apply::apply))
|
||||
.route("/api/v1/cortex/term", get(term::ws_handler))
|
||||
let api = build_api_router();
|
||||
let api = api
|
||||
.route_layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
require_bearer,
|
||||
crate::routes_auth::require_bearer,
|
||||
))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
|
|
@ -119,6 +75,43 @@ pub fn build_router(state: AppState) -> Router {
|
|||
.with_state(state)
|
||||
}
|
||||
|
||||
/// Assemble the protected API sub-router (no auth layer yet — applied by caller).
|
||||
fn build_api_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/cortex/summary", get(summary::summary))
|
||||
.route("/api/v1/cortex/pet/:user_id", get(pet::get_pet))
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/interaction",
|
||||
post(pet::post_interaction)
|
||||
.layer(DefaultBodyLimit::max(INTERACTION_BODY_LIMIT)),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/chat",
|
||||
post(chat::chat).layer(DefaultBodyLimit::max(CHAT_BODY_LIMIT)),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/portrait/stylize",
|
||||
post(portrait::stylize).layer(DefaultBodyLimit::max(PORTRAIT_BODY_LIMIT)),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/cortex/stt",
|
||||
post(stt::transcribe).layer(DefaultBodyLimit::max(STT_BODY_LIMIT)),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/cortex/pet/:user_id/tts",
|
||||
post(tts::synthesize).layer(DefaultBodyLimit::max(TTS_BODY_LIMIT)),
|
||||
)
|
||||
.route("/api/v1/cortex/ledger/recent", get(ledger::recent))
|
||||
.route("/api/v1/cortex/memory/search", get(memory::search_memory))
|
||||
.route("/api/v1/cortex/usage", get(usage::usage))
|
||||
.route("/api/v1/cortex/fs/list", get(fs_list::list))
|
||||
.route(
|
||||
"/api/v1/cortex/tool/apply",
|
||||
post(tool_apply::apply).layer(DefaultBodyLimit::max(TOOL_APPLY_BODY_LIMIT)),
|
||||
)
|
||||
.route("/api/v1/cortex/term", get(term::ws_handler))
|
||||
}
|
||||
|
||||
/// Build the CORS layer locked to a single origin.
|
||||
fn build_cors(origin: &str) -> Result<CorsLayer, String> {
|
||||
let origin_header: HeaderValue = origin
|
||||
|
|
@ -131,60 +124,3 @@ fn build_cors(origin: &str) -> Result<CorsLayer, String> {
|
|||
.allow_credentials(true))
|
||||
}
|
||||
|
||||
/// Bearer-token middleware.
|
||||
///
|
||||
/// Two acceptable transports — checked in order:
|
||||
/// 1. `Authorization: Bearer <token>` — standard HTTP requests.
|
||||
/// 2. `Sec-WebSocket-Protocol: bearer, <token>` — WS upgrade only,
|
||||
/// because browsers cannot set the Authorization header on a
|
||||
/// `new WebSocket(url, [...protocols])` call.
|
||||
///
|
||||
/// Missing → 401; mismatch → 403.
|
||||
async fn require_bearer(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
let token = state.token().to_string();
|
||||
if let Some(got) = bearer_from_authorization(&req) {
|
||||
return finish_auth(&token, &got, req, next).await;
|
||||
}
|
||||
if let Some(got) = bearer_from_websocket_subprotocol(&req) {
|
||||
return finish_auth(&token, &got, req, next).await;
|
||||
}
|
||||
Err(AppError::Unauthorized)
|
||||
}
|
||||
|
||||
/// Pull `<token>` from `Authorization: Bearer <token>` if present.
|
||||
fn bearer_from_authorization(req: &Request) -> Option<String> {
|
||||
let v = req.headers().get(header::AUTHORIZATION)?.to_str().ok()?;
|
||||
Some(v.strip_prefix("Bearer ")?.trim().to_string())
|
||||
}
|
||||
|
||||
/// Pull `<token>` from `Sec-WebSocket-Protocol: bearer, <token>`. The
|
||||
/// browser's `new WebSocket(url, ['bearer', tok])` produces this header.
|
||||
fn bearer_from_websocket_subprotocol(req: &Request) -> Option<String> {
|
||||
let v = req
|
||||
.headers()
|
||||
.get("sec-websocket-protocol")?
|
||||
.to_str()
|
||||
.ok()?;
|
||||
let mut parts = v.split(',').map(str::trim);
|
||||
if parts.next()? != "bearer" {
|
||||
return None;
|
||||
}
|
||||
Some(parts.next()?.to_string())
|
||||
}
|
||||
|
||||
/// Compare expected vs. supplied; on match, call `next`.
|
||||
async fn finish_auth(
|
||||
expected: &str,
|
||||
got: &str,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
if !tokens_match(expected, got) {
|
||||
return Err(AppError::Forbidden);
|
||||
}
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
|
|
|||
69
_primitives/_rust/kei-cortex/src/routes_auth.rs
Normal file
69
_primitives/_rust/kei-cortex/src/routes_auth.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
//! Bearer-token middleware for the cortex API router.
|
||||
//!
|
||||
//! Separated from `routes.rs` so each file stays under 200 LOC.
|
||||
|
||||
use crate::auth::tokens_match;
|
||||
use crate::error::AppError;
|
||||
use crate::state::AppState;
|
||||
use axum::extract::{Request, State};
|
||||
use axum::http::header;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
|
||||
/// Bearer-token middleware.
|
||||
///
|
||||
/// Two acceptable transports — checked in order:
|
||||
/// 1. `Authorization: Bearer <token>` — standard HTTP requests.
|
||||
/// 2. `Sec-WebSocket-Protocol: bearer, <token>` — WS upgrade only,
|
||||
/// because browsers cannot set the Authorization header on a
|
||||
/// `new WebSocket(url, [...protocols])` call.
|
||||
///
|
||||
/// Missing → 401; mismatch → 403.
|
||||
pub async fn require_bearer(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
let token = state.token().to_string();
|
||||
if let Some(got) = bearer_from_authorization(&req) {
|
||||
return finish_auth(&token, &got, req, next).await;
|
||||
}
|
||||
if let Some(got) = bearer_from_websocket_subprotocol(&req) {
|
||||
return finish_auth(&token, &got, req, next).await;
|
||||
}
|
||||
Err(AppError::Unauthorized)
|
||||
}
|
||||
|
||||
/// Pull `<token>` from `Authorization: Bearer <token>` if present.
|
||||
fn bearer_from_authorization(req: &Request) -> Option<String> {
|
||||
let v = req.headers().get(header::AUTHORIZATION)?.to_str().ok()?;
|
||||
Some(v.strip_prefix("Bearer ")?.trim().to_string())
|
||||
}
|
||||
|
||||
/// Pull `<token>` from `Sec-WebSocket-Protocol: bearer, <token>`. The
|
||||
/// browser's `new WebSocket(url, ['bearer', tok])` produces this header.
|
||||
fn bearer_from_websocket_subprotocol(req: &Request) -> Option<String> {
|
||||
let v = req
|
||||
.headers()
|
||||
.get("sec-websocket-protocol")?
|
||||
.to_str()
|
||||
.ok()?;
|
||||
let mut parts = v.split(',').map(str::trim);
|
||||
if parts.next()? != "bearer" {
|
||||
return None;
|
||||
}
|
||||
Some(parts.next()?.to_string())
|
||||
}
|
||||
|
||||
/// Compare expected vs. supplied; on match, call `next`.
|
||||
async fn finish_auth(
|
||||
expected: &str,
|
||||
got: &str,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
if !tokens_match(expected, got) {
|
||||
return Err(AppError::Forbidden);
|
||||
}
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
Loading…
Reference in a new issue