This commit is contained in:
2025-07-24 17:22:50 +02:00
parent 732411cd50
commit 43e964b5a0
39 changed files with 1406 additions and 2908 deletions

275
src/bin/main.rs Normal file
View File

@@ -0,0 +1,275 @@
#![no_std]
#![no_main]
#![feature(type_alias_impl_trait)]
#![feature(impl_trait_in_assoc_type)]
use core::net::Ipv4Addr;
use core::str::FromStr;
use embassy_executor::Spawner;
use embassy_net::{Ipv4Cidr, Runner, Stack, StackResources, StaticConfigV4};
use embassy_time::{Duration, Timer};
use esp_hal::clock::CpuClock;
use esp_hal::gpio::{Output, OutputConfig};
use esp_hal::peripherals::{GPIO1, GPIO2, UART1};
use esp_hal::timer::systimer::SystemTimer;
use esp_hal::timer::timg::TimerGroup;
use esp_hal::uart::{Config, Uart};
use esp_println::logger::init_logger;
use esp_wifi::wifi::{
AccessPointConfiguration, Configuration, WifiController, WifiDevice, WifiEvent, WifiState,
};
use log::{debug, info};
use picoserve::routing::get;
use picoserve::{AppBuilder, AppRouter};
use static_cell::make_static;
#[panic_handler]
fn panic(_: &core::panic::PanicInfo) -> ! {
loop {}
}
extern crate alloc;
#[esp_hal_embassy::main]
async fn main(spawner: Spawner) {
// ------------------- init ---------------------------
let config = esp_hal::Config::default().with_cpu_clock(CpuClock::max());
let peripherals = esp_hal::init(config);
info!("starting up...");
esp_alloc::heap_allocator!(size: 72 * 1024);
let timer0 = SystemTimer::new(peripherals.SYSTIMER);
esp_hal_embassy::init(timer0.alarm0);
init_logger(log::LevelFilter::Debug);
let timer1 = TimerGroup::new(peripherals.TIMG0);
let mut rng = esp_hal::rng::Rng::new(peripherals.RNG);
debug!("set wlan antenna..");
let mut rf_switch = Output::new(
peripherals.GPIO3,
esp_hal::gpio::Level::Low,
OutputConfig::default(),
);
rf_switch.set_low();
Timer::after_secs(1).await;
let mut antenna_mode = Output::new(
peripherals.GPIO14,
esp_hal::gpio::Level::Low,
OutputConfig::default(),
);
antenna_mode.set_low();
Timer::after_secs(1).await;
// Setup wifi deivce
debug!("setup wifi..");
let esp_wifi_ctrl =
make_static!(esp_wifi::init(timer1.timer0, rng).unwrap());
let (controller, interfaces) = esp_wifi::wifi::new(esp_wifi_ctrl, peripherals.WIFI).unwrap();
// let wifi_interface = interfaces.sta;
let wifi_ap = interfaces.ap;
let gw_ip_addr_str = "192.168.2.1";
let gw_ip_addr = Ipv4Addr::from_str(gw_ip_addr_str).expect("failed to parse gateway ip");
let config = embassy_net::Config::ipv4_static(StaticConfigV4 {
address: Ipv4Cidr::new(gw_ip_addr, 24),
gateway: Some(gw_ip_addr),
dns_servers: Default::default(),
});
let seed = (rng.random() as u64) << 32 | rng.random() as u64;
// Init network stack
let (stack, runner) = embassy_net::new(
wifi_ap,
config,
make_static!(StackResources::<3>::new()),
seed,
);
debug!("Setup complete. Running network tasks");
spawner.spawn(connection(controller)).ok();
spawner.spawn(net_task(runner)).ok();
spawner.spawn(run_dhcp(stack, gw_ip_addr_str)).ok();
spawner
.spawn(rfid_reader_task(
peripherals.UART1,
peripherals.GPIO1,
peripherals.GPIO2,
))
.ok();
loop {
if stack.is_link_up() {
break;
}
Timer::after(Duration::from_millis(500)).await;
if stack.is_config_up() {
break;
}
Timer::after(Duration::from_millis(500)).await;
}
debug!("Starting webserver");
let app = make_static!(AppProps.build_app());
let config = make_static!(picoserve::Config::new(picoserve::Timeouts {
start_read_request: Some(Duration::from_secs(5)),
persistent_start_read_request: Some(Duration::from_secs(1)),
read_request: Some(Duration::from_secs(1)),
write: Some(Duration::from_secs(1)),
}));
let _ = spawner.spawn(webserver_task(0, stack, app, config));
}
struct AppProps;
impl AppBuilder for AppProps {
type PathRouter = impl picoserve::routing::PathRouter;
fn build_app(self) -> picoserve::Router<Self::PathRouter> {
picoserve::Router::new().route("/", get(|| async move { "Hello World" }))
}
}
#[embassy_executor::task]
async fn webserver_task(
id: usize,
stack: embassy_net::Stack<'static>,
app: &'static AppRouter<AppProps>,
config: &'static picoserve::Config<Duration>,
) -> ! {
let mut tcp_rx_buffer = [0u8; 1024];
let mut tcp_tx_buffer = [0u8; 1024];
let mut http_buffer = [0u8; 2048];
picoserve::listen_and_serve(
id,
app,
config,
stack,
80,
&mut tcp_rx_buffer,
&mut tcp_tx_buffer,
&mut http_buffer,
)
.await
}
#[embassy_executor::task]
async fn run_dhcp(stack: Stack<'static>, gw_ip_addr: &'static str) {
debug!("start dhcp");
use core::net::{Ipv4Addr, SocketAddrV4};
use edge_dhcp::{
io::{self, DEFAULT_SERVER_PORT},
server::{Server, ServerOptions},
};
use edge_nal::UdpBind;
use edge_nal_embassy::{Udp, UdpBuffers};
let ip = Ipv4Addr::from_str(gw_ip_addr).expect("dhcp task failed to parse gw ip");
let mut buf = [0u8; 1500];
let mut gw_buf = [Ipv4Addr::UNSPECIFIED];
let buffers = UdpBuffers::<3, 1024, 1024, 10>::new();
let unbound_socket = Udp::new(stack, &buffers);
let mut bound_socket = unbound_socket
.bind(core::net::SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::UNSPECIFIED,
DEFAULT_SERVER_PORT,
)))
.await
.unwrap();
loop {
_ = io::server::run(
&mut Server::<_, 64>::new_with_et(ip),
&ServerOptions::new(ip, Some(&mut gw_buf)),
&mut bound_socket,
&mut buf,
)
.await
.inspect_err(|e| log::warn!("DHCP server error: {e:?}"));
Timer::after(Duration::from_millis(500)).await;
}
}
#[embassy_executor::task]
async fn net_task(mut runner: Runner<'static, WifiDevice<'static>>) {
runner.run().await;
}
#[embassy_executor::task]
async fn connection(mut controller: WifiController<'static>) {
debug!("start connection task");
debug!("Device capabilities: {:?}", controller.capabilities());
loop {
match esp_wifi::wifi::wifi_state() {
WifiState::ApStarted => {
// wait until we're no longer connected
controller.wait_for_event(WifiEvent::ApStop).await;
Timer::after(Duration::from_millis(5000)).await
}
_ => {}
}
if !matches!(controller.is_started(), Ok(true)) {
let client_config = Configuration::AccessPoint(AccessPointConfiguration {
ssid: "esp-wifi".try_into().unwrap(),
..Default::default()
});
controller.set_configuration(&client_config).unwrap();
debug!("Starting wifi");
controller.start_async().await.unwrap();
debug!("Wifi started!");
}
}
}
#[embassy_executor::task]
async fn rfid_reader_task(uart1: UART1<'static>, gpio1: GPIO1<'static>, gpio2: GPIO2<'static>) {
debug!("init rfid reader..");
let uart1_block_result = Uart::new(uart1, Config::default().with_baudrate(9600));
let mut nfc_reader = match uart1_block_result {
Ok(block) => block.with_rx(gpio1).with_tx(gpio2).into_async(),
Err(e) => {
log::error!("Failed to initialize UART: {:?}", e);
return;
}
};
let mut uart_buffer = [0u8; 64];
loop {
debug!("Looking for NFC...");
match nfc_reader.read_async(&mut uart_buffer).await {
Ok(n) => {
let mut hex_str = heapless::String::<128>::new();
for byte in &uart_buffer[..n] {
core::fmt::Write::write_fmt(&mut hex_str, format_args!("{:02X} ", byte)).ok();
}
info!("Read {} bytes from UART: {}", n, hex_str);
}
Err(e) => {
log::error!("Error reading from UART: {:?}", e);
}
}
Timer::after(Duration::from_millis(200)).await;
}
}

