feat: change socket nomenclature + implement events provider

This commit is contained in:
atagen 2025-06-21 17:20:01 +10:00
parent 389c4b3ee6
commit 243582307c
5 changed files with 214 additions and 65 deletions

139
daemon/ipc.rs Normal file
View File

@ -0,0 +1,139 @@
use std::{
collections::{BTreeMap, HashMap},
net::SocketAddr,
os::linux::net::SocketAddrExt,
};
use crate::socket::{create_niri_socket, tell};
use anyhow::{Error, Result, anyhow};
use niri_ipc::{Event, Request};
use niri_tag::{TagCmd, TagEvent};
use nix::unistd::geteuid;
use smol::{
channel::{self},
future,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
net::unix::{UnixListener, UnixStream},
stream::StreamExt,
};
#[allow(unreachable_code)]
pub async fn event_consumer(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating event consumer");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::debug!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
async fn create_provider_socket(name: &'static str, socket: &'static str) -> Result<UnixListener> {
let sock_path = format!("/run/user/{}/{}.sock", geteuid(), socket);
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old {} socket", name);
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing {} socket", name);
UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen on {} socket: {}", name, f))
.map_err(|e| anyhow!(e))
}
#[allow(unreachable_code)]
pub async fn ipc_provider(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating ipc provider");
let mut buf = String::new();
let listen = create_provider_socket("ipc provider", "niri-tag").await?;
tracing::debug!("beginning ipc provider loop");
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command {}", buf);
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}
#[derive(Debug)]
enum EventsReceivable {
Event(TagEvent),
Conn(String, UnixStream),
}
#[allow(unreachable_code)]
pub async fn event_provider(rx: channel::Receiver<TagEvent>) -> Result<()> {
tracing::debug!("creating event provider");
let listen = create_provider_socket("event provider", "niri-tag-events").await?;
let mut sockets = BTreeMap::new();
loop {
use EventsReceivable::*;
let recvd: EventsReceivable = future::or(
async { rx.recv().await.map(Event).map_err(|e| anyhow!(e)) },
async {
listen
.accept()
.await
.map(|(s, a)| Conn(format!("{:?}", a), s))
.map_err(|e| anyhow!(e))
},
)
.await?;
tracing::debug!("event provider received {:?}", recvd);
enum Res {
Ok,
BadSockets(Vec<String>),
}
let res = match recvd {
Conn(addr, socket) => {
sockets.insert(addr, socket);
Res::Ok
}
Event(e) => {
let data = serde_json::to_string(&e).map_err(|e| anyhow!(e))?;
let conns = smol::stream::iter(sockets.iter_mut());
let send_all: Vec<(&String, Result<(), _>)> = conns
.then(async |(a, s)| (a, s.write_all(&[data.as_bytes(), b"\n"].concat()).await))
.collect()
.await;
// .for_each(f);
let bad = send_all
.into_iter()
.fold(Vec::new(), |mut acc, (addr, res)| {
if let Err(e) = res {
tracing::warn!("error on event provider socket {}: {}", addr, e);
acc.push(addr.to_owned());
}
acc
});
if bad.len() > 0 {
Res::BadSockets(bad)
} else {
Res::Ok
}
}
};
match res {
Res::BadSockets(bad) => bad.into_iter().for_each(|b| {
sockets.remove(&b);
}),
_ => (),
}
}
tracing::debug!("beginning ipc provider loop");
Ok(())
}

View File

@ -1,53 +0,0 @@
use crate::socket::{create_niri_socket, tell};
use anyhow::Result;
use niri_ipc::{Event, Request};
use niri_tag::TagCmd;
use nix::unistd::geteuid;
use smol::{
channel::{self},
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::UnixListener,
};
#[allow(unreachable_code)]
pub async fn event_listener(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating listener socket");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::debug!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
#[allow(unreachable_code)]
pub async fn ipc_listener(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating niri-tag socket");
let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid());
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old niri-tag socket");
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing niri-tag socket connection");
let listen = UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?;
let mut buf = String::new();
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command {}", buf);
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}

View File

@ -1,4 +1,4 @@
mod listeners;
mod ipc;
mod manager;
mod socket;
@ -13,18 +13,21 @@ fn main() -> Result<()> {
.init();
let span = tracing::span!(tracing::Level::DEBUG, "main");
let _ = span.enter();
// spawn socket listener for niri event stream
let (event_tx, event_rx) = smol::channel::unbounded();
smol::spawn(listeners::event_listener(event_tx)).detach();
// spawn socket consumer for niri event stream
let (niri_tx, niri_rx) = smol::channel::unbounded();
smol::spawn(ipc::event_consumer(niri_tx)).detach();
// spawn socket listener for ipc
let (ipc_tx, ipc_rx) = smol::channel::unbounded();
smol::spawn(listeners::ipc_listener(ipc_tx)).detach();
smol::spawn(ipc::ipc_provider(ipc_tx)).detach();
// spawn socket listener for events
let (event_tx, event_rx) = smol::channel::unbounded();
smol::spawn(ipc::event_provider(event_rx)).detach();
// begin managing niri tags
smol::block_on(async {
let niri_tag = manager::NiriTag::new()
let niri_tag = manager::NiriTag::new(event_tx)
.await
.context("Initialising niri tag manager")
.unwrap();
niri_tag.manage_tags(event_rx, ipc_rx).await
niri_tag.manage_tags(niri_rx, ipc_rx).await
})
}

View File

@ -4,9 +4,9 @@ use niri_ipc::{
Action, Event, Reply, Request, Response, Window, Workspace, WorkspaceReferenceArg,
state::{EventStreamState, EventStreamStatePart},
};
use niri_tag::TagCmd;
use niri_tag::{TagCmd, TagEvent};
use smol::{
channel::{self},
channel::{self, Sender},
future,
io::BufReader,
net::unix::UnixStream,
@ -18,6 +18,7 @@ pub struct NiriTag {
windows: HashMap<u64, u8>,
state: EventStreamState,
socket: BufReader<UnixStream>,
ev_tx: channel::Sender<TagEvent>,
}
enum TagAction {
@ -26,12 +27,13 @@ enum TagAction {
}
impl NiriTag {
pub async fn new() -> Result<Self> {
pub async fn new(ev_tx: channel::Sender<TagEvent>) -> Result<Self> {
Ok(Self {
tags: HashMap::new(),
windows: HashMap::new(),
state: EventStreamState::default(),
socket: create_niri_socket().await?,
ev_tx,
})
}
@ -196,6 +198,36 @@ impl NiriTag {
async fn handle_recvd(&mut self, recvd: Receivable) -> Result<()> {
use TagAction::*;
let send_event = async |tx: Sender<TagEvent>, ev| {
smol::spawn(async move {
tx.send(ev)
.await
.inspect_err(|e| tracing::error!("Failed to send event: {}", e))
})
.detach();
};
let add_tag = async |tx: Sender<TagEvent>, windows: &HashMap<u64, u8>, t| {
if windows
.iter()
.filter(|(_, tag)| **tag == t)
.collect::<HashMap<_, _>>()
.is_empty()
{
send_event(tx, TagEvent::TagOccupied(t)).await;
}
};
let rm_tag = async |tx: Sender<TagEvent>, windows: &HashMap<u64, u8>, wid, old_tag| {
if old_tag != 0
&& windows
.iter()
.filter(|(w, tag)| **tag == old_tag && **w != wid)
.collect::<Vec<(_, _)>>()
.is_empty()
{
send_event(tx, TagEvent::TagEmpty(old_tag)).await;
}
};
// first do any local mutations
let action: TagAction = match recvd {
Receivable::Event(ev) => {
@ -208,13 +240,17 @@ impl NiriTag {
let wid = win.id;
self.windows.insert(wid, t);
tracing::debug!("adding tag {} to {}", t, wid);
let tx = self.ev_tx.clone();
add_tag(tx, &self.windows, t).await;
ChangeWindow(wid)
}
TagCmd::RemoveTagFromWin(_) => {
let win = self.get_focused_window().await?;
let wid = win.id;
self.windows.insert(wid, 0);
let old_tag = self.windows.insert(wid, 0).unwrap_or(0);
tracing::debug!("resetting tag on {}", wid);
let tx = self.ev_tx.clone();
rm_tag(tx, &self.windows, wid, old_tag).await;
ChangeWindow(wid)
}
TagCmd::ToggleTagOnWin(t) => {
@ -223,6 +259,12 @@ impl NiriTag {
tracing::debug!("{} has tag {:?}", wid, self.windows.get(&wid));
let this_tag = *self.windows.entry(wid).or_insert(0);
let toggle = if this_tag == t { 0 } else { t };
let tx = self.ev_tx.clone();
if toggle == 0 {
rm_tag(tx, &self.windows, wid, this_tag).await;
} else {
add_tag(tx, &self.windows, t).await;
}
tracing::debug!("toggling {} to tag {}", wid, toggle);
self.windows.insert(wid, toggle);
ChangeWindow(wid)
@ -230,14 +272,21 @@ impl NiriTag {
TagCmd::EnableTag(t) => {
self.tags.insert(t, true);
send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await;
ChangeTag(t)
}
TagCmd::DisableTag(t) => {
self.tags.insert(t, false);
send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await;
ChangeTag(t)
}
TagCmd::ToggleTag(t) => {
let visible = *self.tags.entry(t).or_insert(false);
if visible {
send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await;
} else {
send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await;
}
tracing::debug!("toggling tag {} to {}", t, !visible);
self.tags.insert(t, !visible);
ChangeTag(t)
@ -283,7 +332,7 @@ impl NiriTag {
tag_rx.recv().await.map(Receivable::TagCmd)
})
.await?;
tracing::debug!("received {:?}", recvd);
tracing::debug!("manager received {:?}", recvd);
let res = self.handle_recvd(recvd).await;
match res {

View File

@ -8,4 +8,15 @@ pub enum TagCmd {
DisableTag(u8),
ToggleTagOnWin(u8),
ToggleTag(u8),
// TODO
// ExclusiveTag(u8),
}
#[derive(Serialize, Deserialize, Debug)]
pub enum TagEvent {
TagEmpty(u8),
TagOccupied(u8),
TagUrgent(u8),
TagEnabled(u8),
TagDisabled(u8),
}