fix(kei-cortex): SSRF + atomic token + body limits + capped reads

Group C — kei-cortex daemon security hardening (post-audit 2026-05-02).

- fal_ssrf.rs (new): validate_fal_url whitelist (fal.ai/.media/.run only).
                      Applied to upload_url, file_url, status_url, images[0].url,
                      and download_image. Closes SSRF where compromised fal response
                      could direct daemon to fetch IMDSv1 (169.254.169.254) and
                      stream cloud creds.
- fal_pipeline.rs (new): HTTP step functions extracted from fal.rs; fal.rs trimmed
                          to thin orchestrator (101 LOC, was over 200 LOC limit).
- auth.rs: save_token now writes to <path>.<nanos>.tmp + sync_all + rename. Was
            non-atomic OpenOptions truncate+write — crash mid-write produced empty
            token file -> bootstrap rotated -> stale clients locked out.
- routes.rs + routes_auth.rs (new): explicit DefaultBodyLimit per route — chat 256 KiB,
                                     tool/apply 11 MiB, pet/interaction 64 KiB, tts 32 KiB.
                                     Bearer auth middleware extracted to routes_auth.
- handlers/chat.rs: validate_body enforces MAX_MESSAGE_CHARS = 50_000. Closed cost
                     amplification where 1.99 MiB chat message billed 500K tokens
                     ($1.50/turn at Sonnet pricing) on every send.
- anthropic_sse.rs: SseParser MAX_BUF = 1 MiB cap; was unbounded — peer streaming
                     1 GB without \\n\\n would OOM daemon.
- http_helpers.rs (new): HTTP_CLIENT: Lazy<reqwest::Client> shared across handlers
                          (was per-request Client::new() => 100-300ms TLS handshake
                          per chat turn, no HTTP/2 multiplexing, fd leak risk on
                          macOS TIME_WAIT).
- http_helpers.rs::read_capped: per-response body cap (16 KiB error / 64 MiB success).
                                  Applied to anthropic, anthropic_invoker, elevenlabs,
                                  fal_pipeline. Closed unbounded resp.text() / .bytes()
                                  pattern that compromised upstream could exploit.

Test results: 462 passed; 0 failed (single-threaded). cargo check clean.
2 pre-existing port-binding flakes in openai_loop_wiring tests are unrelated.

Findings consensus: fal SSRF + body-size + bearer-token-atomicity appeared in
Wave-A retest; chat message cap + SSE buf cap appeared in Wave-A only. Would have
been missed by single audit pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Parfii-bot 2026-05-02 21:39:57 +08:00
parent 8b0401b9db
commit 9aa29aca15
14 changed files with 579 additions and 303 deletions

View file

