feat: proxy rerank api (#851)

pull/852/head
sigoden 1 month ago committed by GitHub
parent 69965466e6
commit e5cc194598
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -56,6 +56,7 @@ pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
let stop_server = server.run(listener).await?;
println!("Chat Completions API: http://{addr}/v1/chat/completions");
println!("Embeddings API: http://{addr}/v1/embeddings");
println!("Rerank API: http://{addr}/v1/rerank");
println!("LLM Playground: http://{addr}/playground");
println!("LLM Arena: http://{addr}/arena?num=2");
shutdown_signal().await;
@ -158,6 +159,8 @@ impl Server {
self.chat_completions(req).await
} else if path == "/v1/embeddings" {
self.embeddings(req).await
} else if path == "/v1/rerank" {
self.rerank(req).await
} else if path == "/v1/models" {
self.list_models()
} else if path == "/v1/roles" {
@ -498,6 +501,57 @@ impl Server {
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
let req_body = req.collect().await?.to_bytes();
let req_body: Value = serde_json::from_slice(&req_body)
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
debug!("rerank request: {req_body}");
let req_body = serde_json::from_value(req_body)
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
let RerankReqBody {
model: reranker_model_id,
documents,
query,
top_n,
} = req_body;
let top_n = top_n.unwrap_or(documents.len());
let config = Arc::new(RwLock::new(self.config.clone()));
let reranker_model = Model::retrieve_embedding(&config.read(), &reranker_model_id)?;
let client = init_client(&config, Some(reranker_model))?;
let data = client
.rerank(RerankData {
query,
documents: documents.clone(),
top_n,
})
.await?;
let results: Vec<_> = data
.into_iter()
.map(|v| {
json!({
"index": v.index,
"relevance_score": v.relevance_score,
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
})
})
.collect();
let output = json!({
"id": uuid::Uuid::new_v4().to_string(),
"results": results,
});
let res = Response::builder()
.header("Content-Type", "application/json")
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
Ok(res)
}
}
#[derive(Debug, Deserialize)]
@ -520,8 +574,8 @@ struct ChatCompletionsReqBody {
#[derive(Debug, Deserialize)]
struct EmbeddingsReqBody {
pub input: EmbeddingsReqBodyInput,
pub model: String,
input: EmbeddingsReqBodyInput,
model: String,
}
#[derive(Debug, Deserialize)]
@ -531,6 +585,14 @@ enum EmbeddingsReqBodyInput {
Multiple(Vec<String>),
}
#[derive(Debug, Deserialize)]
struct RerankReqBody {
documents: Vec<String>,
query: String,
model: String,
top_n: Option<usize>,
}
#[derive(Debug)]
enum ResEvent {
First(Option<String>),

Loading…
Cancel
Save