View File

@@ -1,37 +0,0 @@
use anyhow::Result;
use rppal::pwm::{Channel, Polarity, Pwm};
use std::time::Duration;
use tokio::time::sleep;
use crate::hardware::Buzzer;
const DEFAULT_PWM_CHANNEL_BUZZER: Channel = Channel::Pwm0; //PWM0 = GPIO18/Physical pin 12
pub struct GPIOBuzzer {
pwm: Pwm,
}
impl GPIOBuzzer {
pub fn new_from_channel(channel: Channel) -> Result<Self, rppal::pwm::Error> {
// Enable with dummy values; we'll set frequency/duty in the tone method
let duty_cycle: f64 = 0.5;
let pwm = Pwm::with_frequency(channel, 1000.0, duty_cycle, Polarity::Normal, true)?;
pwm.disable()?;
Ok(GPIOBuzzer { pwm })
}
pub fn new_default() -> Result<Self, rppal::pwm::Error> {
Self::new_from_channel(DEFAULT_PWM_CHANNEL_BUZZER)
}
}
impl Buzzer for GPIOBuzzer {
async fn modulated_tone(&mut self, frequency_hz: f64, duration: Duration) -> Result<()> {
self.pwm.set_frequency(frequency_hz, 0.5)?; // 50% duty cycle (square wave)
self.pwm.enable()?;
sleep(duration).await;
self.pwm.disable()?;
Ok(())
}
}

