diff --git a/daemon/ipc.rs b/daemon/ipc.rs new file mode 100644 index 0000000..ff98950 --- /dev/null +++ b/daemon/ipc.rs @@ -0,0 +1,139 @@ +use std::{ + collections::{BTreeMap, HashMap}, + net::SocketAddr, + os::linux::net::SocketAddrExt, +}; + +use crate::socket::{create_niri_socket, tell}; +use anyhow::{Error, Result, anyhow}; +use niri_ipc::{Event, Request}; +use niri_tag::{TagCmd, TagEvent}; +use nix::unistd::geteuid; + +use smol::{ + channel::{self}, + future, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::unix::{UnixListener, UnixStream}, + stream::StreamExt, +}; + +#[allow(unreachable_code)] +pub async fn event_consumer(tx: channel::Sender) -> Result<()> { + tracing::debug!("creating event consumer"); + let mut socket = create_niri_socket().await?; + tracing::debug!("requesting event stream"); + let mut buf = String::new(); + tell(&mut socket, Request::EventStream).await?; + tracing::debug!("beginning event loop"); + loop { + let _ = socket.read_line(&mut buf).await?; + let event: Event = serde_json::from_str(&buf)?; + tracing::debug!("event: {:?}", event); + tx.send(event).await?; + } + unreachable!("Listener loop ended"); +} + +async fn create_provider_socket(name: &'static str, socket: &'static str) -> Result { + let sock_path = format!("/run/user/{}/{}.sock", geteuid(), socket); + if smol::fs::metadata(&sock_path).await.is_ok() { + tracing::debug!("removing old {} socket", name); + smol::fs::remove_file(&sock_path).await?; + } + tracing::debug!("establishing {} socket", name); + UnixListener::bind(&sock_path) + .inspect_err(|f| tracing::error!("failed to listen on {} socket: {}", name, f)) + .map_err(|e| anyhow!(e)) +} + +#[allow(unreachable_code)] +pub async fn ipc_provider(tx: channel::Sender) -> Result<()> { + tracing::debug!("creating ipc provider"); + let mut buf = String::new(); + let listen = create_provider_socket("ipc provider", "niri-tag").await?; + tracing::debug!("beginning ipc provider loop"); + loop { + let (raw_sock, _addr) = listen.accept().await?; + let mut socket = BufReader::new(raw_sock); + tracing::debug!("awaiting ipc socket"); + let _ = socket.read_line(&mut buf).await?; + tracing::debug!("forwarding ipc command {}", buf); + tx.send(serde_json::from_str(&buf)?).await?; + socket.close().await?; + } + tracing::error!("IPC loop ended"); + unreachable!("IPC loop ended"); + Ok(()) +} + +#[derive(Debug)] +enum EventsReceivable { + Event(TagEvent), + Conn(String, UnixStream), +} + +#[allow(unreachable_code)] +pub async fn event_provider(rx: channel::Receiver) -> Result<()> { + tracing::debug!("creating event provider"); + let listen = create_provider_socket("event provider", "niri-tag-events").await?; + let mut sockets = BTreeMap::new(); + loop { + use EventsReceivable::*; + let recvd: EventsReceivable = future::or( + async { rx.recv().await.map(Event).map_err(|e| anyhow!(e)) }, + async { + listen + .accept() + .await + .map(|(s, a)| Conn(format!("{:?}", a), s)) + .map_err(|e| anyhow!(e)) + }, + ) + .await?; + tracing::debug!("event provider received {:?}", recvd); + + enum Res { + Ok, + BadSockets(Vec), + } + let res = match recvd { + Conn(addr, socket) => { + sockets.insert(addr, socket); + Res::Ok + } + Event(e) => { + let data = serde_json::to_string(&e).map_err(|e| anyhow!(e))?; + let conns = smol::stream::iter(sockets.iter_mut()); + let send_all: Vec<(&String, Result<(), _>)> = conns + .then(async |(a, s)| (a, s.write_all(&[data.as_bytes(), b"\n"].concat()).await)) + .collect() + .await; + // .for_each(f); + let bad = send_all + .into_iter() + .fold(Vec::new(), |mut acc, (addr, res)| { + if let Err(e) = res { + tracing::warn!("error on event provider socket {}: {}", addr, e); + acc.push(addr.to_owned()); + } + acc + }); + if bad.len() > 0 { + Res::BadSockets(bad) + } else { + Res::Ok + } + } + }; + match res { + Res::BadSockets(bad) => bad.into_iter().for_each(|b| { + sockets.remove(&b); + }), + _ => (), + } + } + tracing::debug!("beginning ipc provider loop"); + + Ok(()) +} diff --git a/daemon/listeners.rs b/daemon/listeners.rs deleted file mode 100644 index 1fd1439..0000000 --- a/daemon/listeners.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::socket::{create_niri_socket, tell}; -use anyhow::Result; -use niri_ipc::{Event, Request}; -use niri_tag::TagCmd; -use nix::unistd::geteuid; -use smol::{ - channel::{self}, - io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, - net::unix::UnixListener, -}; - -#[allow(unreachable_code)] -pub async fn event_listener(tx: channel::Sender) -> Result<()> { - tracing::debug!("creating listener socket"); - let mut socket = create_niri_socket().await?; - tracing::debug!("requesting event stream"); - let mut buf = String::new(); - tell(&mut socket, Request::EventStream).await?; - tracing::debug!("beginning event loop"); - loop { - let _ = socket.read_line(&mut buf).await?; - let event: Event = serde_json::from_str(&buf)?; - tracing::debug!("event: {:?}", event); - tx.send(event).await?; - } - unreachable!("Listener loop ended"); -} - -#[allow(unreachable_code)] -pub async fn ipc_listener(tx: channel::Sender) -> Result<()> { - tracing::debug!("creating niri-tag socket"); - let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid()); - if smol::fs::metadata(&sock_path).await.is_ok() { - tracing::debug!("removing old niri-tag socket"); - smol::fs::remove_file(&sock_path).await?; - } - tracing::debug!("establishing niri-tag socket connection"); - let listen = UnixListener::bind(&sock_path) - .inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?; - let mut buf = String::new(); - loop { - let (raw_sock, _addr) = listen.accept().await?; - let mut socket = BufReader::new(raw_sock); - tracing::debug!("awaiting ipc socket"); - let _ = socket.read_line(&mut buf).await?; - tracing::debug!("forwarding ipc command {}", buf); - tx.send(serde_json::from_str(&buf)?).await?; - socket.close().await?; - } - tracing::error!("IPC loop ended"); - unreachable!("IPC loop ended"); - Ok(()) -} diff --git a/daemon/main.rs b/daemon/main.rs index bafaa5a..7509e60 100644 --- a/daemon/main.rs +++ b/daemon/main.rs @@ -1,4 +1,4 @@ -mod listeners; +mod ipc; mod manager; mod socket; @@ -13,18 +13,21 @@ fn main() -> Result<()> { .init(); let span = tracing::span!(tracing::Level::DEBUG, "main"); let _ = span.enter(); - // spawn socket listener for niri event stream - let (event_tx, event_rx) = smol::channel::unbounded(); - smol::spawn(listeners::event_listener(event_tx)).detach(); + // spawn socket consumer for niri event stream + let (niri_tx, niri_rx) = smol::channel::unbounded(); + smol::spawn(ipc::event_consumer(niri_tx)).detach(); // spawn socket listener for ipc let (ipc_tx, ipc_rx) = smol::channel::unbounded(); - smol::spawn(listeners::ipc_listener(ipc_tx)).detach(); + smol::spawn(ipc::ipc_provider(ipc_tx)).detach(); + // spawn socket listener for events + let (event_tx, event_rx) = smol::channel::unbounded(); + smol::spawn(ipc::event_provider(event_rx)).detach(); // begin managing niri tags smol::block_on(async { - let niri_tag = manager::NiriTag::new() + let niri_tag = manager::NiriTag::new(event_tx) .await .context("Initialising niri tag manager") .unwrap(); - niri_tag.manage_tags(event_rx, ipc_rx).await + niri_tag.manage_tags(niri_rx, ipc_rx).await }) } diff --git a/daemon/manager.rs b/daemon/manager.rs index 5f0a05d..ac3bd42 100644 --- a/daemon/manager.rs +++ b/daemon/manager.rs @@ -4,9 +4,9 @@ use niri_ipc::{ Action, Event, Reply, Request, Response, Window, Workspace, WorkspaceReferenceArg, state::{EventStreamState, EventStreamStatePart}, }; -use niri_tag::TagCmd; +use niri_tag::{TagCmd, TagEvent}; use smol::{ - channel::{self}, + channel::{self, Sender}, future, io::BufReader, net::unix::UnixStream, @@ -18,6 +18,7 @@ pub struct NiriTag { windows: HashMap, state: EventStreamState, socket: BufReader, + ev_tx: channel::Sender, } enum TagAction { @@ -26,12 +27,13 @@ enum TagAction { } impl NiriTag { - pub async fn new() -> Result { + pub async fn new(ev_tx: channel::Sender) -> Result { Ok(Self { tags: HashMap::new(), windows: HashMap::new(), state: EventStreamState::default(), socket: create_niri_socket().await?, + ev_tx, }) } @@ -196,6 +198,36 @@ impl NiriTag { async fn handle_recvd(&mut self, recvd: Receivable) -> Result<()> { use TagAction::*; + let send_event = async |tx: Sender, ev| { + smol::spawn(async move { + tx.send(ev) + .await + .inspect_err(|e| tracing::error!("Failed to send event: {}", e)) + }) + .detach(); + }; + let add_tag = async |tx: Sender, windows: &HashMap, t| { + if windows + .iter() + .filter(|(_, tag)| **tag == t) + .collect::>() + .is_empty() + { + send_event(tx, TagEvent::TagOccupied(t)).await; + } + }; + let rm_tag = async |tx: Sender, windows: &HashMap, wid, old_tag| { + if old_tag != 0 + && windows + .iter() + .filter(|(w, tag)| **tag == old_tag && **w != wid) + .collect::>() + .is_empty() + { + send_event(tx, TagEvent::TagEmpty(old_tag)).await; + } + }; + // first do any local mutations let action: TagAction = match recvd { Receivable::Event(ev) => { @@ -208,13 +240,17 @@ impl NiriTag { let wid = win.id; self.windows.insert(wid, t); tracing::debug!("adding tag {} to {}", t, wid); + let tx = self.ev_tx.clone(); + add_tag(tx, &self.windows, t).await; ChangeWindow(wid) } TagCmd::RemoveTagFromWin(_) => { let win = self.get_focused_window().await?; let wid = win.id; - self.windows.insert(wid, 0); + let old_tag = self.windows.insert(wid, 0).unwrap_or(0); tracing::debug!("resetting tag on {}", wid); + let tx = self.ev_tx.clone(); + rm_tag(tx, &self.windows, wid, old_tag).await; ChangeWindow(wid) } TagCmd::ToggleTagOnWin(t) => { @@ -223,6 +259,12 @@ impl NiriTag { tracing::debug!("{} has tag {:?}", wid, self.windows.get(&wid)); let this_tag = *self.windows.entry(wid).or_insert(0); let toggle = if this_tag == t { 0 } else { t }; + let tx = self.ev_tx.clone(); + if toggle == 0 { + rm_tag(tx, &self.windows, wid, this_tag).await; + } else { + add_tag(tx, &self.windows, t).await; + } tracing::debug!("toggling {} to tag {}", wid, toggle); self.windows.insert(wid, toggle); ChangeWindow(wid) @@ -230,14 +272,21 @@ impl NiriTag { TagCmd::EnableTag(t) => { self.tags.insert(t, true); + send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await; ChangeTag(t) } TagCmd::DisableTag(t) => { self.tags.insert(t, false); + send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await; ChangeTag(t) } TagCmd::ToggleTag(t) => { let visible = *self.tags.entry(t).or_insert(false); + if visible { + send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await; + } else { + send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await; + } tracing::debug!("toggling tag {} to {}", t, !visible); self.tags.insert(t, !visible); ChangeTag(t) @@ -283,7 +332,7 @@ impl NiriTag { tag_rx.recv().await.map(Receivable::TagCmd) }) .await?; - tracing::debug!("received {:?}", recvd); + tracing::debug!("manager received {:?}", recvd); let res = self.handle_recvd(recvd).await; match res { diff --git a/lib/main.rs b/lib/main.rs index b4eff8b..7841a4e 100644 --- a/lib/main.rs +++ b/lib/main.rs @@ -8,4 +8,15 @@ pub enum TagCmd { DisableTag(u8), ToggleTagOnWin(u8), ToggleTag(u8), + // TODO + // ExclusiveTag(u8), +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum TagEvent { + TagEmpty(u8), + TagOccupied(u8), + TagUrgent(u8), + TagEnabled(u8), + TagDisabled(u8), }