use std::{ collections::{BTreeMap, HashMap}, net::SocketAddr, os::linux::net::SocketAddrExt, }; use crate::socket::{create_niri_socket, tell}; use anyhow::{Error, Result, anyhow}; use niri_ipc::{Event, Request}; use niri_tag::{TagCmd, TagEvent}; use nix::unistd::geteuid; use smol::{ channel::{self}, future, io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, net::unix::{UnixListener, UnixStream}, stream::StreamExt, }; #[allow(unreachable_code)] pub async fn event_consumer(tx: channel::Sender) -> Result<()> { tracing::debug!("creating event consumer"); let mut socket = create_niri_socket().await?; tracing::debug!("requesting event stream"); let mut buf = String::new(); tell(&mut socket, Request::EventStream).await?; tracing::debug!("beginning event loop"); loop { let _ = socket.read_line(&mut buf).await?; let event: Event = serde_json::from_str(&buf)?; tracing::debug!("event: {:?}", event); tx.send(event).await?; } unreachable!("Listener loop ended"); } async fn create_provider_socket(name: &'static str, socket: &'static str) -> Result { let sock_path = format!("/run/user/{}/{}.sock", geteuid(), socket); if smol::fs::metadata(&sock_path).await.is_ok() { tracing::debug!("removing old {} socket", name); smol::fs::remove_file(&sock_path).await?; } tracing::debug!("establishing {} socket", name); UnixListener::bind(&sock_path) .inspect_err(|f| tracing::error!("failed to listen on {} socket: {}", name, f)) .map_err(|e| anyhow!(e)) } #[allow(unreachable_code)] pub async fn ipc_provider(tx: channel::Sender) -> Result<()> { tracing::debug!("creating ipc provider"); let mut buf = String::new(); let listen = create_provider_socket("ipc provider", "niri-tag").await?; tracing::debug!("beginning ipc provider loop"); loop { let (raw_sock, _addr) = listen.accept().await?; let mut socket = BufReader::new(raw_sock); tracing::debug!("awaiting ipc socket"); let _ = socket.read_line(&mut buf).await?; tracing::debug!("forwarding ipc command {}", buf); tx.send(serde_json::from_str(&buf)?).await?; socket.close().await?; } tracing::error!("IPC loop ended"); unreachable!("IPC loop ended"); Ok(()) } #[derive(Debug)] enum EventsReceivable { Event(TagEvent), Conn(String, UnixStream), } #[allow(unreachable_code)] pub async fn event_provider(rx: channel::Receiver) -> Result<()> { tracing::debug!("creating event provider"); let listen = create_provider_socket("event provider", "niri-tag-events").await?; let mut sockets = BTreeMap::new(); loop { use EventsReceivable::*; let recvd: EventsReceivable = future::or( async { rx.recv().await.map(Event).map_err(|e| anyhow!(e)) }, async { listen .accept() .await .map(|(s, a)| Conn(format!("{:?}", a), s)) .map_err(|e| anyhow!(e)) }, ) .await?; tracing::debug!("event provider received {:?}", recvd); enum Res { Ok, BadSockets(Vec), } let res = match recvd { Conn(addr, socket) => { sockets.insert(addr, socket); Res::Ok } Event(e) => { let data = serde_json::to_string(&e).map_err(|e| anyhow!(e))?; let conns = smol::stream::iter(sockets.iter_mut()); let send_all: Vec<(&String, Result<(), _>)> = conns .then(async |(a, s)| (a, s.write_all(&[data.as_bytes(), b"\n"].concat()).await)) .collect() .await; let bad = send_all .into_iter() .fold(Vec::new(), |mut acc, (addr, res)| { if let Err(e) = res { tracing::warn!("error on event provider socket {}: {}", addr, e); acc.push(addr.to_owned()); } acc }); if !bad.is_empty() { Res::BadSockets(bad) } else { Res::Ok } } }; if let Res::BadSockets(bad) = res { bad.into_iter().for_each(|b| { sockets.remove(&b); }); } } tracing::debug!("beginning ipc provider loop"); Ok(()) }