|
|
|
@ -56,6 +56,7 @@ pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
|
|
|
|
|
let stop_server = server.run(listener).await?;
|
|
|
|
|
println!("Chat Completions API: http://{addr}/v1/chat/completions");
|
|
|
|
|
println!("Embeddings API: http://{addr}/v1/embeddings");
|
|
|
|
|
println!("Rerank API: http://{addr}/v1/rerank");
|
|
|
|
|
println!("LLM Playground: http://{addr}/playground");
|
|
|
|
|
println!("LLM Arena: http://{addr}/arena?num=2");
|
|
|
|
|
shutdown_signal().await;
|
|
|
|
@ -158,6 +159,8 @@ impl Server {
|
|
|
|
|
self.chat_completions(req).await
|
|
|
|
|
} else if path == "/v1/embeddings" {
|
|
|
|
|
self.embeddings(req).await
|
|
|
|
|
} else if path == "/v1/rerank" {
|
|
|
|
|
self.rerank(req).await
|
|
|
|
|
} else if path == "/v1/models" {
|
|
|
|
|
self.list_models()
|
|
|
|
|
} else if path == "/v1/roles" {
|
|
|
|
@ -498,6 +501,57 @@ impl Server {
|
|
|
|
|
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
|
|
|
|
Ok(res)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async fn rerank(&self, req: hyper::Request<Incoming>) -> Result<AppResponse> {
|
|
|
|
|
let req_body = req.collect().await?.to_bytes();
|
|
|
|
|
let req_body: Value = serde_json::from_slice(&req_body)
|
|
|
|
|
.map_err(|err| anyhow!("Invalid request json, {err}"))?;
|
|
|
|
|
|
|
|
|
|
debug!("rerank request: {req_body}");
|
|
|
|
|
let req_body = serde_json::from_value(req_body)
|
|
|
|
|
.map_err(|err| anyhow!("Invalid request body, {err}"))?;
|
|
|
|
|
|
|
|
|
|
let RerankReqBody {
|
|
|
|
|
model: reranker_model_id,
|
|
|
|
|
documents,
|
|
|
|
|
query,
|
|
|
|
|
top_n,
|
|
|
|
|
} = req_body;
|
|
|
|
|
|
|
|
|
|
let top_n = top_n.unwrap_or(documents.len());
|
|
|
|
|
|
|
|
|
|
let config = Arc::new(RwLock::new(self.config.clone()));
|
|
|
|
|
|
|
|
|
|
let reranker_model = Model::retrieve_embedding(&config.read(), &reranker_model_id)?;
|
|
|
|
|
|
|
|
|
|
let client = init_client(&config, Some(reranker_model))?;
|
|
|
|
|
let data = client
|
|
|
|
|
.rerank(RerankData {
|
|
|
|
|
query,
|
|
|
|
|
documents: documents.clone(),
|
|
|
|
|
top_n,
|
|
|
|
|
})
|
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
|
|
let results: Vec<_> = data
|
|
|
|
|
.into_iter()
|
|
|
|
|
.map(|v| {
|
|
|
|
|
json!({
|
|
|
|
|
"index": v.index,
|
|
|
|
|
"relevance_score": v.relevance_score,
|
|
|
|
|
"document": documents.get(v.index).map(|v| json!(v)).unwrap_or_default(),
|
|
|
|
|
})
|
|
|
|
|
})
|
|
|
|
|
.collect();
|
|
|
|
|
let output = json!({
|
|
|
|
|
"id": uuid::Uuid::new_v4().to_string(),
|
|
|
|
|
"results": results,
|
|
|
|
|
});
|
|
|
|
|
let res = Response::builder()
|
|
|
|
|
.header("Content-Type", "application/json")
|
|
|
|
|
.body(Full::new(Bytes::from(output.to_string())).boxed())?;
|
|
|
|
|
Ok(res)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
@ -520,8 +574,8 @@ struct ChatCompletionsReqBody {
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
struct EmbeddingsReqBody {
|
|
|
|
|
pub input: EmbeddingsReqBodyInput,
|
|
|
|
|
pub model: String,
|
|
|
|
|
input: EmbeddingsReqBodyInput,
|
|
|
|
|
model: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
@ -531,6 +585,14 @@ enum EmbeddingsReqBodyInput {
|
|
|
|
|
Multiple(Vec<String>),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
struct RerankReqBody {
|
|
|
|
|
documents: Vec<String>,
|
|
|
|
|
query: String,
|
|
|
|
|
model: String,
|
|
|
|
|
top_n: Option<usize>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
|
enum ResEvent {
|
|
|
|
|
First(Option<String>),
|
|
|
|
|