View File

@@ -1,131 +0,0 @@
use anyhow::{Result, anyhow};
use log::{trace, warn};
use std::env;
use tokio::process::Command;
use crate::hardware::Hotspot;
const SSID: &str = "fwa";
const CON_NAME: &str = "fwa-hotspot";
const PASSWORD: &str = "a9LG2kUVrsRRVUo1";
const IPV4_ADDRES: &str = "192.168.4.1/24";
/// NetworkManager Hotspot
pub struct NMHotspot {
ssid: String,
con_name: String,
password: String,
ipv4: String,
}
impl NMHotspot {
pub fn new_from_env() -> Result<Self> {
let ssid = env::var("HOTSPOT_SSID").unwrap_or(SSID.to_owned());
let password = env::var("HOTSPOT_PW").unwrap_or_else(|_| {
warn!("HOTSPOT_PW not set. Using default password");
PASSWORD.to_owned()
});
if password.len() < 8 {
return Err(anyhow!("Hotspot password to short"));
}
Ok(NMHotspot {
ssid,
con_name: CON_NAME.to_owned(),
password,
ipv4: IPV4_ADDRES.to_owned(),
})
}
async fn create_hotspot(&self) -> Result<()> {
let cmd = Command::new("nmcli")
.args(["device", "wifi", "hotspot"])
.arg("con-name")
.arg(&self.con_name)
.arg("ssid")
.arg(&self.ssid)
.arg("password")
.arg(&self.password)
.output()
.await?;
trace!("nmcli (std): {}", String::from_utf8_lossy(&cmd.stdout));
trace!("nmcli (err): {}", String::from_utf8_lossy(&cmd.stderr));
if !cmd.status.success() {
return Err(anyhow!("nmcli command had non-zero exit code"));
}
let cmd = Command::new("nmcli")
.arg("connection")
.arg("modify")
.arg(&self.con_name)
.arg("ipv4.method")
.arg("shared")
.arg("ipv4.addresses")
.arg(&self.ipv4)
.output()
.await?;
if !cmd.status.success() {
return Err(anyhow!("nmcli command had non-zero exit code"));
}
Ok(())
}
/// Checks if the connection already exists
async fn exists(&self) -> Result<bool> {
let cmd = Command::new("nmcli")
.args(["connection", "show"])
.arg(&self.con_name)
.output()
.await?;
trace!("nmcli (std): {}", String::from_utf8_lossy(&cmd.stdout));
trace!("nmcli (err): {}", String::from_utf8_lossy(&cmd.stderr));
Ok(cmd.status.success())
}
}
impl Hotspot for NMHotspot {
async fn enable_hotspot(&self) -> Result<()> {
if !self.exists().await? {
self.create_hotspot().await?;
}
let cmd = Command::new("nmcli")
.args(["connection", "up"])
.arg(&self.con_name)
.output()
.await?;
trace!("nmcli (std): {}", String::from_utf8_lossy(&cmd.stdout));
trace!("nmcli (err): {}", String::from_utf8_lossy(&cmd.stderr));
if !cmd.status.success() {
return Err(anyhow!("nmcli command had non-zero exit code"));
}
Ok(())
}
async fn disable_hotspot(&self) -> Result<()> {
let cmd = Command::new("nmcli")
.args(["connection", "down"])
.arg(&self.con_name)
.output()
.await?;
trace!("nmcli (std): {}", String::from_utf8_lossy(&cmd.stdout));
trace!("nmcli (err): {}", String::from_utf8_lossy(&cmd.stderr));
if !cmd.status.success() {
return Err(anyhow!("nmcli command had non-zero exit code"));
}
Ok(())
}
}

