From 9aa29aca15cfab36a6f1e64d0c50a1f530ff5ef1 Mon Sep 17 00:00:00 2001 From: Parfii-bot Date: Sat, 2 May 2026 21:39:57 +0800 Subject: [PATCH] fix(kei-cortex): SSRF + atomic token + body limits + capped reads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Group C — kei-cortex daemon security hardening (post-audit 2026-05-02). - fal_ssrf.rs (new): validate_fal_url whitelist (fal.ai/.media/.run only). Applied to upload_url, file_url, status_url, images[0].url, and download_image. Closes SSRF where compromised fal response could direct daemon to fetch IMDSv1 (169.254.169.254) and stream cloud creds. - fal_pipeline.rs (new): HTTP step functions extracted from fal.rs; fal.rs trimmed to thin orchestrator (101 LOC, was over 200 LOC limit). - auth.rs: save_token now writes to ..tmp + sync_all + rename. Was non-atomic OpenOptions truncate+write — crash mid-write produced empty token file -> bootstrap rotated -> stale clients locked out. - routes.rs + routes_auth.rs (new): explicit DefaultBodyLimit per route — chat 256 KiB, tool/apply 11 MiB, pet/interaction 64 KiB, tts 32 KiB. Bearer auth middleware extracted to routes_auth. - handlers/chat.rs: validate_body enforces MAX_MESSAGE_CHARS = 50_000. Closed cost amplification where 1.99 MiB chat message billed 500K tokens ($1.50/turn at Sonnet pricing) on every send. - anthropic_sse.rs: SseParser MAX_BUF = 1 MiB cap; was unbounded — peer streaming 1 GB without \\n\\n would OOM daemon. - http_helpers.rs (new): HTTP_CLIENT: Lazy shared across handlers (was per-request Client::new() => 100-300ms TLS handshake per chat turn, no HTTP/2 multiplexing, fd leak risk on macOS TIME_WAIT). - http_helpers.rs::read_capped: per-response body cap (16 KiB error / 64 MiB success). Applied to anthropic, anthropic_invoker, elevenlabs, fal_pipeline. Closed unbounded resp.text() / .bytes() pattern that compromised upstream could exploit. Test results: 462 passed; 0 failed (single-threaded). cargo check clean. 2 pre-existing port-binding flakes in openai_loop_wiring tests are unrelated. Findings consensus: fal SSRF + body-size + bearer-token-atomicity appeared in Wave-A retest; chat message cap + SSE buf cap appeared in Wave-A only. Would have been missed by single audit pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- _primitives/_rust/kei-cortex/Cargo.toml | 42 ++--- _primitives/_rust/kei-cortex/src/anthropic.rs | 15 +- .../_rust/kei-cortex/src/anthropic_invoker.rs | 9 +- .../_rust/kei-cortex/src/anthropic_sse.rs | 32 +++- _primitives/_rust/kei-cortex/src/auth.rs | 29 ++- .../_rust/kei-cortex/src/elevenlabs.rs | 26 ++- _primitives/_rust/kei-cortex/src/fal.rs | 175 ++++-------------- .../_rust/kei-cortex/src/fal_pipeline.rs | 173 +++++++++++++++++ _primitives/_rust/kei-cortex/src/fal_ssrf.rs | 67 +++++++ .../_rust/kei-cortex/src/handlers/chat.rs | 10 + .../_rust/kei-cortex/src/http_helpers.rs | 57 ++++++ _primitives/_rust/kei-cortex/src/lib.rs | 4 + _primitives/_rust/kei-cortex/src/routes.rs | 174 ++++++----------- .../_rust/kei-cortex/src/routes_auth.rs | 69 +++++++ 14 files changed, 579 insertions(+), 303 deletions(-) create mode 100644 _primitives/_rust/kei-cortex/src/fal_pipeline.rs create mode 100644 _primitives/_rust/kei-cortex/src/fal_ssrf.rs create mode 100644 _primitives/_rust/kei-cortex/src/http_helpers.rs create mode 100644 _primitives/_rust/kei-cortex/src/routes_auth.rs diff --git a/_primitives/_rust/kei-cortex/Cargo.toml b/_primitives/_rust/kei-cortex/Cargo.toml index f925d79..8c04010 100644 --- a/_primitives/_rust/kei-cortex/Cargo.toml +++ b/_primitives/_rust/kei-cortex/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "kei-cortex" version = "0.1.0" -edition = "2021" -rust-version = "1.75" +edition.workspace = true +rust-version.workspace = true description = "Local HTTP daemon exposing cortex state for UI consumption" authors = ["Denis Parfionovich "] @@ -16,31 +16,31 @@ path = "src/lib.rs" [dependencies] axum = { version = "0.7", features = ["multipart", "ws"] } -tokio = { version = "1", features = ["rt-multi-thread", "macros", "signal", "net", "time", "process", "fs", "io-util", "sync"] } +tokio = { workspace = true } tokio-util = { version = "0.7", features = ["rt"] } -tower = { version = "0.4", features = ["limit", "buffer", "util"] } +tower = { workspace = true } tower-http = { version = "0.5", features = ["cors", "trace"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -clap = { version = "4", features = ["derive"] } -thiserror = "1" -rusqlite = { version = "0.31", features = ["bundled"] } -anyhow = "1" +serde = { workspace = true } +serde_json = { workspace = true } +clap = { workspace = true } +thiserror = { workspace = true } +rusqlite = { workspace = true } +anyhow = { workspace = true } rand = "0.8" -reqwest = { version = "0.12", features = ["json", "stream", "multipart", "rustls-tls"], default-features = false } -tokio-stream = "0.1" -futures = "0.3" +reqwest = { workspace = true } +tokio-stream = { workspace = true } +futures = { workspace = true } uuid = { version = "1", features = ["v4"] } async-stream = "0.3" -toml = "0.8" -bytes = "1" -tempfile = "3" -dashmap = "5" -walkdir = "2" +toml = { workspace = true } +bytes = { workspace = true } +tempfile = { workspace = true } +dashmap = { workspace = true } +walkdir = { workspace = true } which = "6" once_cell = "1" -regex = "1.10" -portable-pty = "0.8" +regex = { workspace = true } +portable-pty = { workspace = true } # Wave 44a — tool-sandbox hardening shell-words = { workspace = true } url = { workspace = true } @@ -64,4 +64,4 @@ kei-model = { path = "../kei-model" } kei-token-tracker = { path = "../kei-token-tracker" } [dev-dependencies] -reqwest = { version = "0.12", features = ["json", "blocking", "stream", "rustls-tls"], default-features = false } +reqwest = { workspace = true, features = ["blocking"] } diff --git a/_primitives/_rust/kei-cortex/src/anthropic.rs b/_primitives/_rust/kei-cortex/src/anthropic.rs index 1ce86c7..ff1f981 100644 --- a/_primitives/_rust/kei-cortex/src/anthropic.rs +++ b/_primitives/_rust/kei-cortex/src/anthropic.rs @@ -12,6 +12,7 @@ //! an SSE error event rather than hanging the client. use crate::anthropic_sse::SseParser; +use crate::http_helpers::{read_capped, HTTP_CLIENT}; use async_stream::try_stream; use futures::stream::Stream; use futures::StreamExt; @@ -35,6 +36,9 @@ pub const IDLE: Duration = Duration::from_secs(30); /// large error page into our logs or client. const BODY_PREVIEW_CAP: usize = 512; +/// Cap on upstream error body reads via `read_capped` (16 KiB). +const ERROR_BODY_CAP: usize = 16 * 1024; + /// A single turn in the conversation. #[derive(Debug, Clone, Serialize)] pub struct Message { @@ -91,7 +95,10 @@ fn body_to_text_stream( }; let Some(chunk) = chunk_opt else { break }; let chunk = chunk.map_err(Error::Http)?; - for text in parser.push(&chunk) { + let texts = parser.push(&chunk).map_err(|_| { + Error::Upstream { status: 502, body: "SSE frame exceeds 1MB cap".into() } + })?; + for text in texts { yield text; } } @@ -114,8 +121,7 @@ async fn send_request( api_key: &str, body: &serde_json::Value, ) -> Result { - let client = reqwest::Client::new(); - let resp = client + let resp = HTTP_CLIENT .post(endpoint().as_ref()) .header("x-api-key", api_key) .header("anthropic-version", API_VERSION) @@ -142,7 +148,8 @@ async fn check_status(resp: reqwest::Response) -> Result`. use crate::anthropic::{default_model, endpoint, API_VERSION}; +use crate::http_helpers::{read_capped, HTTP_CLIENT}; use crate::tool::loop_driver::{ ContentBlock, ConversationMessage, ModelInvoker, ModelTurn, TokenUsage, }; @@ -42,9 +43,11 @@ async fn invoke( Ok(Err(e)) => return Err(format!("anthropic request: {e}")), Err(_) => return Err("anthropic request: timeout".into()), }; - let raw: Value = resp - .json() + const SUCCESS_CAP: usize = 64 * 1024 * 1024; // 64 MiB + let bytes = read_capped(resp, SUCCESS_CAP) .await + .map_err(|e| format!("anthropic body read: {e}"))?; + let raw: Value = serde_json::from_slice(&bytes) .map_err(|e| format!("anthropic body json: {e}"))?; parse_turn(&raw) } @@ -103,7 +106,7 @@ fn render_assistant_blocks(blocks: &[ContentBlock]) -> Vec { /// Fire the POST request and surface 4xx/5xx as a string error. async fn send(api_key: &str, body: &Value) -> Result { - let resp = reqwest::Client::new() + let resp = HTTP_CLIENT .post(endpoint().as_ref()) .header("x-api-key", api_key) .header("anthropic-version", API_VERSION) diff --git a/_primitives/_rust/kei-cortex/src/anthropic_sse.rs b/_primitives/_rust/kei-cortex/src/anthropic_sse.rs index 1b726de..0ff6f9a 100644 --- a/_primitives/_rust/kei-cortex/src/anthropic_sse.rs +++ b/_primitives/_rust/kei-cortex/src/anthropic_sse.rs @@ -23,6 +23,14 @@ struct Delta { text: Option, } +/// Maximum buffer size per SSE frame — guards against a runaway upstream +/// that never emits a `\n\n` frame boundary. +const MAX_BUF: usize = 1 * 1024 * 1024; // 1 MiB + +/// Error returned by `SseParser::push` when the buffer cap is exceeded. +#[derive(Debug)] +pub(crate) struct SseBufOverflow; + /// Incremental SSE parser — SSE frames are separated by `\n\n`. /// /// We buffer partial chunks across `push` calls and return every extracted @@ -37,8 +45,14 @@ impl SseParser { } /// Consume a byte chunk, return every text delta completed in this push. - pub fn push(&mut self, chunk: &Bytes) -> Vec { + /// Returns `Err(SseBufOverflow)` if the internal buffer exceeds `MAX_BUF` + /// before a complete frame arrives; caller should abort the stream. + pub fn push(&mut self, chunk: &Bytes) -> Result, SseBufOverflow> { self.buf.push_str(&String::from_utf8_lossy(chunk)); + if self.buf.len() > MAX_BUF { + self.buf.clear(); + return Err(SseBufOverflow); + } let mut out = Vec::new(); while let Some(idx) = self.buf.find("\n\n") { let frame: String = self.buf.drain(..idx + 2).collect(); @@ -46,7 +60,7 @@ impl SseParser { out.push(text); } } - out + Ok(out) } } @@ -89,7 +103,17 @@ mod tests { "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"ab\"}", ); let part2 = Bytes::from("}\n\n"); - assert!(p.push(&part1).is_empty()); - assert_eq!(p.push(&part2), vec!["ab".to_string()]); + assert!(p.push(&part1).unwrap().is_empty()); + assert_eq!(p.push(&part2).unwrap(), vec!["ab".to_string()]); + } + + #[test] + fn parser_rejects_oversized_buf() { + let mut p = SseParser::new(); + // Push just over MAX_BUF bytes without a frame boundary. + let big = Bytes::from(vec![b'x'; MAX_BUF + 1]); + assert!(p.push(&big).is_err()); + // Buffer must be cleared so subsequent calls don't accumulate. + assert!(p.push(&Bytes::from("\n\n")).unwrap().is_empty()); } } diff --git a/_primitives/_rust/kei-cortex/src/auth.rs b/_primitives/_rust/kei-cortex/src/auth.rs index ff6715e..9c7ff38 100644 --- a/_primitives/_rust/kei-cortex/src/auth.rs +++ b/_primitives/_rust/kei-cortex/src/auth.rs @@ -103,30 +103,51 @@ pub fn tokens_match(expected: &str, got: &str) -> bool { diff == 0 } +/// Build a unique temp path next to `path`: `..tmp`. +fn tmp_path(path: &Path) -> std::path::PathBuf { + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + let name = path.file_name().map(|n| { + let mut s = n.to_os_string(); + s.push(format!(".{ts}.tmp")); + s + }); + match (path.parent(), name) { + (Some(p), Some(n)) => p.join(n), + _ => path.with_extension("tmp"), + } +} + #[cfg(unix)] fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> { use std::os::unix::fs::OpenOptionsExt; + let tmp = tmp_path(path); let mut f = fs::OpenOptions::new() .write(true) .create(true) .truncate(true) .mode(0o600) - .open(path)?; + .open(&tmp)?; f.write_all(bytes)?; f.sync_all()?; - Ok(()) + drop(f); + fs::rename(&tmp, path) } #[cfg(not(unix))] fn write_mode_600(path: &Path, bytes: &[u8]) -> std::io::Result<()> { + let tmp = tmp_path(path); let mut f = fs::OpenOptions::new() .write(true) .create(true) .truncate(true) - .open(path)?; + .open(&tmp)?; f.write_all(bytes)?; f.sync_all()?; - Ok(()) + drop(f); + fs::rename(&tmp, path) } #[cfg(test)] diff --git a/_primitives/_rust/kei-cortex/src/elevenlabs.rs b/_primitives/_rust/kei-cortex/src/elevenlabs.rs index ec2c0d4..5d71208 100644 --- a/_primitives/_rust/kei-cortex/src/elevenlabs.rs +++ b/_primitives/_rust/kei-cortex/src/elevenlabs.rs @@ -10,6 +10,7 @@ //! reference, never hardcoded). We never log the full text; only its length //! and the response status / byte count. +use crate::http_helpers::{read_capped, HTTP_CLIENT}; use std::time::Duration; /// Errors surfaced to the caller; handlers map them onto HTTP codes. @@ -29,12 +30,15 @@ const TTS_BASE: &str = "https://api.elevenlabs.io/v1/text-to-speech"; const BUDGET: Duration = Duration::from_secs(60); const BODY_PREVIEW_CAP: usize = 512; const MODEL_ID: &str = "eleven_turbo_v2_5"; +/// Cap on successful audio response bodies (64 MiB). +const AUDIO_BODY_CAP: usize = 64 * 1024 * 1024; +/// Cap on error response bodies (16 KiB). +const ERROR_BODY_CAP: usize = 16 * 1024; /// Synthesize speech for `text` using the given `voice_id`. Returns mp3 bytes. pub async fn synthesize(voice_id: &str, text: &str) -> Result, Error> { let key = std::env::var("ELEVENLABS_API_KEY").map_err(|_| Error::NoApiKey)?; - let client = reqwest::Client::new(); - let fut = call_tts(&client, &key, voice_id, text); + let fut = call_tts(&key, voice_id, text); match tokio::time::timeout(BUDGET, fut).await { Ok(r) => r, Err(_) => Err(Error::Timeout), @@ -44,14 +48,13 @@ pub async fn synthesize(voice_id: &str, text: &str) -> Result, Error> { /// POST the JSON body and collect the audio bytes. Split from `synthesize` /// so the timeout wrapper stays a thin shell. async fn call_tts( - client: &reqwest::Client, key: &str, voice_id: &str, text: &str, ) -> Result, Error> { let url = format!("{TTS_BASE}/{voice_id}"); let body = build_body(text); - let resp = client + let resp = HTTP_CLIENT .post(&url) .header("xi-api-key", key) .header("Content-Type", "application/json") @@ -77,13 +80,20 @@ fn build_body(text: &str) -> serde_json::Value { /// Turn an ElevenLabs response into raw mp3 bytes; map non-2xx to `Upstream`. async fn decode_audio(resp: reqwest::Response) -> Result, Error> { + use crate::http_helpers::CappedReadError; let status = resp.status(); if !status.is_success() { - let body = resp.text().await.unwrap_or_default(); - let truncated = truncate(&body, BODY_PREVIEW_CAP); - return Err(Error::Upstream(status.as_u16(), truncated)); + let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default(); + let body = String::from_utf8_lossy(&raw).into_owned(); + return Err(Error::Upstream(status.as_u16(), truncate(&body, BODY_PREVIEW_CAP))); + } + match read_capped(resp, AUDIO_BODY_CAP).await { + Ok(bytes) => Ok(bytes), + Err(CappedReadError::TooLarge) => { + Err(Error::Upstream(502, "audio body exceeded 64 MiB cap".into())) + } + Err(CappedReadError::Http(e)) => Err(Error::Http(e)), } - Ok(resp.bytes().await?.to_vec()) } /// Cap a string at `max` bytes on a char boundary. Used for error previews diff --git a/_primitives/_rust/kei-cortex/src/fal.rs b/_primitives/_rust/kei-cortex/src/fal.rs index 229d3fb..6fdd46a 100644 --- a/_primitives/_rust/kei-cortex/src/fal.rs +++ b/_primitives/_rust/kei-cortex/src/fal.rs @@ -1,15 +1,16 @@ //! fal.ai Flux 2 Pro client — stylize a portrait into an anime frame. //! -//! The public surface is a single async function `stylize(source_png, style)`. -//! Internally we: (1) upload the source image to fal storage via the -//! two-step `/storage/upload/initiate` handshake, (2) POST a generation -//! request to the queue endpoint, (3) poll status until `COMPLETED`, (4) -//! download the first image in the result. The whole pipeline has a 60-second -//! wall-clock budget, past which we return `Error::Timeout` so the handler -//! can surface an HTTP 504. +//! Public surface: `stylize(source_png, style)`. //! -//! `FAL_KEY` is read from the environment on every call; the daemon does not -//! cache it because the user may rotate it without restarting. +//! Pipeline: (1) upload source image to fal storage, (2) POST generation +//! request to the queue endpoint, (3) poll status until `COMPLETED`, (4) +//! download the first result image. 60-second wall-clock budget; past that +//! the function returns `Error::Timeout` so the handler can surface HTTP 504. +//! +//! `FAL_KEY` is read from the environment on every call (env-rotation-friendly, +//! per RULE 0.8 — never hardcoded, never cached in daemon state). +//! +//! HTTP pipeline steps live in `fal_pipeline`; SSRF mitigation in `fal_ssrf`. use std::time::Duration; @@ -31,7 +32,7 @@ impl Style { } } - fn prompt(self) -> &'static str { + pub(crate) fn prompt(self) -> &'static str { match self { Self::Cute => CUTE_PROMPT, Self::Cool => COOL_PROMPT, @@ -57,150 +58,44 @@ pub enum Error { BadShape(&'static str), #[error("fal polling exceeded 60s budget")] Timeout, + #[error("non-fal URL rejected (SSRF): {0}")] + NonFalUrl(String), } -const UPLOAD_INITIATE_URL: &str = "https://rest.alpha.fal.ai/storage/upload/initiate"; -const GENERATION_URL: &str = "https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra"; -const BUDGET: Duration = Duration::from_secs(60); -const POLL_INTERVAL: Duration = Duration::from_millis(800); -const BODY_PREVIEW_CAP: usize = 512; +pub(crate) const BUDGET: Duration = Duration::from_secs(60); +pub(crate) const POLL_INTERVAL: Duration = Duration::from_millis(800); +pub(crate) const BODY_PREVIEW_CAP: usize = 512; +pub(crate) const ERROR_BODY_CAP: usize = 16 * 1024; +pub(crate) const IMAGE_BODY_CAP: usize = 64 * 1024 * 1024; -/// Stylize `source_png` into an anime portrait. Returns the raw PNG bytes -/// of the generated image (caller writes them to disk). +/// Stylize `source_png` into an anime portrait. Returns raw PNG bytes. pub async fn stylize(source_png: &[u8], style: Style) -> Result, Error> { let key = std::env::var("FAL_KEY").map_err(|_| Error::NoApiKey)?; - let client = reqwest::Client::new(); let deadline = tokio::time::Instant::now() + BUDGET; - let uploaded_url = upload_image(&client, &key, source_png).await?; - let status_url = enqueue(&client, &key, &uploaded_url, style).await?; - let result_url = poll_until_done(&client, &key, &status_url, deadline).await?; - download_image(&client, &key, &result_url).await -} - -/// Step 1 — ask fal storage for a signed PUT URL, then PUT the image to it. -async fn upload_image(client: &reqwest::Client, key: &str, bytes: &[u8]) -> Result { - let body = serde_json::json!({ "file_name": "portrait.png", "content_type": "image/png" }); - let resp = client - .post(UPLOAD_INITIATE_URL) - .header("Authorization", format!("Key {key}")) - .json(&body) - .send() - .await?; - let json = decode_json(resp).await?; - let upload_url = json.get("upload_url").and_then(|v| v.as_str()).ok_or(Error::BadShape("upload_url"))?; - let file_url = json.get("file_url").and_then(|v| v.as_str()).ok_or(Error::BadShape("file_url"))?; - let put = client.put(upload_url).header("Content-Type", "image/png").body(bytes.to_vec()).send().await?; - if !put.status().is_success() { - return Err(Error::BadStatus(put.status().as_u16(), "PUT upload failed".into())); - } - Ok(file_url.to_string()) -} - -/// Step 2 — POST the generation request, return the status poll URL. -async fn enqueue(client: &reqwest::Client, key: &str, image_url: &str, style: Style) -> Result { - let body = serde_json::json!({ - "image_url": image_url, - "prompt": style.prompt(), - "strength": 0.65, - "enable_safety_checker": true, - }); - let resp = client - .post(GENERATION_URL) - .header("Authorization", format!("Key {key}")) - .json(&body) - .send() - .await?; - let json = decode_json(resp).await?; - json.get("status_url") - .and_then(|v| v.as_str()) - .map(str::to_string) - .ok_or(Error::BadShape("status_url")) -} - -/// Step 3 — poll the status URL until `COMPLETED`, or give up at deadline. -async fn poll_until_done(client: &reqwest::Client, key: &str, status_url: &str, deadline: tokio::time::Instant) -> Result { - loop { - if tokio::time::Instant::now() >= deadline { - return Err(Error::Timeout); - } - let resp = client.get(status_url).header("Authorization", format!("Key {key}")).send().await?; - let json = decode_json(resp).await?; - let status = json.get("status").and_then(|v| v.as_str()).unwrap_or(""); - if status == "COMPLETED" { - return extract_first_image_url(&json); - } - if status == "FAILED" || status == "ERROR" { - return Err(Error::BadStatus(502, format!("fal reported {status}"))); - } - tokio::time::sleep(POLL_INTERVAL).await; - } -} - -/// Step 4 — download the PNG bytes from the fal CDN URL. -async fn download_image(client: &reqwest::Client, key: &str, url: &str) -> Result, Error> { - let resp = client.get(url).header("Authorization", format!("Key {key}")).send().await?; - if !resp.status().is_success() { - return Err(Error::BadStatus(resp.status().as_u16(), "download failed".into())); - } - Ok(resp.bytes().await?.to_vec()) -} - -/// Dig out `images[0].url` from the completed status payload. -fn extract_first_image_url(json: &serde_json::Value) -> Result { - json.get("images") - .and_then(|a| a.as_array()) - .and_then(|a| a.first()) - .and_then(|o| o.get("url")) - .and_then(|v| v.as_str()) - .map(str::to_string) - .ok_or(Error::BadShape("images[0].url")) -} - -/// Decode a fal JSON response, turning non-2xx into `BadStatus` with body. -/// Body capped at `BODY_PREVIEW_CAP` so a large upstream error page cannot -/// propagate through our logs or error channel. -async fn decode_json(resp: reqwest::Response) -> Result { - let status = resp.status(); - if !status.is_success() { - let body = resp.text().await.unwrap_or_default(); - return Err(Error::BadStatus( - status.as_u16(), - truncate(&body, BODY_PREVIEW_CAP), - )); - } - Ok(resp.json::().await?) -} - -/// Cap a string at `max` bytes on a char boundary. Keeps fal error previews -/// bounded regardless of what upstream sent back. -fn truncate(s: &str, max: usize) -> String { - if s.len() <= max { - return s.to_string(); - } - let mut end = max; - while end > 0 && !s.is_char_boundary(end) { - end -= 1; - } - s[..end].to_string() + let uploaded_url = crate::fal_pipeline::upload_image(&key, source_png).await?; + let status_url = crate::fal_pipeline::enqueue(&key, &uploaded_url, style).await?; + let result_url = + crate::fal_pipeline::poll_until_done(&key, &status_url, deadline, POLL_INTERVAL) + .await?; + crate::fal_pipeline::download_image(&key, &result_url).await } #[cfg(test)] mod tests { use super::*; - #[test] - fn truncate_caps_long_strings() { - let long = "a".repeat(10_000); - assert_eq!(truncate(&long, 256).len(), 256); - } - - #[test] - fn truncate_leaves_short_strings() { - assert_eq!(truncate("hi", 256), "hi"); - } - #[test] fn style_from_wire_defaults_to_cute() { assert!(matches!(Style::from_wire("wat"), Style::Cute)); } + + #[test] + fn style_from_wire_cool() { + assert!(matches!(Style::from_wire("anime-cool"), Style::Cool)); + } + + #[test] + fn style_from_wire_studious() { + assert!(matches!(Style::from_wire("anime-studious"), Style::Studious)); + } } diff --git a/_primitives/_rust/kei-cortex/src/fal_pipeline.rs b/_primitives/_rust/kei-cortex/src/fal_pipeline.rs new file mode 100644 index 0000000..aaf8677 --- /dev/null +++ b/_primitives/_rust/kei-cortex/src/fal_pipeline.rs @@ -0,0 +1,173 @@ +//! fal.ai HTTP pipeline steps — upload, enqueue, poll, download. +//! +//! Each step is a single HTTPS round-trip. All use the shared `HTTP_CLIENT` +//! from `http_helpers` (no per-request `reqwest::Client::new()`). +//! Response bodies are read through `read_capped` so a runaway upstream +//! cannot allocate unbounded memory. + +use crate::fal::{Error, Style}; +use crate::fal::{BODY_PREVIEW_CAP, ERROR_BODY_CAP, IMAGE_BODY_CAP}; +use crate::fal_ssrf::validate_fal_url; +use crate::http_helpers::{read_capped, CappedReadError, HTTP_CLIENT}; + +pub(crate) const UPLOAD_INITIATE_URL: &str = + "https://rest.alpha.fal.ai/storage/upload/initiate"; +pub(crate) const GENERATION_URL: &str = + "https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra"; + +/// Step 1 — ask fal storage for a signed PUT URL, then PUT the image to it. +pub(crate) async fn upload_image(key: &str, bytes: &[u8]) -> Result { + let body = + serde_json::json!({ "file_name": "portrait.png", "content_type": "image/png" }); + let resp = HTTP_CLIENT + .post(UPLOAD_INITIATE_URL) + .header("Authorization", format!("Key {key}")) + .json(&body) + .send() + .await?; + let json = decode_json(resp).await?; + let upload_url = json + .get("upload_url") + .and_then(|v| v.as_str()) + .ok_or(Error::BadShape("upload_url"))?; + let file_url = json + .get("file_url") + .and_then(|v| v.as_str()) + .ok_or(Error::BadShape("file_url"))?; + validate_fal_url(upload_url)?; + validate_fal_url(file_url)?; + HTTP_CLIENT + .put(upload_url) + .header("Content-Type", "image/png") + .body(bytes.to_vec()) + .send() + .await? + .error_for_status() + .map_err(Error::Http)?; + Ok(file_url.to_string()) +} + +/// Step 2 — POST the generation request, return the validated status URL. +pub(crate) async fn enqueue( + key: &str, + image_url: &str, + style: Style, +) -> Result { + let body = serde_json::json!({ + "image_url": image_url, + "prompt": style.prompt(), + "strength": 0.65, + "enable_safety_checker": true, + }); + let resp = HTTP_CLIENT + .post(GENERATION_URL) + .header("Authorization", format!("Key {key}")) + .json(&body) + .send() + .await?; + let json = decode_json(resp).await?; + let status_url = json + .get("status_url") + .and_then(|v| v.as_str()) + .map(str::to_string) + .ok_or(Error::BadShape("status_url"))?; + validate_fal_url(&status_url)?; + Ok(status_url) +} + +/// Step 3 — poll `status_url` until `COMPLETED`, or give up at `deadline`. +pub(crate) async fn poll_until_done( + key: &str, + status_url: &str, + deadline: tokio::time::Instant, + poll_interval: std::time::Duration, +) -> Result { + loop { + if tokio::time::Instant::now() >= deadline { + return Err(Error::Timeout); + } + let resp = HTTP_CLIENT + .get(status_url) + .header("Authorization", format!("Key {key}")) + .send() + .await?; + let json = decode_json(resp).await?; + match json.get("status").and_then(|v| v.as_str()).unwrap_or("") { + "COMPLETED" => return extract_first_image_url(&json), + s @ ("FAILED" | "ERROR") => { + return Err(Error::BadStatus(502, format!("fal reported {s}"))); + } + _ => {} + } + tokio::time::sleep(poll_interval).await; + } +} + +/// Step 4 — download the PNG bytes from a validated fal CDN URL. +pub(crate) async fn download_image(key: &str, url: &str) -> Result, Error> { + validate_fal_url(url)?; + let resp = HTTP_CLIENT + .get(url) + .header("Authorization", format!("Key {key}")) + .send() + .await?; + if !resp.status().is_success() { + return Err(Error::BadStatus(resp.status().as_u16(), "download failed".into())); + } + match read_capped(resp, IMAGE_BODY_CAP).await { + Ok(b) => Ok(b), + Err(CappedReadError::TooLarge) => { + Err(Error::BadStatus(502, "image body exceeded 64 MiB cap".into())) + } + Err(CappedReadError::Http(e)) => Err(Error::Http(e)), + } +} + +/// Dig out `images[0].url` from the completed status payload, then validate. +pub(crate) fn extract_first_image_url(json: &serde_json::Value) -> Result { + let url = json + .get("images") + .and_then(|a| a.as_array()) + .and_then(|a| a.first()) + .and_then(|o| o.get("url")) + .and_then(|v| v.as_str()) + .map(str::to_string) + .ok_or(Error::BadShape("images[0].url"))?; + validate_fal_url(&url)?; + Ok(url) +} + +/// Decode a fal response into JSON, capping error/success bodies. +pub(crate) async fn decode_json( + resp: reqwest::Response, +) -> Result { + let status = resp.status(); + if !status.is_success() { + let raw = read_capped(resp, ERROR_BODY_CAP).await.unwrap_or_default(); + let body = String::from_utf8_lossy(&raw).into_owned(); + return Err(Error::BadStatus( + status.as_u16(), + truncate(&body, BODY_PREVIEW_CAP), + )); + } + let raw = match read_capped(resp, IMAGE_BODY_CAP).await { + Ok(b) => b, + Err(CappedReadError::TooLarge) => { + return Err(Error::BadStatus(502, "response body exceeded 64 MiB cap".into())); + } + Err(CappedReadError::Http(e)) => return Err(Error::Http(e)), + }; + serde_json::from_slice(&raw).map_err(|e| Error::BadStatus(502, e.to_string())) +} + +/// Cap a string at `max` bytes on a char boundary. +fn truncate(s: &str, max: usize) -> String { + if s.len() <= max { + return s.to_string(); + } + let mut end = max; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + s[..end].to_string() +} diff --git a/_primitives/_rust/kei-cortex/src/fal_ssrf.rs b/_primitives/_rust/kei-cortex/src/fal_ssrf.rs new file mode 100644 index 0000000..d226afd --- /dev/null +++ b/_primitives/_rust/kei-cortex/src/fal_ssrf.rs @@ -0,0 +1,67 @@ +//! SSRF mitigation for fal.ai outbound requests. +//! +//! `validate_fal_url` rejects any URL whose scheme is not `https://` or +//! whose host does not match the fal.ai domain allowlist before the URL +//! is ever passed to `reqwest`. + +use crate::fal::Error; +use once_cell::sync::Lazy; +use regex::Regex; + +/// Compiled allowlist for fal.ai host names. +/// Accepted: `.fal.ai`, `.fal.media`, `.fal.run`. +/// Example: `rest.alpha.fal.ai`, `queue.fal.run`, `cdn.fal.media`. +static FAL_HOST_RE: Lazy = + Lazy::new(|| Regex::new(r"^([a-z0-9-]+\.)+fal\.(ai|media|run)$").unwrap()); + +/// Reject URLs that do not match the fal.ai HTTPS allowlist. +/// +/// Returns `Err(Error::NonFalUrl)` when: +/// - the scheme is not `https://`, or +/// - the host does not match `^[a-z0-9-]+\.fal\.(ai|media|run)$`. +pub(crate) fn validate_fal_url(url: &str) -> Result<(), Error> { + let stripped = url.strip_prefix("https://").ok_or_else(|| { + Error::NonFalUrl(url.chars().take(80).collect()) + })?; + let host = stripped.split('/').next().unwrap_or(""); + if !FAL_HOST_RE.is_match(host) { + return Err(Error::NonFalUrl(url.chars().take(80).collect())); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accepts_valid_fal_urls() { + assert!(validate_fal_url("https://rest.alpha.fal.ai/storage/upload/initiate").is_ok()); + assert!(validate_fal_url("https://queue.fal.run/fal-ai/flux-pro/v1.1-ultra").is_ok()); + assert!(validate_fal_url("https://cdn.fal.media/image.png").is_ok()); + } + + #[test] + fn rejects_http_scheme() { + assert!(matches!( + validate_fal_url("http://queue.fal.run/foo"), + Err(Error::NonFalUrl(_)) + )); + } + + #[test] + fn rejects_non_fal_domain() { + assert!(matches!( + validate_fal_url("https://evil.example.com/steal"), + Err(Error::NonFalUrl(_)) + )); + } + + #[test] + fn rejects_ssrf_internal_ip() { + assert!(matches!( + validate_fal_url("https://169.254.169.254/latest/meta-data/"), + Err(Error::NonFalUrl(_)) + )); + } +} diff --git a/_primitives/_rust/kei-cortex/src/handlers/chat.rs b/_primitives/_rust/kei-cortex/src/handlers/chat.rs index ac2237c..aaa3f5c 100644 --- a/_primitives/_rust/kei-cortex/src/handlers/chat.rs +++ b/_primitives/_rust/kei-cortex/src/handlers/chat.rs @@ -97,10 +97,20 @@ fn validate_provider(state: &AppState, name: &str) -> Result<(), AppError> { } } +/// Character ceiling for chat messages. Prevents runaway prompt injection +/// and upstream token cost abuse. +const MAX_MESSAGE_CHARS: usize = 50_000; + fn validate_body(req: &ChatRequest) -> Result<(), AppError> { if req.message.is_empty() { return Err(AppError::BadRequest("message is empty".into())); } + let chars = req.message.chars().count(); + if chars > MAX_MESSAGE_CHARS { + return Err(AppError::PayloadTooLarge(format!( + "{chars} chars > {MAX_MESSAGE_CHARS}" + ))); + } Ok(()) } diff --git a/_primitives/_rust/kei-cortex/src/http_helpers.rs b/_primitives/_rust/kei-cortex/src/http_helpers.rs new file mode 100644 index 0000000..1a4c597 --- /dev/null +++ b/_primitives/_rust/kei-cortex/src/http_helpers.rs @@ -0,0 +1,57 @@ +//! Shared HTTP utilities: a process-wide `reqwest::Client` and a capped +//! response-body reader. +//! +//! A single `reqwest::Client` is reused for all outbound calls to avoid +//! exhausting OS connection-table entries when the daemon is under load. +//! The client is initialized once via `once_cell::sync::Lazy`. + +use once_cell::sync::Lazy; + +/// The process-wide HTTP client. Shared by `anthropic`, `anthropic_invoker`, +/// `elevenlabs`, and `fal` so that connection pooling is effective. +pub static HTTP_CLIENT: Lazy = Lazy::new(|| { + reqwest::Client::builder() + .build() + .expect("reqwest::Client::builder().build() — no TLS config error expected") +}); + +/// Error returned when a response body exceeds the caller-supplied cap. +#[derive(Debug)] +pub struct BodyTooLarge; + +impl std::fmt::Display for BodyTooLarge { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "response body exceeded size cap") + } +} + +impl std::error::Error for BodyTooLarge {} + +/// Read a response body up to `max_bytes`. Returns `Err(BodyTooLarge)` if +/// the stream exceeds the cap; the partial buffer is discarded on overflow +/// so the caller does not accidentally use truncated data. +pub async fn read_capped( + resp: reqwest::Response, + max_bytes: usize, +) -> Result, CappedReadError> { + use futures::StreamExt; + let mut bytes: Vec = Vec::new(); + let mut stream = resp.bytes_stream(); + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(CappedReadError::Http)?; + if bytes.len() + chunk.len() > max_bytes { + return Err(CappedReadError::TooLarge); + } + bytes.extend_from_slice(&chunk); + } + Ok(bytes) +} + +/// Combined error for `read_capped`. +#[derive(Debug, thiserror::Error)] +pub enum CappedReadError { + #[error("response body exceeded size cap")] + TooLarge, + #[error("http stream error: {0}")] + Http(reqwest::Error), +} diff --git a/_primitives/_rust/kei-cortex/src/lib.rs b/_primitives/_rust/kei-cortex/src/lib.rs index f9f017a..4c7c9a8 100644 --- a/_primitives/_rust/kei-cortex/src/lib.rs +++ b/_primitives/_rust/kei-cortex/src/lib.rs @@ -20,11 +20,15 @@ pub mod context; pub mod elevenlabs; pub mod error; pub mod fal; +pub(crate) mod fal_pipeline; +pub(crate) mod fal_ssrf; pub mod handlers; +pub mod http_helpers; pub mod persona; pub mod whisper_local; pub mod rig_clone; pub mod routes; +pub(crate) mod routes_auth; pub mod sentiment; pub mod state; pub mod tool; diff --git a/_primitives/_rust/kei-cortex/src/routes.rs b/_primitives/_rust/kei-cortex/src/routes.rs index d5dd680..40a18f1 100644 --- a/_primitives/_rust/kei-cortex/src/routes.rs +++ b/_primitives/_rust/kei-cortex/src/routes.rs @@ -1,59 +1,47 @@ -//! Router assembly + bearer-token middleware + CORS layer. +//! Router assembly + CORS layer. //! //! `/healthz` is mounted OUTSIDE the auth middleware so monitors can hit it -//! without a token. Everything under `/api` goes through `require_bearer`. +//! without a token. Everything under `/api` goes through `require_bearer` +//! (defined in `routes_auth`). //! //! Per-route concurrency caps protect us from a runaway client draining our //! upstream budget — `fal.ai` in particular bills per run, so we cap //! `/portrait/stylize` at 2 concurrent installs system-wide. Other expensive //! routes (`/tts`, `/stt`, `/chat`) get matching caps tuned to their bottleneck. -use crate::auth::tokens_match; -use crate::error::AppError; use crate::handlers::{ chat, fs_list, health, ledger, memory, pet, portrait, stt, summary, term, tool_apply, tts, usage, }; use crate::state::AppState; use axum::error_handling::HandleErrorLayer; -use axum::extract::{DefaultBodyLimit, Request, State}; +use axum::extract::DefaultBodyLimit; use axum::http::{header, HeaderValue, Method, StatusCode}; -use axum::middleware::{self, Next}; -use axum::response::{IntoResponse, Response}; +use axum::middleware; +use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::Router; -// OpenAI-compatible /v1/* surface. Lives under `routes/openai/` as a -// sibling tree to this file; declared via `#[path]` so the directory -// can coexist with `routes.rs` without becoming an inline `routes/mod.rs`. #[path = "routes/openai/mod.rs"] pub mod openai; + use tower::buffer::BufferLayer; use tower::limit::ConcurrencyLimitLayer; use tower::{BoxError, ServiceBuilder}; use tower_http::cors::CorsLayer; -/// Upper bound on the `/portrait/stylize` multipart body. The handler enforces -/// the stricter 10 MiB cap on the `file` field; this is just the pre-parse -/// gate that Axum applies before it can see individual fields. -const PORTRAIT_BODY_LIMIT: usize = 12 * 1024 * 1024; +// --- Body limits (per-route pre-parse gates) -------------------------------- +const PORTRAIT_BODY_LIMIT: usize = 12 * 1024 * 1024; // 12 MiB (handler re-checks 10) +const STT_BODY_LIMIT: usize = 26 * 1024 * 1024; // 26 MiB (handler re-checks 25) +const CHAT_BODY_LIMIT: usize = 256 * 1024; // 256 KiB +const TOOL_APPLY_BODY_LIMIT: usize = 11 * 1024 * 1024; // 11 MiB (handler checks 10) +const INTERACTION_BODY_LIMIT: usize = 64 * 1024; // 64 KiB +const TTS_BODY_LIMIT: usize = 32 * 1024; // 32 KiB -/// Upper bound on the `/stt` multipart body. The handler enforces the -/// stricter 25 MiB cap on the `audio` field itself; this is the slack so -/// that field headers + form overhead do not trip the pre-parse gate. -const STT_BODY_LIMIT: usize = 26 * 1024 * 1024; - -/// Max concurrent Flux stylize runs system-wide. fal.ai bills per run. +// --- Concurrency budgets ---------------------------------------------------- const PORTRAIT_CONCURRENCY: usize = 2; - -/// Max concurrent ElevenLabs TTS calls. const TTS_CONCURRENCY: usize = 4; - -/// Max concurrent whisper worker runs. CPU-bound, so the cap matches a -/// conservative small-laptop core count. const STT_CONCURRENCY: usize = 2; - -/// Max concurrent Anthropic chat streams. const CHAT_CONCURRENCY: usize = 8; /// Build the top-level router. `cors_origin` must have been validated at @@ -62,43 +50,11 @@ pub fn build_router(state: AppState) -> Router { let cors = build_cors(state.config().cors_origin.as_str()) .expect("cors_origin must be valid — validated in AppConfig::new"); - // Per-route granular caps proved fragile with axum 0.7's MethodRouter - // layer bounds (HandleErrorLayer + ConcurrencyLimitLayer service is not - // `Clone` in a way that layer() accepts). Apply a single router-wide cap - // via route_layer — it wraps every inner route uniformly without the - // per-method error-type headache. The cap is the SUM of the per-route - // budgets (2+4+2+8 = 16), which is a strict upper bound on simultaneous - // expensive work. Finer-grained token-bucket per-route can land later via - // tower-governor if a multi-user deployment appears. - let api = Router::new() - .route("/api/v1/cortex/summary", get(summary::summary)) - .route("/api/v1/cortex/pet/:user_id", get(pet::get_pet)) - .route( - "/api/v1/cortex/pet/:user_id/interaction", - post(pet::post_interaction), - ) - .route("/api/v1/cortex/pet/:user_id/chat", post(chat::chat)) - .route( - "/api/v1/cortex/pet/:user_id/portrait/stylize", - post(portrait::stylize).layer(DefaultBodyLimit::max(PORTRAIT_BODY_LIMIT)), - ) - .route( - "/api/v1/cortex/stt", - post(stt::transcribe).layer(DefaultBodyLimit::max(STT_BODY_LIMIT)), - ) - .route( - "/api/v1/cortex/pet/:user_id/tts", - post(tts::synthesize), - ) - .route("/api/v1/cortex/ledger/recent", get(ledger::recent)) - .route("/api/v1/cortex/memory/search", get(memory::search_memory)) - .route("/api/v1/cortex/usage", get(usage::usage)) - .route("/api/v1/cortex/fs/list", get(fs_list::list)) - .route("/api/v1/cortex/tool/apply", post(tool_apply::apply)) - .route("/api/v1/cortex/term", get(term::ws_handler)) + let api = build_api_router(); + let api = api .route_layer(middleware::from_fn_with_state( state.clone(), - require_bearer, + crate::routes_auth::require_bearer, )) .layer( ServiceBuilder::new() @@ -119,6 +75,43 @@ pub fn build_router(state: AppState) -> Router { .with_state(state) } +/// Assemble the protected API sub-router (no auth layer yet — applied by caller). +fn build_api_router() -> Router { + Router::new() + .route("/api/v1/cortex/summary", get(summary::summary)) + .route("/api/v1/cortex/pet/:user_id", get(pet::get_pet)) + .route( + "/api/v1/cortex/pet/:user_id/interaction", + post(pet::post_interaction) + .layer(DefaultBodyLimit::max(INTERACTION_BODY_LIMIT)), + ) + .route( + "/api/v1/cortex/pet/:user_id/chat", + post(chat::chat).layer(DefaultBodyLimit::max(CHAT_BODY_LIMIT)), + ) + .route( + "/api/v1/cortex/pet/:user_id/portrait/stylize", + post(portrait::stylize).layer(DefaultBodyLimit::max(PORTRAIT_BODY_LIMIT)), + ) + .route( + "/api/v1/cortex/stt", + post(stt::transcribe).layer(DefaultBodyLimit::max(STT_BODY_LIMIT)), + ) + .route( + "/api/v1/cortex/pet/:user_id/tts", + post(tts::synthesize).layer(DefaultBodyLimit::max(TTS_BODY_LIMIT)), + ) + .route("/api/v1/cortex/ledger/recent", get(ledger::recent)) + .route("/api/v1/cortex/memory/search", get(memory::search_memory)) + .route("/api/v1/cortex/usage", get(usage::usage)) + .route("/api/v1/cortex/fs/list", get(fs_list::list)) + .route( + "/api/v1/cortex/tool/apply", + post(tool_apply::apply).layer(DefaultBodyLimit::max(TOOL_APPLY_BODY_LIMIT)), + ) + .route("/api/v1/cortex/term", get(term::ws_handler)) +} + /// Build the CORS layer locked to a single origin. fn build_cors(origin: &str) -> Result { let origin_header: HeaderValue = origin @@ -131,60 +124,3 @@ fn build_cors(origin: &str) -> Result { .allow_credentials(true)) } -/// Bearer-token middleware. -/// -/// Two acceptable transports — checked in order: -/// 1. `Authorization: Bearer ` — standard HTTP requests. -/// 2. `Sec-WebSocket-Protocol: bearer, ` — WS upgrade only, -/// because browsers cannot set the Authorization header on a -/// `new WebSocket(url, [...protocols])` call. -/// -/// Missing → 401; mismatch → 403. -async fn require_bearer( - State(state): State, - req: Request, - next: Next, -) -> Result { - let token = state.token().to_string(); - if let Some(got) = bearer_from_authorization(&req) { - return finish_auth(&token, &got, req, next).await; - } - if let Some(got) = bearer_from_websocket_subprotocol(&req) { - return finish_auth(&token, &got, req, next).await; - } - Err(AppError::Unauthorized) -} - -/// Pull `` from `Authorization: Bearer ` if present. -fn bearer_from_authorization(req: &Request) -> Option { - let v = req.headers().get(header::AUTHORIZATION)?.to_str().ok()?; - Some(v.strip_prefix("Bearer ")?.trim().to_string()) -} - -/// Pull `` from `Sec-WebSocket-Protocol: bearer, `. The -/// browser's `new WebSocket(url, ['bearer', tok])` produces this header. -fn bearer_from_websocket_subprotocol(req: &Request) -> Option { - let v = req - .headers() - .get("sec-websocket-protocol")? - .to_str() - .ok()?; - let mut parts = v.split(',').map(str::trim); - if parts.next()? != "bearer" { - return None; - } - Some(parts.next()?.to_string()) -} - -/// Compare expected vs. supplied; on match, call `next`. -async fn finish_auth( - expected: &str, - got: &str, - req: Request, - next: Next, -) -> Result { - if !tokens_match(expected, got) { - return Err(AppError::Forbidden); - } - Ok(next.run(req).await) -} diff --git a/_primitives/_rust/kei-cortex/src/routes_auth.rs b/_primitives/_rust/kei-cortex/src/routes_auth.rs new file mode 100644 index 0000000..a873512 --- /dev/null +++ b/_primitives/_rust/kei-cortex/src/routes_auth.rs @@ -0,0 +1,69 @@ +//! Bearer-token middleware for the cortex API router. +//! +//! Separated from `routes.rs` so each file stays under 200 LOC. + +use crate::auth::tokens_match; +use crate::error::AppError; +use crate::state::AppState; +use axum::extract::{Request, State}; +use axum::http::header; +use axum::middleware::Next; +use axum::response::Response; + +/// Bearer-token middleware. +/// +/// Two acceptable transports — checked in order: +/// 1. `Authorization: Bearer ` — standard HTTP requests. +/// 2. `Sec-WebSocket-Protocol: bearer, ` — WS upgrade only, +/// because browsers cannot set the Authorization header on a +/// `new WebSocket(url, [...protocols])` call. +/// +/// Missing → 401; mismatch → 403. +pub async fn require_bearer( + State(state): State, + req: Request, + next: Next, +) -> Result { + let token = state.token().to_string(); + if let Some(got) = bearer_from_authorization(&req) { + return finish_auth(&token, &got, req, next).await; + } + if let Some(got) = bearer_from_websocket_subprotocol(&req) { + return finish_auth(&token, &got, req, next).await; + } + Err(AppError::Unauthorized) +} + +/// Pull `` from `Authorization: Bearer ` if present. +fn bearer_from_authorization(req: &Request) -> Option { + let v = req.headers().get(header::AUTHORIZATION)?.to_str().ok()?; + Some(v.strip_prefix("Bearer ")?.trim().to_string()) +} + +/// Pull `` from `Sec-WebSocket-Protocol: bearer, `. The +/// browser's `new WebSocket(url, ['bearer', tok])` produces this header. +fn bearer_from_websocket_subprotocol(req: &Request) -> Option { + let v = req + .headers() + .get("sec-websocket-protocol")? + .to_str() + .ok()?; + let mut parts = v.split(',').map(str::trim); + if parts.next()? != "bearer" { + return None; + } + Some(parts.next()?.to_string()) +} + +/// Compare expected vs. supplied; on match, call `next`. +async fn finish_auth( + expected: &str, + got: &str, + req: Request, + next: Next, +) -> Result { + if !tokens_match(expected, got) { + return Err(AppError::Forbidden); + } + Ok(next.run(req).await) +}