use std::{ collections::{BTreeMap, HashMap}, net::SocketAddr, os::linux::net::SocketAddrExt, path::PathBuf, str::FromStr, }; use crate::socket::{create_niri_socket, tell}; use anyhow::{Error, Result, anyhow}; use microxdg::Xdg; use niri_ipc::{Event, Request}; use niri_tag::{TagCmd, TagEvent, TagState}; use nix::unistd::geteuid; use smol::{ channel::{self}, future, io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, net::unix::{UnixListener, UnixStream}, stream::StreamExt, }; #[allow(unreachable_code)] pub async fn event_consumer(tx: channel::Sender) -> Result<()> { tracing::debug!("creating event consumer"); let mut socket = create_niri_socket().await?; tracing::debug!("requesting event stream"); let mut buf = String::new(); tell(&mut socket, Request::EventStream).await?; tracing::debug!("beginning event loop"); loop { let _ = socket.read_line(&mut buf).await?; let event: Event = serde_json::from_str(&buf)?; tracing::trace!("event: {:?}", event); tx.send(event).await?; } unreachable!("Listener loop ended"); } fn get_run_path() -> Result { let xdg = Xdg::new()?; Ok(xdg .runtime()? .unwrap_or(PathBuf::from_str(&format!("/run/user/{}", geteuid()))?)) } async fn create_provider_socket(name: &'static str, socket: &'static str) -> Result { let mut sock_path = get_run_path()?; sock_path.push(format!("{}.sock", socket)); if smol::fs::metadata(&sock_path).await.is_ok() { tracing::debug!("removing old {} socket", name); smol::fs::remove_file(&sock_path).await?; } tracing::debug!("establishing {} socket", name); UnixListener::bind(&sock_path) .inspect_err(|f| tracing::error!("failed to listen on {} socket: {}", name, f)) .map_err(|e| anyhow!(e)) } #[allow(unreachable_code)] pub async fn ipc_provider(tx: channel::Sender) -> Result<()> { tracing::debug!("creating ipc provider"); let mut buf = String::new(); let listen = create_provider_socket("ipc provider", "niri-tag").await?; tracing::debug!("beginning ipc provider loop"); loop { let (raw_sock, _addr) = listen.accept().await?; let mut socket = BufReader::new(raw_sock); tracing::debug!("awaiting ipc socket"); let _ = socket.read_line(&mut buf).await?; tracing::debug!("forwarding ipc command {}", buf); tx.send(serde_json::from_str(&buf)?).await?; socket.close().await?; } tracing::error!("IPC loop ended"); unreachable!("IPC loop ended"); Ok(()) } #[derive(Debug)] enum EventsReceivable { Event(TagEvent), Conn(String, UnixStream), } #[allow(unreachable_code)] pub async fn event_provider( rx: channel::Receiver, fullstate_tx: channel::Sender>>, ) -> Result<()> { tracing::debug!("creating event provider"); let listen = create_provider_socket("event provider", "niri-tag-events").await?; let mut sockets = BTreeMap::new(); tracing::debug!("beginning event provider loop"); loop { use EventsReceivable::*; let recvd: EventsReceivable = future::or( async { rx.recv().await.map(Event).map_err(|e| anyhow!(e)) }, async { listen .accept() .await .map(|(s, a)| Conn(format!("{:?}", a), s)) .map_err(|e| anyhow!(e)) }, ) .await?; tracing::debug!("event provider received {:?}", recvd); enum Res { Ok, BadSockets(Vec), } let res = match recvd { Conn(addr, mut socket) => { tracing::debug!("received a new event provider connection"); sockets.insert(addr, socket.clone()); let (t, r) = smol::channel::bounded(1); tracing::debug!("sending fullstate request"); fullstate_tx.send(t).await?; match r.recv().await { Ok(fullstate) => { tracing::debug!("received fullstate, sending"); let data = serde_json::to_string(&TagEvent::TagFullState(fullstate))?; if let Err(e) = socket.write_all(&[data.as_bytes(), b"\n"].concat()).await { tracing::error!("Failed to send fullstate to socket: {}", e); } } Err(e) => tracing::error!("Failed to receive fullstate: {}", e), } Res::Ok } Event(e) => { let data = serde_json::to_string(&e).map_err(|e| anyhow!(e))?; let conns = smol::stream::iter(sockets.iter_mut()); let send_all: Vec<(&String, Result<(), _>)> = conns .then(async |(a, s)| (a, s.write_all(&[data.as_bytes(), b"\n"].concat()).await)) .collect() .await; let bad = send_all .into_iter() .fold(Vec::new(), |mut acc, (addr, res)| { if let Err(e) = res { tracing::warn!("error from event provider client {}: {}", addr, e); acc.push(addr.to_owned()); } acc }); if !bad.is_empty() { Res::BadSockets(bad) } else { Res::Ok } } }; if let Res::BadSockets(bad) = res { bad.into_iter().for_each(|b| { sockets.remove(&b); }); } } Ok(()) }