View File

@@ -24,22 +24,3 @@ pub trait Buzzer {
) -> impl Future<Output = Result<()>> + std::marker::Send;
}
pub trait Hotspot {
fn enable_hotspot(&self) -> impl std::future::Future<Output = Result<()>> + std::marker::Send;
fn disable_hotspot(&self) -> impl std::future::Future<Output = Result<()>> + std::marker::Send;
}
/// Create a struct to manage the hotspot
/// Respects the `mock_pi` flag.
pub fn create_hotspot() -> Result<impl Hotspot> {
#[cfg(feature = "mock_pi")]
{
Ok(mock::MockHotspot {})
}
#[cfg(not(feature = "mock_pi"))]
{
hotspot::NMHotspot::new_from_env()
}
}

View File

@@ -1,33 +0,0 @@
use anyhow::Result;
use rppal::spi::{Bus, Mode, SlaveSelect, Spi};
use smart_leds::SmartLedsWrite;
use ws2812_spi::Ws2812;
use crate::hardware::StatusLed;
const SPI_CLOCK_SPEED: u32 = 3_800_000;
pub struct SpiLed {
controller: Ws2812<Spi>,
}
impl SpiLed {
pub fn new() -> Result<Self, rppal::spi::Error> {
let spi = Spi::new(Bus::Spi0, SlaveSelect::Ss0, SPI_CLOCK_SPEED, Mode::Mode0)?;
let controller = Ws2812::new(spi);
Ok(SpiLed { controller })
}
}
impl StatusLed for SpiLed {
fn turn_off(&mut self) -> Result<()> {
self.controller
.write(vec![rgb::RGB8::new(0, 0, 0)].into_iter())?;
Ok(())
}
fn turn_on(&mut self, color: rgb::RGB8) -> Result<()> {
self.controller.write(vec![color].into_iter())?;
Ok(())
}
}

1
src/lib.rs Normal file
View File

@@ -0,0 +1 @@
#![no_std]

View File

@@ -1,25 +0,0 @@
use std::env;
use log::LevelFilter;
use simplelog::{ConfigBuilder, SimpleLogger};
pub fn setup_logger() {
let log_level = env::var("LOG_LEVEL")
.ok()
.and_then(|level| level.parse::<LevelFilter>().ok())
.unwrap_or({
if cfg!(debug_assertions) {
LevelFilter::Debug
} else {
LevelFilter::Warn
}
});
let config = ConfigBuilder::new()
.set_target_level(LevelFilter::Off)
.set_location_level(LevelFilter::Off)
.set_thread_level(LevelFilter::Off)
.build();
let _ = SimpleLogger::init(log_level, config);
}

View File