@ -1,8 +1,8 @@
[package] [package]
name = "kei-cortex" name = "kei-cortex"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition.workspace = true
rust-version = "1.75" rust-version.workspace = true
description = "Local HTTP daemon exposing cortex state for UI consumption" description = "Local HTTP daemon exposing cortex state for UI consumption"
authors = ["Denis Parfionovich <info@greendragon.info>"] authors = ["Denis Parfionovich <info@greendragon.info>"]
@ -16,31 +16,31 @@ path = "src/lib.rs"
[dependencies] [dependencies]
axum = { version = "0.7", features = ["multipart", "ws"] } axum = { version = "0.7", features = ["multipart", "ws"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal", "net", "time", "process", "fs", "io-util", "sync"] } tokio = { workspace = true }
tokio-util = { version = "0.7", features = ["rt"] } tokio-util = { version = "0.7", features = ["rt"] }
tower = { version = "0.4", features = ["limit", "buffer", "util"] } tower = { workspace = true }
tower-http = { version = "0.5", features = ["cors", "trace"] } tower-http = { version = "0.5", features = ["cors", "trace"] }
serde = { version = "1", features = ["derive"] } serde = { workspace = true }
serde_json = "1" serde_json = { workspace = true }
clap = { version = "4", features = ["derive"] } clap = { workspace = true }
thiserror = "1" thiserror = { workspace = true }
rusqlite = { version = "0.31", features = ["bundled"] } rusqlite = { workspace = true }
anyhow = "1" anyhow = { workspace = true }
rand = "0.8" rand = "0.8"
reqwest = { version = "0.12", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false } reqwest = { workspace = true }
tokio-stream = "0.1" tokio-stream = { workspace = true }
futures = "0.3" futures = { workspace = true }
uuid = { version = "1", features = ["v4"] } uuid = { version = "1", features = ["v4"] }
async-stream = "0.3" async-stream = "0.3"
toml = "0.8" toml = { workspace = true }
bytes = "1" bytes = { workspace = true }
tempfile = "3" tempfile = { workspace = true }
dashmap = "5" dashmap = { workspace = true }
walkdir = "2" walkdir = { workspace = true }
which = "6" which = "6"
once_cell = "1" once_cell = "1"
regex = "1.10" regex = { workspace = true }
portable-pty = "0.8" portable-pty = { workspace = true }
# Wave 44a — tool-sandbox hardening # Wave 44a — tool-sandbox hardening
shell-words = { workspace = true } shell-words = { workspace = true }
url = { workspace = true } url = { workspace = true }
@ -64,4 +64,4 @@ kei-model = { path = "../kei-model" }
kei-token-tracker = { path = "../kei-token-tracker" } kei-token-tracker = { path = "../kei-token-tracker" }
[dev-dependencies] [dev-dependencies]
reqwest = { version = "0.12", features = ["json", "blocking", "stream", "rustls-tls"], default-features = false } reqwest = { workspace = true, features = ["blocking"] }

View file

@ -12,6 +12,7 @@
//! an SSE error event rather than hanging the client. //! an SSE error event rather than hanging the client.
use crate::anthropic_sse::SseParser; use crate::anthropic_sse::SseParser;
use crate::http_helpers::{read_capped, HTTP_CLIENT};
use async_stream::try_stream; use async_stream::try_stream;
use futures::stream::Stream; use futures::stream::Stream;
use futures::StreamExt; use futures::StreamExt;
@ -35,6 +36,9 @@ pub const IDLE: Duration = Duration::from_secs(30);
/// large error page into our logs or client. /// large error page into our logs or client.
const BODY_PREVIEW_CAP: usize = 512; const BODY_PREVIEW_CAP: usize = 512;
/// Cap on upstream error body reads via `read_capped` (16 KiB).
const ERROR_BODY_CAP: usize = 16 * 1024;
/// A single turn in the conversation. /// A single turn in the conversation.
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct Message { pub struct Message {
@ -91,7 +95,10 @@ fn body_to_text_stream(
}; };
let Some(chunk) = chunk_opt else { break }; let Some(chunk) = chunk_opt else { break };
let chunk = chunk.map_err(Error::Http)?; let chunk = chunk.map_err(Error::Http)?;
for text in parser.push(&chunk) { let texts = parser.push(&chunk).map_err(|_| {
Error::Upstream { status: 502, body: "SSE frame exceeds 1MB cap".into() }
})?;
for text in texts {
yield text; yield text;
} }
} }
@ -114,8 +121,7 @@ async fn send_request(
api_key: &str, api_key: &str,
body: &serde_json::Value, body: &serde_json::Value,
) -> Result<reqwest::Response, Error> { ) -> Result<reqwest::Response, Error> {
let client = reqwest::Client::new(); let resp = HTTP_CLIENT
let resp = client
.post(endpoint().as_ref()) .post(endpoint().as_ref())
.header("x-api-key", api_key) .header("x-api-key", api_key)
.header("anthropic-version", API_VERSION) .header("anthropic-version", API_VERSION)
@ -142,7 +148,8 @@ async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, Erro
if code == 503 || code == 529 { if code == 503 || code == 529 {
return Err(Error::ServiceUnavailable); return Err(Error::ServiceUnavailable);
} }
let body = resp.text().await.unwrap_or_default(); let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default();
let body = String::from_utf8_lossy(&raw).into_owned();
Err(Error::Upstream { Err(Error::Upstream {
status: code, status: code,
body: truncate(&body, BODY_PREVIEW_CAP), body: truncate(&body, BODY_PREVIEW_CAP),

View file

@ -5,6 +5,7 @@
//! multi-block content (`text` + `tool_use`) into `Vec<ContentBlock>`. //! multi-block content (`text` + `tool_use`) into `Vec<ContentBlock>`.
use crate::anthropic::{default_model, endpoint, API_VERSION}; use crate::anthropic::{default_model, endpoint, API_VERSION};
use crate::http_helpers::{read_capped, HTTP_CLIENT};
use crate::tool::loop_driver::{ use crate::tool::loop_driver::{
ContentBlock, ConversationMessage, ModelInvoker, ModelTurn, TokenUsage, ContentBlock, ConversationMessage, ModelInvoker, ModelTurn, TokenUsage,
}; };
@ -42,9 +43,11 @@ async fn invoke(
Ok(Err(e)) => return Err(format!("anthropic request: {e}")), Ok(Err(e)) => return Err(format!("anthropic request: {e}")),
Err(_) => return Err("anthropic request: timeout".into()), Err(_) => return Err("anthropic request: timeout".into()),
}; };
let raw: Value = resp const SUCCESS_CAP: usize = 64 * 1024 * 1024; // 64 MiB
.json() let bytes = read_capped(resp, SUCCESS_CAP)
.await .await
.map_err(|e| format!("anthropic body read: {e}"))?;
let raw: Value = serde_json::from_slice(&bytes)
.map_err(|e| format!("anthropic body json: {e}"))?; .map_err(|e| format!("anthropic body json: {e}"))?;
parse_turn(&raw) parse_turn(&raw)
} }
@ -103,7 +106,7 @@ fn render_assistant_blocks(blocks: &[ContentBlock]) -> Vec<Value> {
/// Fire the POST request and surface 4xx/5xx as a string error. /// Fire the POST request and surface 4xx/5xx as a string error.
async fn send(api_key: &str, body: &Value) -> Result<reqwest::Response, reqwest::Error> { async fn send(api_key: &str, body: &Value) -> Result<reqwest::Response, reqwest::Error> {
let resp = reqwest::Client::new() let resp = HTTP_CLIENT
.post(endpoint().as_ref()) .post(endpoint().as_ref())
.header("x-api-key", api_key) .header("x-api-key", api_key)
.header("anthropic-version", API_VERSION) .header("anthropic-version", API_VERSION)

View file

@ -23,6 +23,14 @@ struct Delta {
text: Option<String>, text: Option<String>,
} }
/// Maximum buffer size per SSE frame — guards against a runaway upstream
/// that never emits a `\n\n` frame boundary.
const MAX_BUF: usize = 1 * 1024 * 1024; // 1 MiB
/// Error returned by `SseParser::push` when the buffer cap is exceeded.
#[derive(Debug)]
pub(crate) struct SseBufOverflow;
/// Incremental SSE parser — SSE frames are separated by `\n\n`. /// Incremental SSE parser — SSE frames are separated by `\n\n`.
/// ///
/// We buffer partial chunks across `push` calls and return every extracted /// We buffer partial chunks across `push` calls and return every extracted
@ -37,8 +45,14 @@ impl SseParser {
} }
/// Consume a byte chunk, return every text delta completed in this push. /// Consume a byte chunk, return every text delta completed in this push.
pub fn push(&mut self, chunk: &Bytes) -> Vec<String> { /// Returns `Err(SseBufOverflow)` if the internal buffer exceeds `MAX_BUF`
/// before a complete frame arrives; caller should abort the stream.
pub fn push(&mut self, chunk: &Bytes) -> Result<Vec<String>, SseBufOverflow> {
self.buf.push_str(&String::from_utf8_lossy(chunk)); self.buf.push_str(&String::from_utf8_lossy(chunk));
if self.buf.len() > MAX_BUF {
self.buf.clear();
return Err(SseBufOverflow);
}
let mut out = Vec::new(); let mut out = Vec::new();
while let Some(idx) = self.buf.find("\n\n") { while let Some(idx) = self.buf.find("\n\n") {
let frame: String = self.buf.drain(..idx + 2).collect(); let frame: String = self.buf.drain(..idx + 2).collect();
@ -46,7 +60,7 @@ impl SseParser {
out.push(text); out.push(text);
} }
} }
out Ok(out)
} }
} }
@ -89,7 +103,17 @@ mod tests {
"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"ab\"}", "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"ab\"}",
); );
let part2 = Bytes::from("}\n\n"); let part2 = Bytes::from("}\n\n");
assert!(p.push(&part1).is_empty()); assert!(p.push(&part1).unwrap().is_empty());
assert_eq!(p.push(&part2), vec!["ab".to_string()]); assert_eq!(p.push(&part2).unwrap(), vec!["ab".to_string()]);
}
#[test]
fn parser_rejects_oversized_buf() {
let mut p = SseParser::new();
// Push just over MAX_BUF bytes without a frame boundary.
let big = Bytes::from(vec![b'x'; MAX_BUF + 1]);
assert!(p.push(&big).is_err());
// Buffer must be cleared so subsequent calls don't accumulate.
assert!(p.push(&Bytes::from("\n\n")).unwrap().is_empty());
} }
} }

View file

