zombienet_file_server/
main.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#![allow(clippy::expect_fun_call)]
use std::io;

use axum::{
    extract::{Path, Request, State},
    http::StatusCode,
    routing::{get, post},
    Router,
};
use futures::TryStreamExt;
use tokio::{fs::File, io::BufWriter, net::TcpListener};
use tokio_util::io::StreamReader;
use tower_http::services::ServeDir;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[derive(Clone)]
struct AppState {
    uploads_directory: String,
}

#[tokio::main]
async fn main() {
    let address =
        std::env::var("LISTENING_ADDRESS").expect("LISTENING_ADDRESS env variable isn't defined");
    let uploads_directory =
        std::env::var("UPLOADS_DIRECTORY").expect("UPLOADS_DIRECTORY env variable isn't defined");

    tracing_subscriber::registry()
        .with(tracing_subscriber::fmt::layer())
        .init();

    tokio::fs::create_dir_all(&uploads_directory)
        .await
        .expect(&format!("failed to create '{uploads_directory}' directory"));

    let app = Router::new()
        .route("/", get(|| async { "Ok" }))
        .route(
            "/*file_path",
            post(upload).get_service(ServeDir::new(&uploads_directory)),
        )
        .with_state(AppState { uploads_directory });

    let listener = TcpListener::bind(&address)
        .await
        .expect(&format!("failed to listen on {address}"));
    tracing::info!("file server started on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap()
}

async fn upload(
    Path(file_path): Path<String>,
    State(state): State<AppState>,
    request: Request,
) -> Result<(), (StatusCode, String)> {
    if !path_is_valid(&file_path) {
        return Err((StatusCode::BAD_REQUEST, "Invalid path".to_owned()));
    }

    async {
        let path = std::path::Path::new(&state.uploads_directory).join(file_path);

        if let Some(parent_dir) = path.parent() {
            tokio::fs::create_dir_all(parent_dir).await?;
        }

        let stream = request.into_body().into_data_stream();
        let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
        let body_reader = StreamReader::new(body_with_io_error);
        futures::pin_mut!(body_reader);

        let mut file = BufWriter::new(File::create(&path).await?);
        tokio::io::copy(&mut body_reader, &mut file).await?;

        tracing::info!("created file '{}'", path.to_string_lossy());

        Ok::<_, io::Error>(())
    }
    .await
    .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))
}

fn path_is_valid(path: &str) -> bool {
    let path = std::path::Path::new(path);
    let mut components = path.components().peekable();

    components.all(|component| matches!(component, std::path::Component::Normal(_)))
}