@@ -1,192 +0,0 @@
#![allow(dead_code)]
use anyhow::Result;
use feedback::{Feedback, FeedbackImpl};
use log::{error, info, warn};
use std::{
env::{self, args},
sync::Arc,
time::Duration,
};
use tally_id::TallyID;
use tokio::{
fs,
signal::unix::{SignalKind, signal},
sync::{
Mutex,
broadcast::{self, Receiver, Sender},
},
try_join,
};
use webserver::start_webserver;
use crate::{hardware::{create_hotspot, Hotspot}, pm3::run_pm3, store::IDStore, webserver::{spawn_idle_watcher, ActivityNotifier}};
mod feedback;
mod hardware;
mod pm3;
mod logger;
mod tally_id;
mod webserver;
mod store;
const STORE_PATH: &str = "./data.json";
async fn run_webserver<H>(
store: Arc<Mutex<IDStore>>,
id_channel: Sender<String>,
hotspot: Arc<Mutex<H>>,
user_feedback: Arc<Mutex<FeedbackImpl>>,
) -> Result<()>
where
H: Hotspot + Send + Sync + 'static,
{
let activity_channel = spawn_idle_watcher(Duration::from_secs(60 * 30), move || {
info!("No activity on webserver. Disabling hotspot");
let cloned_hotspot = hotspot.clone();
let cloned_user_feedback = user_feedback.clone();
tokio::spawn(async move {
let _ = cloned_hotspot.lock().await.disable_hotspot().await;
cloned_user_feedback
.lock()
.await
.set_device_status(feedback::DeviceStatus::Ready);
});
});
let notifier = ActivityNotifier {
sender: activity_channel,
};
start_webserver(store, id_channel, notifier).await?;
Ok(())
}
async fn load_or_create_store() -> Result<IDStore> {
if fs::try_exists(STORE_PATH).await? {
info!("Loading data from file");
IDStore::new_from_json(STORE_PATH).await
} else {
info!("No data file found. Creating empty one.");
Ok(IDStore::new())
}
}
fn get_hotspot_enable_ids() -> Vec<TallyID> {
let hotspot_ids: Vec<TallyID> = env::var("HOTSPOT_IDS")
.map(|ids| ids.split(";").map(|id| TallyID(id.to_owned())).collect())
.unwrap_or_default();
if hotspot_ids.is_empty() {
warn!(
"HOTSPOT_IDS is not set or empty. You will not be able to activate the hotspot via a tally!"
);
}
hotspot_ids
}
async fn handle_ids_loop(
mut id_channel: Receiver<String>,
hotspot_enable_ids: Vec<TallyID>,
id_store: Arc<Mutex<IDStore>>,
hotspot: Arc<Mutex<impl Hotspot>>,
user_feedback: Arc<Mutex<FeedbackImpl>>,
) -> Result<()> {
while let Ok(tally_id_string) = id_channel.recv().await {
let tally_id = TallyID(tally_id_string);
if hotspot_enable_ids.contains(&tally_id) {
info!("Enableing hotspot");
let hotspot_enable_result = hotspot.lock().await.enable_hotspot().await;
match hotspot_enable_result {
Ok(_) => {
user_feedback
.lock()
.await
.set_device_status(feedback::DeviceStatus::HotspotEnabled);
}
Err(e) => {
error!("Hotspot: {e}");
}
}
// TODO: Should the ID be added anyway or ignored ?
}
if id_store.lock().await.add_id(tally_id) {
info!("Added new id to current day");
user_feedback.lock().await.success().await;
if let Err(e) = id_store.lock().await.export_json(STORE_PATH).await {
error!("Failed to save id store to file: {e}");
user_feedback.lock().await.failure().await;
// TODO: How to handle a failure to save ?
}
}
}
Ok(())
}
async fn enter_error_state(feedback: Arc<Mutex<FeedbackImpl>>, hotspot: Arc<Mutex<impl Hotspot>>) {
let _ = feedback.lock().await.activate_error_state().await;
let _ = hotspot.lock().await.enable_hotspot().await;
let mut sigterm = signal(SignalKind::terminate()).unwrap();
sigterm.recv().await;
}
#[tokio::main]
async fn main() -> Result<()> {
logger::setup_logger();
info!("Starting application");
let user_feedback = Arc::new(Mutex::new(Feedback::new()?));
let hotspot = Arc::new(Mutex::new(create_hotspot()?));
let error_flag_set = args().any(|e| e == "--error" || e == "-e");
if error_flag_set {
error!("Error flag set. Entering error state");
enter_error_state(user_feedback.clone(), hotspot).await;
return Ok(());
}
let store: Arc<Mutex<IDStore>> = Arc::new(Mutex::new(load_or_create_store().await?));
let hotspot_enable_ids = get_hotspot_enable_ids();
let (tx, rx) = broadcast::channel::<String>(32);
let sse_tx = tx.clone();
let pm3_handle = run_pm3(tx);
user_feedback.lock().await.startup().await;
let loop_handle = handle_ids_loop(
rx,
hotspot_enable_ids,
store.clone(),
hotspot.clone(),
user_feedback.clone(),
);
let webserver_handle = run_webserver(
store.clone(),
sse_tx,
hotspot.clone(),
user_feedback.clone(),
);
let run_result = try_join!(pm3_handle, loop_handle, webserver_handle);
if let Err(e) = run_result {
error!("Failed to run application: {e}");
return Err(e);
}
Ok(())
}

