123 lines
4.0 KiB
Rust
123 lines
4.0 KiB
Rust
use axum::{
|
|
response::IntoResponse,
|
|
http::StatusCode,
|
|
extract::{ Json, State },
|
|
};
|
|
use crate::{models::ModelResponse, schemas::InputMetric};
|
|
use crate::models::{ApiSessionConfig, OutputModel};
|
|
use std::sync::Arc;
|
|
|
|
type Config = Arc<ApiSessionConfig>;
|
|
|
|
pub mod openapi {
|
|
use super::{IntoResponse, StatusCode};
|
|
|
|
pub async fn swagger() -> impl IntoResponse { (StatusCode::NOT_IMPLEMENTED, "still in development ...") }
|
|
}
|
|
|
|
pub mod rest {
|
|
use tracing::trace;
|
|
use super::{IntoResponse, StatusCode, InputMetric, Json, Config, State};
|
|
|
|
#[utoipa::path(
|
|
get,
|
|
path = "/api/metrics/rest",
|
|
responses(
|
|
(status = 200, description = "Model successfully processed all given data and"),
|
|
(status = 500, description = "Some errors with model")
|
|
),
|
|
params(
|
|
("metrics" = Vec<InputMetric>, Path, description = "Metrics list to work with"),
|
|
)
|
|
)]
|
|
pub async fn model_rest_handler(
|
|
State(config) : State<Config>,
|
|
Json(req) : Json<Vec<InputMetric>>,
|
|
) -> impl IntoResponse {
|
|
trace!("GET on /api/metrics/rest");
|
|
return match super::send_message_to_model(config.clone(), req).await {
|
|
Ok(resp) => (StatusCode::OK, resp),
|
|
Err(er) => (StatusCode::INTERNAL_SERVER_ERROR, format!("cannot get model response: {er}")),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub mod ws {
|
|
use axum::extract::{ws::{Message, WebSocket}, WebSocketUpgrade};
|
|
use tracing::trace;
|
|
use super::{IntoResponse, InputMetric, Config, State};
|
|
|
|
pub async fn model_ws_handler(
|
|
State(config): State<Config>,
|
|
ws: WebSocketUpgrade,
|
|
) -> impl IntoResponse {
|
|
trace!("working with new WebSocket connection ...");
|
|
ws.on_upgrade(|socket| model_ws_worker(config, socket))
|
|
}
|
|
|
|
async fn model_ws_worker(
|
|
config: Config,
|
|
mut ws: WebSocket,
|
|
) {
|
|
trace!("handling WebSocket connection ...");
|
|
let ws_reciever = tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = ws.recv().await {
|
|
match msg.to_text() {
|
|
Err(er) => {
|
|
let _ = ws.send(Message::Text(format!("Cannot convert input message: {er}").into())).await;
|
|
},
|
|
Ok(msg) => {
|
|
match serde_json::from_str::<Vec<InputMetric>>(msg) {
|
|
Err(er) => {
|
|
let _ = ws.send(
|
|
Message::Text(
|
|
format!("Cannot convert input message: {er}")
|
|
.into()
|
|
)
|
|
).await;
|
|
},
|
|
Ok(req) => {
|
|
match super::send_message_to_model(config.clone(), req).await {
|
|
Ok(resp) => {
|
|
let _ = ws.send(Message::Text(resp.into()));
|
|
},
|
|
Err(er) => {
|
|
let _ = ws.send(Message::Text(format!("Cannot get model's response: {er}").into())).await;
|
|
},
|
|
}
|
|
},
|
|
}
|
|
},
|
|
}
|
|
}
|
|
});
|
|
|
|
let _ = ws_reciever.await;
|
|
}
|
|
}
|
|
|
|
async fn send_message_to_model(
|
|
config: Config,
|
|
req: Vec<InputMetric>
|
|
) -> anyhow::Result<String> {
|
|
let prompt = OutputModel::build(
|
|
&config.model_name,
|
|
format!("{} {}", serde_json::to_string(&req)?, &*crate::PROMPT)
|
|
);
|
|
let request = config
|
|
.client
|
|
// TODO: REMOVE CLONE()
|
|
.clone()
|
|
.post(&config.target_url)
|
|
.json(&prompt)
|
|
.timeout(tokio::time::Duration::from_secs(config.request_timeout as u64));
|
|
|
|
Ok(
|
|
request
|
|
.send()
|
|
.await?
|
|
.json::<ModelResponse>()
|
|
.await?
|
|
.response
|
|
)
|
|
} |