ML-API/src/endpoints.rs

123 lines
4.0 KiB
Rust

use axum::{
response::IntoResponse,
http::StatusCode,
extract::{ Json, State },
};
use crate::{models::ModelResponse, schemas::InputMetric};
use crate::models::{ApiSessionConfig, OutputModel};
use std::sync::Arc;
type Config = Arc<ApiSessionConfig>;
pub mod openapi {
use super::{IntoResponse, StatusCode};
pub async fn swagger() -> impl IntoResponse { (StatusCode::NOT_IMPLEMENTED, "still in development ...") }
}
pub mod rest {
use tracing::trace;
use super::{IntoResponse, StatusCode, InputMetric, Json, Config, State};
#[utoipa::path(
get,
path = "/api/metrics/rest",
responses(
(status = 200, description = "Model successfully processed all given data and"),
(status = 500, description = "Some errors with model")
),
params(
("metrics" = Vec<InputMetric>, Path, description = "Metrics list to work with"),
)
)]
pub async fn model_rest_handler(
State(config) : State<Config>,
Json(req) : Json<Vec<InputMetric>>,
) -> impl IntoResponse {
trace!("GET on /api/metrics/rest");
return match super::send_message_to_model(config.clone(), req).await {
Ok(resp) => (StatusCode::OK, resp),
Err(er) => (StatusCode::INTERNAL_SERVER_ERROR, format!("cannot get model response: {er}")),
}
}
}
pub mod ws {
use axum::extract::{ws::{Message, WebSocket}, WebSocketUpgrade};
use tracing::trace;
use super::{IntoResponse, InputMetric, Config, State};
pub async fn model_ws_handler(
State(config): State<Config>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
trace!("working with new WebSocket connection ...");
ws.on_upgrade(|socket| model_ws_worker(config, socket))
}
async fn model_ws_worker(
config: Config,
mut ws: WebSocket,
) {
trace!("handling WebSocket connection ...");
let ws_reciever = tokio::spawn(async move {
while let Some(Ok(msg)) = ws.recv().await {
match msg.to_text() {
Err(er) => {
let _ = ws.send(Message::Text(format!("Cannot convert input message: {er}").into())).await;
},
Ok(msg) => {
match serde_json::from_str::<Vec<InputMetric>>(msg) {
Err(er) => {
let _ = ws.send(
Message::Text(
format!("Cannot convert input message: {er}")
.into()
)
).await;
},
Ok(req) => {
match super::send_message_to_model(config.clone(), req).await {
Ok(resp) => {
let _ = ws.send(Message::Text(resp.into()));
},
Err(er) => {
let _ = ws.send(Message::Text(format!("Cannot get model's response: {er}").into())).await;
},
}
},
}
},
}
}
});
let _ = ws_reciever.await;
}
}
async fn send_message_to_model(
config: Config,
req: Vec<InputMetric>
) -> anyhow::Result<String> {
let prompt = OutputModel::build(
&config.model_name,
format!("{} {}", serde_json::to_string(&req)?, &*crate::PROMPT)
);
let request = config
.client
// TODO: REMOVE CLONE()
.clone()
.post(&config.target_url)
.json(&prompt)
.timeout(tokio::time::Duration::from_secs(config.request_timeout as u64));
Ok(
request
.send()
.await?
.json::<ModelResponse>()
.await?
.response
)
}