View File

@@ -1,4 +0,0 @@
mod runner;
mod parser;
pub use runner::run_pm3;

View File

@@ -1,10 +0,0 @@
use regex::Regex;
/// Parses the output of PM3 finds the read IDs
/// Example input: `[+] UID.... 3112B710`
pub fn parse_line(line: &str) -> Option<String> {
let regex = Regex::new(r"(?m)^\[\+\] UID.... (.*)$").unwrap();
let result = regex.captures(line);
result.map(|c| c.get(1).unwrap().as_str().to_owned())
}

View File

@@ -1,95 +0,0 @@
use anyhow::{Result, anyhow};
use log::{debug, info, trace, warn};
use std::env;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tokio::select;
use tokio::signal::unix::{SignalKind, signal};
use tokio::sync::broadcast;
/// Runs the pm3 binary and monitors it's output
/// The pm3 binary is ether set in the env var PM3_BIN or found in the path
/// The ouput is parsed and send via the `tx` channel
pub async fn run_pm3(tx: broadcast::Sender<String>) -> Result<()> {
kill_orphans().await;
let pm3_path = match env::var("PM3_BIN") {
Ok(path) => path,
Err(_) => {
info!("PM3_BIN not set. Using default value");
"pm3".to_owned()
}
};
let mut cmd = Command::new("stdbuf")
.arg("-oL")
.arg(pm3_path)
.arg("-c")
.arg("lf hitag reader -@")
.stdout(Stdio::piped())
.stderr(Stdio::null())
.stdin(Stdio::piped())
.spawn()?;
let stdout = cmd.stdout.take().ok_or(anyhow!("Failed to get stdout"))?;
let mut stdin = cmd.stdin.take().ok_or(anyhow!("Failed to get stdin"))?;
let mut reader = BufReader::new(stdout).lines();
let mut sigterm = signal(SignalKind::terminate())?;
let child_handle = tokio::spawn(async move {
let mut last_id: String = "".to_owned();
while let Some(line) = reader.next_line().await.unwrap_or(None) {
trace!("PM3: {line}");
if let Some(uid) = super::parser::parse_line(&line) {
if last_id == uid {
let _ = tx.send(uid.clone());
}
last_id = uid;
}
}
});
select! {
_ = child_handle => {}
_ = sigterm.recv() => {
debug!("Graceful shutdown of PM3");
let _ = stdin.write_all(b"\n").await;
let _ = stdin.flush().await;
}
};
let status = cmd.wait().await?;
// We use the exit code here because status.success() is false if the child was terminated by a
// signal
let code = status.code().unwrap_or(0);
if code == 0 {
Ok(())
} else {
Err(anyhow!("PM3 exited with a non-zero exit code: {code}"))
}
}
/// Kills any open pm3 instances
/// Also funny name. hehehe.
async fn kill_orphans() {
let kill_result = Command::new("pkill")
.arg("-KILL")
.arg("-x")
.arg("proxmark3")
.output()
.await;
match kill_result {
Ok(_) => {
debug!("Successfully killed orphaned pm3 instances");
}
Err(e) => {
warn!("Failed to kill pm3 orphans: {e} Continuing anyway");
}
}
}

View File

@@ -1,46 +0,0 @@
use std::time::Duration;
use log::error;
use rocket::{
Data, Request,
fairing::{Fairing, Info, Kind},
};
use tokio::{sync::mpsc, time::timeout};
pub struct ActivityNotifier {
pub sender: mpsc::Sender<()>,
}
#[rocket::async_trait]
impl Fairing for ActivityNotifier {
fn info(&self) -> Info {
Info {
name: "Keeps track of time since the last request",
kind: Kind::Request | Kind::Response,
}
}
async fn on_request(&self, _: &mut Request<'_>, _: &mut Data<'_>) {
error!("on_request");
let _ = self.sender.try_send(());
}
}
pub fn spawn_idle_watcher<F>(idle_duration: Duration, mut on_idle: F) -> mpsc::Sender<()>
where
F: FnMut() + Send + 'static,
{
let (tx, mut rx) = mpsc::channel::<()>(100);
tokio::spawn(async move {
loop {
let idle = timeout(idle_duration, rx.recv()).await;
if idle.is_err() {
// No activity received in the duration
on_idle();
}
}
});
tx
}