@ -103,30 +103,51 @@ pub fn tokens_match(expected: &str, got: &str) -> bool {
diff == 0 diff == 0
} }
/// Build a unique temp path next to `path`: `<path>.<nanos>.tmp`.
fn tmp_path(path: &Path) -> std::path::PathBuf {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let name = path.file_name().map(|n| {
let mut s = n.to_os_string();
s.push(format!(".{ts}.tmp"));
s
});
match (path.parent(), name) {
(Some(p), Some(n)) => p.join(n),
_ => path.with_extension("tmp"),
}
}
#[cfg(unix)] #[cfg(unix)]
fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> { fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
use std::os::unix::fs::OpenOptionsExt; use std::os::unix::fs::OpenOptionsExt;
let tmp = tmp_path(path);
let mut f = fs::OpenOptions::new() let mut f = fs::OpenOptions::new()
.write(true) .write(true)
.create(true) .create(true)
.truncate(true) .truncate(true)
.mode(0o600) .mode(0o600)
.open(path)?; .open(&tmp)?;
f.write_all(bytes)?; f.write_all(bytes)?;
f.sync_all()?; f.sync_all()?;
Ok(()) drop(f);
fs::rename(&tmp, path)
} }
#[cfg(not(unix))] #[cfg(not(unix))]
fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> { fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> {
let tmp = tmp_path(path);
let mut f = fs::OpenOptions::new() let mut f = fs::OpenOptions::new()
.write(true) .write(true)
.create(true) .create(true)
.truncate(true) .truncate(true)
.open(path)?; .open(&tmp)?;
f.write_all(bytes)?; f.write_all(bytes)?;
f.sync_all()?; f.sync_all()?;
Ok(()) drop(f);
fs::rename(&tmp, path)
} }
#[cfg(test)] #[cfg(test)]

View file

