zombienet_file_server/
main.rs
1#![allow(clippy::expect_fun_call)]
2use std::io;
3
4use axum::{
5 extract::{Path, Request, State},
6 http::StatusCode,
7 routing::{get, post},
8 Router,
9};
10use futures::TryStreamExt;
11use tokio::{fs::File, io::BufWriter, net::TcpListener};
12use tokio_util::io::StreamReader;
13use tower_http::services::ServeDir;
14use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
15
16#[derive(Clone)]
17struct AppState {
18 uploads_directory: String,
19}
20
21#[tokio::main]
22async fn main() {
23 let address =
24 std::env::var("LISTENING_ADDRESS").expect("LISTENING_ADDRESS env variable isn't defined");
25 let uploads_directory =
26 std::env::var("UPLOADS_DIRECTORY").expect("UPLOADS_DIRECTORY env variable isn't defined");
27
28 tracing_subscriber::registry()
29 .with(tracing_subscriber::fmt::layer())
30 .init();
31
32 tokio::fs::create_dir_all(&uploads_directory)
33 .await
34 .expect(&format!("failed to create '{uploads_directory}' directory"));
35
36 let app = Router::new()
37 .route("/", get(|| async { "Ok" }))
38 .route(
39 "/*file_path",
40 post(upload).get_service(ServeDir::new(&uploads_directory)),
41 )
42 .with_state(AppState { uploads_directory });
43
44 let listener = TcpListener::bind(&address)
45 .await
46 .expect(&format!("failed to listen on {address}"));
47 tracing::info!("file server started on {}", listener.local_addr().unwrap());
48 axum::serve(listener, app).await.unwrap()
49}
50
51async fn upload(
52 Path(file_path): Path<String>,
53 State(state): State<AppState>,
54 request: Request,
55) -> Result<(), (StatusCode, String)> {
56 if !path_is_valid(&file_path) {
57 return Err((StatusCode::BAD_REQUEST, "Invalid path".to_owned()));
58 }
59
60 async {
61 let path = std::path::Path::new(&state.uploads_directory).join(file_path);
62
63 if let Some(parent_dir) = path.parent() {
64 tokio::fs::create_dir_all(parent_dir).await?;
65 }
66
67 let stream = request.into_body().into_data_stream();
68 let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
69 let body_reader = StreamReader::new(body_with_io_error);
70 futures::pin_mut!(body_reader);
71
72 let mut file = BufWriter::new(File::create(&path).await?);
73 tokio::io::copy(&mut body_reader, &mut file).await?;
74
75 tracing::info!("created file '{}'", path.to_string_lossy());
76
77 Ok::<_, io::Error>(())
78 }
79 .await
80 .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))
81}
82
83fn path_is_valid(path: &str) -> bool {
84 let path = std::path::Path::new(path);
85 let mut components = path.components().peekable();
86
87 components.all(|component| matches!(component, std::path::Component::Normal(_)))
88}