View File

@@ -1,6 +0,0 @@
mod server;
mod activity_fairing;
pub use activity_fairing::{ActivityNotifier,spawn_idle_watcher};
pub use server::start_webserver;

View File

@@ -1,142 +0,0 @@
use log::{error, info, warn};
use rocket::http::Status;
use rocket::response::stream::{Event, EventStream};
use rocket::serde::json::Json;
use rocket::{Config, Shutdown, State, post};
use rocket::{get, http::ContentType, response::content::RawHtml, routes};
use rust_embed::Embed;
use serde::Deserialize;
use std::borrow::Cow;
use std::env;
use std::ffi::OsStr;
use std::sync::Arc;
use tokio::select;
use tokio::sync::Mutex;
use tokio::sync::broadcast::Sender;
use crate::store::{IDMapping, IDStore, Name};
use crate::tally_id::TallyID;
use crate::webserver::ActivityNotifier;
#[derive(Embed)]
#[folder = "web/dist"]
struct Asset;
#[derive(Deserialize)]
struct NewMapping {
id: String,
name: Name,
}
pub async fn start_webserver(
store: Arc<Mutex<IDStore>>,
sse_broadcaster: Sender<String>,
fairing: ActivityNotifier,
) -> Result<(), rocket::Error> {
let port = match env::var("HTTP_PORT") {
Ok(port) => port.parse().unwrap_or_else(|_| {
warn!("Failed to parse HTTP_PORT. Using default 80");
80
}),
Err(_) => 80,
};
let config = Config {
address: "0.0.0.0".parse().unwrap(), // Listen on all interfaces
port,
..Config::default()
};
rocket::custom(config)
.attach(fairing)
.mount(
"/",
routes![
static_files,
index,
export_csv,
id_event,
get_mapping,
add_mapping
],
)
.manage(store)
.manage(sse_broadcaster)
.launch()
.await?;
Ok(())
}
#[get("/")]
fn index() -> Option<RawHtml<Cow<'static, [u8]>>> {
let asset = Asset::get("index.html")?;
Some(RawHtml(asset.data))
}
#[get("/<file..>")]
fn static_files(file: std::path::PathBuf) -> Option<(ContentType, Vec<u8>)> {
let filename = file.display().to_string();
let asset = Asset::get(&filename)?;
let content_type = file
.extension()
.and_then(OsStr::to_str)
.and_then(ContentType::from_extension)
.unwrap_or(ContentType::Bytes);
Some((content_type, asset.data.into_owned()))
}
#[get("/api/idevent")]
fn id_event(sse_broadcaster: &State<Sender<String>>, shutdown: Shutdown) -> EventStream![] {
let mut rx = sse_broadcaster.subscribe();
EventStream! {
loop {
select! {
msg = rx.recv() => {
if let Ok(id) = msg {
yield Event::data(id);
}
}
_ = &mut shutdown.clone() => {
// Shutdown signal received, exit the loop
break;
}
}
}
}
}
#[get("/api/csv")]
async fn export_csv(manager: &State<Arc<Mutex<IDStore>>>) -> Result<String, Status> {
info!("Exporting CSV");
match manager.lock().await.export_csv() {
Ok(csv) => Ok(csv),
Err(e) => {
error!("Failed to generate csv: {e}");
Err(Status::InternalServerError)
}
}
}
#[get("/api/mapping")]
async fn get_mapping(store: &State<Arc<Mutex<IDStore>>>) -> Json<IDMapping> {
Json(store.lock().await.mapping.clone())
}
#[post("/api/mapping", format = "json", data = "<new_mapping>")]
async fn add_mapping(store: &State<Arc<Mutex<IDStore>>>, new_mapping: Json<NewMapping>) -> Status {
if new_mapping.id.is_empty()
|| new_mapping.name.first.is_empty()
|| new_mapping.name.last.is_empty()
{
return Status::BadRequest;
}
store
.lock()
.await
.mapping
.add_mapping(TallyID(new_mapping.id.clone()), new_mapping.name.clone());
Status::Created
}