zombienet_file_server/
main.rs

1#![allow(clippy::expect_fun_call)]
2use std::io;
3
4use axum::{
5    extract::{Path, Request, State},
6    http::StatusCode,
7    routing::{get, post},
8    Router,
9};
10use futures::TryStreamExt;
11use tokio::{fs::File, io::BufWriter, net::TcpListener};
12use tokio_util::io::StreamReader;
13use tower_http::services::ServeDir;
14use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
15
16#[derive(Clone)]
17struct AppState {
18    uploads_directory: String,
19}
20
21#[tokio::main]
22async fn main() {
23    let address =
24        std::env::var("LISTENING_ADDRESS").expect("LISTENING_ADDRESS env variable isn't defined");
25    let uploads_directory =
26        std::env::var("UPLOADS_DIRECTORY").expect("UPLOADS_DIRECTORY env variable isn't defined");
27
28    tracing_subscriber::registry()
29        .with(tracing_subscriber::fmt::layer())
30        .init();
31
32    tokio::fs::create_dir_all(&uploads_directory)
33        .await
34        .expect(&format!("failed to create '{uploads_directory}' directory"));
35
36    let app = Router::new()
37        .route("/", get(|| async { "Ok" }))
38        .route(
39            "/*file_path",
40            post(upload).get_service(ServeDir::new(&uploads_directory)),
41        )
42        .with_state(AppState { uploads_directory });
43
44    let listener = TcpListener::bind(&address)
45        .await
46        .expect(&format!("failed to listen on {address}"));
47    tracing::info!("file server started on {}", listener.local_addr().unwrap());
48    axum::serve(listener, app).await.unwrap()
49}
50
51async fn upload(
52    Path(file_path): Path<String>,
53    State(state): State<AppState>,
54    request: Request,
55) -> Result<(), (StatusCode, String)> {
56    if !path_is_valid(&file_path) {
57        return Err((StatusCode::BAD_REQUEST, "Invalid path".to_owned()));
58    }
59
60    async {
61        let path = std::path::Path::new(&state.uploads_directory).join(file_path);
62
63        if let Some(parent_dir) = path.parent() {
64            tokio::fs::create_dir_all(parent_dir).await?;
65        }
66
67        let stream = request.into_body().into_data_stream();
68        let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
69        let body_reader = StreamReader::new(body_with_io_error);
70        futures::pin_mut!(body_reader);
71
72        let mut file = BufWriter::new(File::create(&path).await?);
73        tokio::io::copy(&mut body_reader, &mut file).await?;
74
75        tracing::info!("created file '{}'", path.to_string_lossy());
76
77        Ok::<_, io::Error>(())
78    }
79    .await
80    .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))
81}
82
83fn path_is_valid(path: &str) -> bool {
84    let path = std::path::Path::new(path);
85    let mut components = path.components().peekable();
86
87    components.all(|component| matches!(component, std::path::Component::Normal(_)))
88}