@ -10,6 +10,7 @@
//! reference, never hardcoded). We never log the full text; only its length //! reference, never hardcoded). We never log the full text; only its length
//! and the response status / byte count. //! and the response status / byte count.
use crate::http_helpers::{read_capped, HTTP_CLIENT};
use std::time::Duration; use std::time::Duration;
/// Errors surfaced to the caller; handlers map them onto HTTP codes. /// Errors surfaced to the caller; handlers map them onto HTTP codes.
@ -29,12 +30,15 @@ const TTS_BASE: &str = "https://api.elevenlabs.io/v1/text-to-speech";
const BUDGET: Duration = Duration::from_secs(60); const BUDGET: Duration = Duration::from_secs(60);
const BODY_PREVIEW_CAP: usize = 512; const BODY_PREVIEW_CAP: usize = 512;
const MODEL_ID: &str = "eleven_turbo_v2_5"; const MODEL_ID: &str = "eleven_turbo_v2_5";
/// Cap on successful audio response bodies (64 MiB).
const AUDIO_BODY_CAP: usize = 64 * 1024 * 1024;
/// Cap on error response bodies (16 KiB).
const ERROR_BODY_CAP: usize = 16 * 1024;
/// Synthesize speech for `text` using the given `voice_id`. Returns mp3 bytes. /// Synthesize speech for `text` using the given `voice_id`. Returns mp3 bytes.
pub async fn synthesize(voice_id: &str, text: &str) -> Result<Vec<u8>, Error> { pub async fn synthesize(voice_id: &str, text: &str) -> Result<Vec<u8>, Error> {
let key = std::env::var("ELEVENLABS_API_KEY").map_err(|_| Error::NoApiKey)?; let key = std::env::var("ELEVENLABS_API_KEY").map_err(|_| Error::NoApiKey)?;
let client = reqwest::Client::new(); let fut = call_tts(&key, voice_id, text);
let fut = call_tts(&client, &key, voice_id, text);
match tokio::time::timeout(BUDGET, fut).await { match tokio::time::timeout(BUDGET, fut).await {
Ok(r) => r, Ok(r) => r,
Err(_) => Err(Error::Timeout), Err(_) => Err(Error::Timeout),
@ -44,14 +48,13 @@ pub async fn synthesize(voice_id: &str, text: &str) -> Result<Vec<u8>, Error> {
/// POST the JSON body and collect the audio bytes. Split from `synthesize` /// POST the JSON body and collect the audio bytes. Split from `synthesize`
/// so the timeout wrapper stays a thin shell. /// so the timeout wrapper stays a thin shell.
async fn call_tts( async fn call_tts(
client: &reqwest::Client,
key: &str, key: &str,
voice_id: &str, voice_id: &str,
text: &str, text: &str,
) -> Result<Vec<u8>, Error> { ) -> Result<Vec<u8>, Error> {
let url = format!("{TTS_BASE}/{voice_id}"); let url = format!("{TTS_BASE}/{voice_id}");
let body = build_body(text); let body = build_body(text);
let resp = client let resp = HTTP_CLIENT
.post(&url) .post(&url)
.header("xi-api-key", key) .header("xi-api-key", key)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
@ -77,13 +80,20 @@ fn build_body(text: &str) -> serde_json::Value {
/// Turn an ElevenLabs response into raw mp3 bytes; map non-2xx to `Upstream`. /// Turn an ElevenLabs response into raw mp3 bytes; map non-2xx to `Upstream`.
async fn decode_audio(resp: reqwest::Response) -> Result<Vec<u8>, Error> { async fn decode_audio(resp: reqwest::Response) -> Result<Vec<u8>, Error> {
use crate::http_helpers::CappedReadError;
let status = resp.status(); let status = resp.status();
if !status.is_success() { if !status.is_success() {
let body = resp.text().await.unwrap_or_default(); let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default();
let truncated = truncate(&body, BODY_PREVIEW_CAP); let body = String::from_utf8_lossy(&raw).into_owned();
return Err(Error::Upstream(status.as_u16(), truncated)); return Err(Error::Upstream(status.as_u16(), truncate(&body, BODY_PREVIEW_CAP)));
}
match read_capped(resp, AUDIO_BODY_CAP).await {
Ok(bytes) => Ok(bytes),
Err(CappedReadError::TooLarge) => {
Err(Error::Upstream(502, "audio body exceeded 64 MiB cap".into()))
}
Err(CappedReadError::Http(e)) => Err(Error::Http(e)),
} }
Ok(resp.bytes().await?.to_vec())
} }
/// Cap a string at `max` bytes on a char boundary. Used for error previews /// Cap a string at `max` bytes on a char boundary. Used for error previews

View file

@ -1,15 +1,16 @@
//! fal.ai Flux 2 Pro client — stylize a portrait into an anime frame. //! fal.ai Flux 2 Pro client — stylize a portrait into an anime frame.
//! //!
//! The public surface is a single async function `stylize(source_png, style)`. //! Public surface: `stylize(source_png, style)`.
//! Internally we: (1) upload the source image to fal storage via the
//! two-step `/storage/upload/initiate` handshake, (2) POST a generation
//! request to the queue endpoint, (3) poll status until `COMPLETED`, (4)
//! download the first image in the result. The whole pipeline has a 60-second
//! wall-clock budget, past which we return `Error::Timeout` so the handler
//! can surface an HTTP 504.
//! //!
//! `FAL_KEY` is read from the environment on every call; the daemon does not //! Pipeline: (1) upload source image to fal storage, (2) POST generation
//! cache it because the user may rotate it without restarting. //! request to the queue endpoint, (3) poll status until `COMPLETED`, (4)
//! download the first result image. 60-second wall-clock budget; past that
//! the function returns `Error::Timeout` so the handler can surface HTTP 504.
//!
//! `FAL_KEY` is read from the environment on every call (env-rotation-friendly,
//! per RULE 0.8 — never hardcoded, never cached in daemon state).
//!
//! HTTP pipeline steps live in `fal_pipeline`; SSRF mitigation in `fal_ssrf`.
use std::time::Duration; use std::time::Duration;
@ -31,7 +32,7 @@ impl Style {
} }
} }
fn prompt(self) -> &'static str { pub(crate) fn prompt(self) -> &'static str {
match self { match self {
Self::Cute => CUTE_PROMPT, Self::Cute => CUTE_PROMPT,
Self::Cool => COOL_PROMPT, Self::Cool => COOL_PROMPT,
@ -57,150 +58,44 @@ pub enum Error {
BadShape(&'static str), BadShape(&'static str),
#[error("fal polling exceeded 60s budget")] #[error("fal polling exceeded 60s budget")]
Timeout, Timeout,
#[error("non-fal URL rejected (SSRF): {0}")]
NonFalUrl(String),
} }
const UPLOAD_INITIATE_URL: &str = "https://rest.alpha.fal.ai/storage/upload/initiate"; pub(crate) const BUDGET: Duration = Duration::from_secs(60);
const GENERATION_URL: &str = "https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra"; pub(crate) const POLL_INTERVAL: Duration = Duration::from_millis(800);
const BUDGET: Duration = Duration::from_secs(60); pub(crate) const BODY_PREVIEW_CAP: usize = 512;
const POLL_INTERVAL: Duration = Duration::from_millis(800); pub(crate) const ERROR_BODY_CAP: usize = 16 * 1024;
const BODY_PREVIEW_CAP: usize = 512; pub(crate) const IMAGE_BODY_CAP: usize = 64 * 1024 * 1024;
/// Stylize `source_png` into an anime portrait. Returns the raw PNG bytes /// Stylize `source_png` into an anime portrait. Returns raw PNG bytes.
/// of the generated image (caller writes them to disk).
pub async fn stylize(source_png: &[u8], style: Style) -> Result<Vec<u8>, Error> { pub async fn stylize(source_png: &[u8], style: Style) -> Result<Vec<u8>, Error> {
let key = std::env::var("FAL_KEY").map_err(|_| Error::NoApiKey)?; let key = std::env::var("FAL_KEY").map_err(|_| Error::NoApiKey)?;
let client = reqwest::Client::new();
let deadline = tokio::time::Instant::now() + BUDGET; let deadline = tokio::time::Instant::now() + BUDGET;
let uploaded_url = upload_image(&client, &key, source_png).await?; let uploaded_url = crate::fal_pipeline::upload_image(&key, source_png).await?;
let status_url = enqueue(&client, &key, &uploaded_url, style).await?; let status_url = crate::fal_pipeline::enqueue(&key, &uploaded_url, style).await?;
let result_url = poll_until_done(&client, &key, &status_url, deadline).await?; let result_url =
download_image(&client, &key, &result_url).await crate::fal_pipeline::poll_until_done(&key, &status_url, deadline, POLL_INTERVAL)
}
/// Step 1 — ask fal storage for a signed PUT URL, then PUT the image to it.
async fn upload_image(client: &reqwest::Client, key: &str, bytes: &[u8]) -> Result<String, Error> {
let body = serde_json::json!({ "file_name": "portrait.png", "content_type": "image/png" });
let resp = client
.post(UPLOAD_INITIATE_URL)
.header("Authorization", format!("Key {key}"))
.json(&body)
.send()
.await?; .await?;
let json = decode_json(resp).await?; crate::fal_pipeline::download_image(&key, &result_url).await
let upload_url = json.get("upload_url").and_then(|v| v.as_str()).ok_or(Error::BadShape("upload_url"))?;
let file_url = json.get("file_url").and_then(|v| v.as_str()).ok_or(Error::BadShape("file_url"))?;
let put = client.put(upload_url).header("Content-Type", "image/png").body(bytes.to_vec()).send().await?;
if !put.status().is_success() {
return Err(Error::BadStatus(put.status().as_u16(), "PUT upload failed".into()));
}
Ok(file_url.to_string())
}
/// Step 2 — POST the generation request, return the status poll URL.
async fn enqueue(client: &reqwest::Client, key: &str, image_url: &str, style: Style) -> Result<String, Error> {
let body = serde_json::json!({
"image_url": image_url,
"prompt": style.prompt(),
"strength": 0.65,
"enable_safety_checker": true,
});
let resp = client
.post(GENERATION_URL)
.header("Authorization", format!("Key {key}"))
.json(&body)
.send()
.await?;
let json = decode_json(resp).await?;
json.get("status_url")
.and_then(|v| v.as_str())
.map(str::to_string)
.ok_or(Error::BadShape("status_url"))
}
/// Step 3 — poll the status URL until `COMPLETED`, or give up at deadline.
async fn poll_until_done(client: &reqwest::Client, key: &str, status_url: &str, deadline: tokio::time::Instant) -> Result<String, Error> {
loop {
if tokio::time::Instant::now() >= deadline {
return Err(Error::Timeout);
}
let resp = client.get(status_url).header("Authorization", format!("Key {key}")).send().await?;
let json = decode_json(resp).await?;
let status = json.get("status").and_then(|v| v.as_str()).unwrap_or("");
if status == "COMPLETED" {
return extract_first_image_url(&json);
}
if status == "FAILED" || status == "ERROR" {
return Err(Error::BadStatus(502, format!("fal reported {status}")));
}
tokio::time::sleep(POLL_INTERVAL).await;
}
}
/// Step 4 — download the PNG bytes from the fal CDN URL.
async fn download_image(client: &reqwest::Client, key: &str, url: &str) -> Result<Vec<u8>, Error> {
let resp = client.get(url).header("Authorization", format!("Key {key}")).send().await?;
if !resp.status().is_success() {
return Err(Error::BadStatus(resp.status().as_u16(), "download failed".into()));
}
Ok(resp.bytes().await?.to_vec())
}
/// Dig out `images[0].url` from the completed status payload.
fn extract_first_image_url(json: &serde_json::Value) -> Result<String, Error> {
json.get("images")
.and_then(|a| a.as_array())
.and_then(|a| a.first())
.and_then(|o| o.get("url"))
.and_then(|v| v.as_str())
.map(str::to_string)
.ok_or(Error::BadShape("images[0].url"))
}
/// Decode a fal JSON response, turning non-2xx into `BadStatus` with body.
/// Body capped at `BODY_PREVIEW_CAP` so a large upstream error page cannot
/// propagate through our logs or error channel.
async fn decode_json(resp: reqwest::Response) -> Result<serde_json::Value, Error> {
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(Error::BadStatus(
status.as_u16(),
truncate(&body, BODY_PREVIEW_CAP),
));
}
Ok(resp.json::<serde_json::Value>().await?)
}
/// Cap a string at `max` bytes on a char boundary. Keeps fal error previews
/// bounded regardless of what upstream sent back.
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
return s.to_string();
}
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
s[..end].to_string()
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn truncate_caps_long_strings() {
let long = "a".repeat(10_000);
assert_eq!(truncate(&long, 256).len(), 256);
}
#[test]
fn truncate_leaves_short_strings() {
assert_eq!(truncate("hi", 256), "hi");
}
#[test] #[test]
fn style_from_wire_defaults_to_cute() { fn style_from_wire_defaults_to_cute() {
assert!(matches!(Style::from_wire("wat"), Style::Cute)); assert!(matches!(Style::from_wire("wat"), Style::Cute));
} }
#[test]
fn style_from_wire_cool() {
assert!(matches!(Style::from_wire("anime-cool"), Style::Cool));
}
#[test]
fn style_from_wire_studious() {
assert!(matches!(Style::from_wire("anime-studious"), Style::Studious));
}
} }

View file

@ -0,0 +1,173 @@
//! fal.ai HTTP pipeline steps — upload, enqueue, poll, download.
//!
//! Each step is a single HTTPS round-trip. All use the shared `HTTP_CLIENT`
//! from `http_helpers` (no per-request `reqwest::Client::new()`).
//! Response bodies are read through `read_capped` so a runaway upstream
//! cannot allocate unbounded memory.
use crate::fal::{Error, Style};
use crate::fal::{BODY_PREVIEW_CAP, ERROR_BODY_CAP, IMAGE_BODY_CAP};
use crate::fal_ssrf::validate_fal_url;
use crate::http_helpers::{read_capped, CappedReadError, HTTP_CLIENT};
pub(crate) const UPLOAD_INITIATE_URL: &str =
"https://rest.alpha.fal.ai/storage/upload/initiate";
pub(crate) const GENERATION_URL: &str =
"https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra";
/// Step 1 — ask fal storage for a signed PUT URL, then PUT the image to it.
pub(crate) async fn upload_image(key: &str, bytes: &[u8]) -> Result<String, Error> {
let body =
serde_json::json!({ "file_name": "portrait.png", "content_type": "image/png" });
let resp = HTTP_CLIENT
.post(UPLOAD_INITIATE_URL)
.header("Authorization", format!("Key {key}"))
.json(&body)
.send()
.await?;
let json = decode_json(resp).await?;
let upload_url = json
.get("upload_url")
.and_then(|v| v.as_str())
.ok_or(Error::BadShape("upload_url"))?;
let file_url = json
.get("file_url")
.and_then(|v| v.as_str())
.ok_or(Error::BadShape("file_url"))?;
validate_fal_url(upload_url)?;
validate_fal_url(file_url)?;
HTTP_CLIENT
.put(upload_url)
.header("Content-Type", "image/png")
.body(bytes.to_vec())
.send()
.await?
.error_for_status()
.map_err(Error::Http)?;
Ok(file_url.to_string())
}
/// Step 2 — POST the generation request, return the validated status URL.
pub(crate) async fn enqueue(
key: &str,
image_url: &str,
style: Style,
) -> Result<String, Error> {
let body = serde_json::json!({
"image_url": image_url,
"prompt": style.prompt(),
"strength": 0.65,
"enable_safety_checker": true,
});
let resp = HTTP_CLIENT
.post(GENERATION_URL)
.header("Authorization", format!("Key {key}"))
.json(&body)
.send()
.await?;
let json = decode_json(resp).await?;
let status_url = json
.get("status_url")
.and_then(|v| v.as_str())
.map(str::to_string)
.ok_or(Error::BadShape("status_url"))?;
validate_fal_url(&status_url)?;
Ok(status_url)
}
/// Step 3 — poll `status_url` until `COMPLETED`, or give up at `deadline`.
pub(crate) async fn poll_until_done(
key: &str,
status_url: &str,
deadline: tokio::time::Instant,
poll_interval: std::time::Duration,
) -> Result<String, Error> {
loop {
if tokio::time::Instant::now() >= deadline {
return Err(Error::Timeout);
}
let resp = HTTP_CLIENT
.get(status_url)
.header("Authorization", format!("Key {key}"))
.send()
.await?;
let json = decode_json(resp).await?;
match json.get("status").and_then(|v| v.as_str()).unwrap_or("") {
"COMPLETED" => return extract_first_image_url(&json),
s @ ("FAILED" | "ERROR") => {
return Err(Error::BadStatus(502, format!("fal reported {s}")));
}
_ => {}
}
tokio::time::sleep(poll_interval).await;
}
}
/// Step 4 — download the PNG bytes from a validated fal CDN URL.
pub(crate) async fn download_image(key: &str, url: &str) -> Result<Vec<u8>, Error> {
validate_fal_url(url)?;
let resp = HTTP_CLIENT
.get(url)
.header("Authorization", format!("Key {key}"))
.send()
.await?;
if !resp.status().is_success() {
return Err(Error::BadStatus(resp.status().as_u16(), "download failed".into()));
}
match read_capped(resp, IMAGE_BODY_CAP).await {
Ok(b) => Ok(b),
Err(CappedReadError::TooLarge) => {
Err(Error::BadStatus(502, "image body exceeded 64 MiB cap".into()))
}
Err(CappedReadError::Http(e)) => Err(Error::Http(e)),
}
}
/// Dig out `images[0].url` from the completed status payload, then validate.
pub(crate) fn extract_first_image_url(json: &serde_json::Value) -> Result<String, Error> {
let url = json
.get("images")
.and_then(|a| a.as_array())
.and_then(|a| a.first())
.and_then(|o| o.get("url"))
.and_then(|v| v.as_str())
.map(str::to_string)
.ok_or(Error::BadShape("images[0].url"))?;
validate_fal_url(&url)?;
Ok(url)
}
/// Decode a fal response into JSON, capping error/success bodies.
pub(crate) async fn decode_json(
resp: reqwest::Response,
) -> Result<serde_json::Value, Error> {
let status = resp.status();
if !status.is_success() {
let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default();
let body = String::from_utf8_lossy(&raw).into_owned();
return Err(Error::BadStatus(
status.as_u16(),
truncate(&body, BODY_PREVIEW_CAP),
));
}
let raw = match read_capped(resp, IMAGE_BODY_CAP).await {
Ok(b) => b,
Err(CappedReadError::TooLarge) => {
return Err(Error::BadStatus(502, "response body exceeded 64 MiB cap".into()));
}
Err(CappedReadError::Http(e)) => return Err(Error::Http(e)),
};
serde_json::from_slice(&raw).map_err(|e| Error::BadStatus(502, e.to_string()))
}
/// Cap a string at `max` bytes on a char boundary.
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
return s.to_string();
}
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
s[..end].to_string()
}

View file

@ -0,0 +1,67 @@
//! SSRF mitigation for fal.ai outbound requests.
//!
//! `validate_fal_url` rejects any URL whose scheme is not `https://` or
//! whose host does not match the fal.ai domain allowlist before the URL
//! is ever passed to `reqwest`.
use crate::fal::Error;
use once_cell::sync::Lazy;
use regex::Regex;
/// Compiled allowlist for fal.ai host names.
/// Accepted: `<one-or-more-subdomains>.fal.ai`, `.fal.media`, `.fal.run`.
/// Example: `rest.alpha.fal.ai`, `queue.fal.run`, `cdn.fal.media`.
static FAL_HOST_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^([a-z0-9-]+\.)+fal\.(ai|media|run)$").unwrap());
/// Reject URLs that do not match the fal.ai HTTPS allowlist.
///
/// Returns `Err(Error::NonFalUrl)` when:
/// - the scheme is not `https://`, or
/// - the host does not match `^[a-z0-9-]+\.fal\.(ai|media|run)$`.
pub(crate) fn validate_fal_url(url: &str) -> Result<(), Error> {
let stripped = url.strip_prefix("https://").ok_or_else(|| {
Error::NonFalUrl(url.chars().take(80).collect())
})?;
let host = stripped.split('/').next().unwrap_or("");
if !FAL_HOST_RE.is_match(host) {
return Err(Error::NonFalUrl(url.chars().take(80).collect()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn accepts_valid_fal_urls() {
assert!(validate_fal_url("https://rest.alpha.fal.ai/storage/upload/initiate").is_ok());
assert!(validate_fal_url("https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra").is_ok());
assert!(validate_fal_url("https://cdn.fal.media/image.png").is_ok());
}
#[test]
fn rejects_http_scheme() {
assert!(matches!(
validate_fal_url("http://queue.fal.run/foo"),
Err(Error::NonFalUrl(_))
));
}
#[test]
fn rejects_non_fal_domain() {
assert!(matches!(
validate_fal_url("https://evil.example.com/steal"),
Err(Error::NonFalUrl(_))
));
}
#[test]
fn rejects_ssrf_internal_ip() {
assert!(matches!(
validate_fal_url("https://169.254.169.254/latest/meta-data/"),
Err(Error::NonFalUrl(_))
));
}
}

View file

@ -97,10 +97,20 @@ fn validate_provider(state: &AppState, name: &str) -> Result<(), AppError> {
} }
} }
/// Character ceiling for chat messages. Prevents runaway prompt injection
/// and upstream token cost abuse.
const MAX_MESSAGE_CHARS: usize = 50_000;
fn validate_body(req: &ChatRequest) -> Result<(), AppError> { fn validate_body(req: &ChatRequest) -> Result<(), AppError> {
if req.message.is_empty() { if req.message.is_empty() {
return Err(AppError::BadRequest("message is empty".into())); return Err(AppError::BadRequest("message is empty".into()));
} }
let chars = req.message.chars().count();
if chars > MAX_MESSAGE_CHARS {
return Err(AppError::PayloadTooLarge(format!(
"{chars} chars > {MAX_MESSAGE_CHARS}"
)));
}
Ok(()) Ok(())
} }

View file

@ -0,0 +1,57 @@
//! Shared HTTP utilities: a process-wide `reqwest::Client` and a capped
//! response-body reader.
//!
//! A single `reqwest::Client` is reused for all outbound calls to avoid
//! exhausting OS connection-table entries when the daemon is under load.
//! The client is initialized once via `once_cell::sync::Lazy`.
use once_cell::sync::Lazy;
/// The process-wide HTTP client. Shared by `anthropic`, `anthropic_invoker`,
/// `elevenlabs`, and `fal` so that connection pooling is effective.
pub static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.build()
.expect("reqwest::Client::builder().build() — no TLS config error expected")
});
/// Error returned when a response body exceeds the caller-supplied cap.
#[derive(Debug)]
pub struct BodyTooLarge;
impl std::fmt::Display for BodyTooLarge {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "response body exceeded size cap")
}
}
impl std::error::Error for BodyTooLarge {}
/// Read a response body up to `max_bytes`. Returns `Err(BodyTooLarge)` if
/// the stream exceeds the cap; the partial buffer is discarded on overflow
/// so the caller does not accidentally use truncated data.
pub async fn read_capped(
resp: reqwest::Response,
max_bytes: usize,
) -> Result<Vec<u8>, CappedReadError> {
use futures::StreamExt;
let mut bytes: Vec<u8> = Vec::new();
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(CappedReadError::Http)?;
if bytes.len() + chunk.len() > max_bytes {
return Err(CappedReadError::TooLarge);
}
bytes.extend_from_slice(&chunk);
}
Ok(bytes)
}
/// Combined error for `read_capped`.
#[derive(Debug, thiserror::Error)]
pub enum CappedReadError {
#[error("response body exceeded size cap")]
TooLarge,
#[error("http stream error: {0}")]
Http(reqwest::Error),
}

View file

@ -20,11 +20,15 @@ pub mod context;
pub mod elevenlabs; pub mod elevenlabs;
pub mod error; pub mod error;
pub mod fal; pub mod fal;
pub(crate) mod fal_pipeline;
pub(crate) mod fal_ssrf;
pub mod handlers; pub mod handlers;
pub mod http_helpers;
pub mod persona; pub mod persona;
pub mod whisper_local; pub mod whisper_local;
pub mod rig_clone; pub mod rig_clone;
pub mod routes; pub mod routes;
pub(crate) mod routes_auth;
pub mod sentiment; pub mod sentiment;
pub mod state; pub mod state;
pub mod tool; pub mod tool;

View file

@ -1,59 +1,47 @@
//! Router assembly + bearer-token middleware + CORS layer. //! Router assembly + CORS layer.
//! //!
//! `/healthz` is mounted OUTSIDE the auth middleware so monitors can hit it //! `/healthz` is mounted OUTSIDE the auth middleware so monitors can hit it
//! without a token. Everything under `/api` goes through `require_bearer`. //! without a token. Everything under `/api` goes through `require_bearer`
//! (defined in `routes_auth`).
//! //!
//! Per-route concurrency caps protect us from a runaway client draining our //! Per-route concurrency caps protect us from a runaway client draining our
//! upstream budget — `fal.ai` in particular bills per run, so we cap //! upstream budget — `fal.ai` in particular bills per run, so we cap
//! `/portrait/stylize` at 2 concurrent installs system-wide. Other expensive //! `/portrait/stylize` at 2 concurrent installs system-wide. Other expensive
//! routes (`/tts`, `/stt`, `/chat`) get matching caps tuned to their bottleneck. //! routes (`/tts`, `/stt`, `/chat`) get matching caps tuned to their bottleneck.
use crate::auth::tokens_match;
use crate::error::AppError;
use crate::handlers::{ use crate::handlers::{
chat, fs_list, health, ledger, memory, pet, portrait, stt, summary, term, tool_apply, tts, chat, fs_list, health, ledger, memory, pet, portrait, stt, summary, term, tool_apply, tts,
usage, usage,
}; };
use crate::state::AppState; use crate::state::AppState;
use axum::error_handling::HandleErrorLayer; use axum::error_handling::HandleErrorLayer;
use axum::extract::{DefaultBodyLimit, Request, State}; use axum::extract::DefaultBodyLimit;
use axum::http::{header, HeaderValue, Method, StatusCode}; use axum::http::{header, HeaderValue, Method, StatusCode};
use axum::middleware::{self, Next}; use axum::middleware;
use axum::response::{IntoResponse, Response}; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::Router; use axum::Router;
// OpenAI-compatible /v1/* surface. Lives under `routes/openai/` as a
// sibling tree to this file; declared via `#[path]` so the directory
// can coexist with `routes.rs` without becoming an inline `routes/mod.rs`.
#[path = "routes/openai/mod.rs"] #[path = "routes/openai/mod.rs"]
pub mod openai; pub mod openai;
use tower::buffer::BufferLayer; use tower::buffer::BufferLayer;
use tower::limit::ConcurrencyLimitLayer; use tower::limit::ConcurrencyLimitLayer;
use tower::{BoxError, ServiceBuilder}; use tower::{BoxError, ServiceBuilder};
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
/// Upper bound on the `/portrait/stylize` multipart body. The handler enforces // --- Body limits (per-route pre-parse gates) --------------------------------
/// the stricter 10 MiB cap on the `file` field; this is just the pre-parse const PORTRAIT_BODY_LIMIT: usize = 12 * 1024 * 1024; // 12 MiB (handler re-checks 10)
/// gate that Axum applies before it can see individual fields. const STT_BODY_LIMIT: usize = 26 * 1024 * 1024; // 26 MiB (handler re-checks 25)
const PORTRAIT_BODY_LIMIT: usize = 12 * 1024 * 1024; const CHAT_BODY_LIMIT: usize = 256 * 1024; // 256 KiB
const TOOL_APPLY_BODY_LIMIT: usize = 11 * 1024 * 1024; // 11 MiB (handler checks 10)
const INTERACTION_BODY_LIMIT: usize = 64 * 1024; // 64 KiB
const TTS_BODY_LIMIT: usize = 32 * 1024; // 32 KiB
/// Upper bound on the `/stt` multipart body. The handler enforces the // --- Concurrency budgets ----------------------------------------------------
/// stricter 25 MiB cap on the `audio` field itself; this is the slack so
/// that field headers + form overhead do not trip the pre-parse gate.
const STT_BODY_LIMIT: usize = 26 * 1024 * 1024;
/// Max concurrent Flux stylize runs system-wide. fal.ai bills per run.
const PORTRAIT_CONCURRENCY: usize = 2; const PORTRAIT_CONCURRENCY: usize = 2;
/// Max concurrent ElevenLabs TTS calls.
const TTS_CONCURRENCY: usize = 4; const TTS_CONCURRENCY: usize = 4;
/// Max concurrent whisper worker runs. CPU-bound, so the cap matches a
/// conservative small-laptop core count.
const STT_CONCURRENCY: usize = 2; const STT_CONCURRENCY: usize = 2;
/// Max concurrent Anthropic chat streams.
const CHAT_CONCURRENCY: usize = 8; const CHAT_CONCURRENCY: usize = 8;
/// Build the top-level router. `cors_origin` must have been validated at /// Build the top-level router. `cors_origin` must have been validated at
@ -62,43 +50,11 @@ pub fn build_router(state: AppState) -> Router {
let cors = build_cors(state.config().cors_origin.as_str()) let cors = build_cors(state.config().cors_origin.as_str())
.expect("cors_origin must be valid — validated in AppConfig::new"); .expect("cors_origin must be valid — validated in AppConfig::new");
// Per-route granular caps proved fragile with axum 0.7's MethodRouter let api = build_api_router();
// layer bounds (HandleErrorLayer + ConcurrencyLimitLayer service is not let api = api
// `Clone` in a way that layer() accepts). Apply a single router-wide cap
// via route_layer — it wraps every inner route uniformly without the
// per-method error-type headache. The cap is the SUM of the per-route
// budgets (2+4+2+8 = 16), which is a strict upper bound on simultaneous
// expensive work. Finer-grained token-bucket per-route can land later via
// tower-governor if a multi-user deployment appears.
let api = Router::new()
.route("/api/v1/cortex/summary", get(summary::summary))
.route("/api/v1/cortex/pet/:user_id", get(pet::get_pet))
.route(
"/api/v1/cortex/pet/:user_id/interaction",
post(pet::post_interaction),
)
.route("/api/v1/cortex/pet/:user_id/chat", post(chat::chat))
.route(
"/api/v1/cortex/pet/:user_id/portrait/stylize",
post(portrait::stylize).layer(DefaultBodyLimit::max(PORTRAIT_BODY_LIMIT)),
)
.route(
"/api/v1/cortex/stt",
post(stt::transcribe).layer(DefaultBodyLimit::max(STT_BODY_LIMIT)),
)
.route(
"/api/v1/cortex/pet/:user_id/tts",
post(tts::synthesize),
)
.route("/api/v1/cortex/ledger/recent", get(ledger::recent))
.route("/api/v1/cortex/memory/search", get(memory::search_memory))
.route("/api/v1/cortex/usage", get(usage::usage))
.route("/api/v1/cortex/fs/list", get(fs_list::list))
.route("/api/v1/cortex/tool/apply", post(tool_apply::apply))
.route("/api/v1/cortex/term", get(term::ws_handler))
.route_layer(middleware::from_fn_with_state( .route_layer(middleware::from_fn_with_state(
state.clone(), state.clone(),
require_bearer, crate::routes_auth::require_bearer,
)) ))
.layer( .layer(
ServiceBuilder::new() ServiceBuilder::new()
@ -119,6 +75,43 @@ pub fn build_router(state: AppState) -> Router {
.with_state(state) .with_state(state)
} }
/// Assemble the protected API sub-router (no auth layer yet — applied by caller).
fn build_api_router() -> Router<AppState> {
Router::new()
.route("/api/v1/cortex/summary", get(summary::summary))
.route("/api/v1/cortex/pet/:user_id", get(pet::get_pet))
.route(
"/api/v1/cortex/pet/:user_id/interaction",
post(pet::post_interaction)
.layer(DefaultBodyLimit::max(INTERACTION_BODY_LIMIT)),
)
.route(
"/api/v1/cortex/pet/:user_id/chat",
post(chat::chat).layer(DefaultBodyLimit::max(CHAT_BODY_LIMIT)),
)
.route(
"/api/v1/cortex/pet/:user_id/portrait/stylize",
post(portrait::stylize).layer(DefaultBodyLimit::max(PORTRAIT_BODY_LIMIT)),
)
.route(
"/api/v1/cortex/stt",
post(stt::transcribe).layer(DefaultBodyLimit::max(STT_BODY_LIMIT)),
)
.route(
"/api/v1/cortex/pet/:user_id/tts",
post(tts::synthesize).layer(DefaultBodyLimit::max(TTS_BODY_LIMIT)),
)
.route("/api/v1/cortex/ledger/recent", get(ledger::recent))
.route("/api/v1/cortex/memory/search", get(memory::search_memory))
.route("/api/v1/cortex/usage", get(usage::usage))
.route("/api/v1/cortex/fs/list", get(fs_list::list))
.route(
"/api/v1/cortex/tool/apply",
post(tool_apply::apply).layer(DefaultBodyLimit::max(TOOL_APPLY_BODY_LIMIT)),
)
.route("/api/v1/cortex/term", get(term::ws_handler))
}
/// Build the CORS layer locked to a single origin. /// Build the CORS layer locked to a single origin.
fn build_cors(origin: &str) -> Result<CorsLayer, String> { fn build_cors(origin: &str) -> Result<CorsLayer, String> {
let origin_header: HeaderValue = origin let origin_header: HeaderValue = origin
@ -131,60 +124,3 @@ fn build_cors(origin: &str) -> Result<CorsLayer, String> {
.allow_credentials(true)) .allow_credentials(true))
} }
/// Bearer-token middleware.
///
/// Two acceptable transports — checked in order:
/// 1. `Authorization: Bearer <token>` — standard HTTP requests.
/// 2. `Sec-WebSocket-Protocol: bearer, <token>` — WS upgrade only,
/// because browsers cannot set the Authorization header on a
/// `new WebSocket(url, [...protocols])` call.
///
/// Missing → 401; mismatch → 403.
async fn require_bearer(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Result<Response, AppError> {
let token = state.token().to_string();
if let Some(got) = bearer_from_authorization(&req) {
return finish_auth(&token, &got, req, next).await;
}
if let Some(got) = bearer_from_websocket_subprotocol(&req) {
return finish_auth(&token, &got, req, next).await;
}
Err(AppError::Unauthorized)
}
/// Pull `<token>` from `Authorization: Bearer <token>` if present.
fn bearer_from_authorization(req: &Request) -> Option<String> {
let v = req.headers().get(header::AUTHORIZATION)?.to_str().ok()?;
Some(v.strip_prefix("Bearer ")?.trim().to_string())
}
/// Pull `<token>` from `Sec-WebSocket-Protocol: bearer, <token>`. The
/// browser's `new WebSocket(url, ['bearer', tok])` produces this header.
fn bearer_from_websocket_subprotocol(req: &Request) -> Option<String> {
let v = req
.headers()
.get("sec-websocket-protocol")?
.to_str()
.ok()?;
let mut parts = v.split(',').map(str::trim);
if parts.next()? != "bearer" {
return None;
}
Some(parts.next()?.to_string())
}
/// Compare expected vs. supplied; on match, call `next`.
async fn finish_auth(
expected: &str,
got: &str,
req: Request,
next: Next,
) -> Result<Response, AppError> {
if !tokens_match(expected, got) {
return Err(AppError::Forbidden);
}
Ok(next.run(req).await)
}

View file

@ -0,0 +1,69 @@
//! Bearer-token middleware for the cortex API router.
//!
//! Separated from `routes.rs` so each file stays under 200 LOC.
use crate::auth::tokens_match;
use crate::error::AppError;
use crate::state::AppState;
use axum::extract::{Request, State};
use axum::http::header;
use axum::middleware::Next;
use axum::response::Response;
/// Bearer-token middleware.
///
/// Two acceptable transports — checked in order:
/// 1. `Authorization: Bearer <token>` — standard HTTP requests.
/// 2. `Sec-WebSocket-Protocol: bearer, <token>` — WS upgrade only,
/// because browsers cannot set the Authorization header on a
/// `new WebSocket(url, [...protocols])` call.
///
/// Missing → 401; mismatch → 403.
pub async fn require_bearer(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Result<Response, AppError> {
let token = state.token().to_string();
if let Some(got) = bearer_from_authorization(&req) {
return finish_auth(&token, &got, req, next).await;
}
if let Some(got) = bearer_from_websocket_subprotocol(&req) {
return finish_auth(&token, &got, req, next).await;
}
Err(AppError::Unauthorized)
}
/// Pull `<token>` from `Authorization: Bearer <token>` if present.
fn bearer_from_authorization(req: &Request) -> Option<String> {
let v = req.headers().get(header::AUTHORIZATION)?.to_str().ok()?;
Some(v.strip_prefix("Bearer ")?.trim().to_string())
}
/// Pull `<token>` from `Sec-WebSocket-Protocol: bearer, <token>`. The
/// browser's `new WebSocket(url, ['bearer', tok])` produces this header.
fn bearer_from_websocket_subprotocol(req: &Request) -> Option<String> {
let v = req
.headers()
.get("sec-websocket-protocol")?
.to_str()
.ok()?;
let mut parts = v.split(',').map(str::trim);
if parts.next()? != "bearer" {
return None;
}
Some(parts.next()?.to_string())
}
/// Compare expected vs. supplied; on match, call `next`.
async fn finish_auth(
expected: &str,
got: &str,
req: Request,
next: Next,
) -> Result<Response, AppError> {
if !tokens_match(expected, got) {
return Err(AppError::Forbidden);
}
Ok(next.run(req).await)
}