feat: full poc

This commit is contained in:
atagen 2025-06-19 17:35:38 +10:00
parent b76036038e
commit a8847b93cf
14 changed files with 630 additions and 323 deletions

16
daemon/Cargo.toml Normal file
View file

@ -0,0 +1,16 @@
[package]
name = "niri-tag"
version = "0.1.0"
edition = "2024"
[dependencies]
libsystemd = "0.7"
niri-ipc = { path = "../niri/niri-ipc" }
nix = { version = "0.30", features = ["process", "user"] }
anyhow = "1.0"
tracing = "0.1"
tracing-subscriber = "0.3"
serde_json = "1.0"
smol = "2.0"
serde = { version = "1.0", features = ["derive"] }
niri-tag = { path = "../lib" }

53
daemon/listeners.rs Normal file
View file

@ -0,0 +1,53 @@
use crate::socket::{create_niri_socket, tell};
use anyhow::Result;
use niri_ipc::{Event, Request};
use niri_tag::TagCmd;
use nix::unistd::geteuid;
use smol::{
channel::{self},
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::UnixListener,
};
#[allow(unreachable_code)]
pub async fn event_listener(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating listener socket");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::debug!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
#[allow(unreachable_code)]
pub async fn ipc_listener(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating niri-tag socket");
let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid());
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old niri-tag socket");
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing niri-tag socket connection");
let listen = UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?;
let mut buf = String::new();
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command {}", buf);
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}

25
daemon/main.rs Normal file
View file

@ -0,0 +1,25 @@
mod listeners;
mod manager;
mod socket;
use anyhow::Result;
fn main() -> Result<()> {
// let systemd know we're ready
let _ = libsystemd::daemon::notify(false, &[libsystemd::daemon::NotifyState::Ready])?;
// debug stuff
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();
let span = tracing::span!(tracing::Level::DEBUG, "main");
let _ = span.enter();
// spawn socket listener for niri event stream
let (event_tx, event_rx) = smol::channel::unbounded();
smol::spawn(listeners::event_listener(event_tx)).detach();
// spawn socket listener for ipc
let (ipc_tx, ipc_rx) = smol::channel::unbounded();
smol::spawn(listeners::ipc_listener(ipc_tx)).detach();
// begin managing niri tags
let niri_tag = manager::NiriTag::new();
smol::block_on(niri_tag.manage_tags(event_rx, ipc_rx))
}

199
daemon/manager.rs Normal file
View file

@ -0,0 +1,199 @@
use crate::socket::{create_niri_socket, query, tell};
use anyhow::{Result, anyhow};
use niri_ipc::{
Action, Event, Reply, Request, Response, Window, WorkspaceReferenceArg,
state::{EventStreamState, EventStreamStatePart},
};
use niri_tag::TagCmd;
use smol::{
channel::{self},
future,
io::BufReader,
net::unix::UnixStream,
};
use std::collections::HashMap;
pub struct NiriTag {
tags: HashMap<u8, bool>,
windows: HashMap<u64, u8>,
}
impl NiriTag {
pub fn new() -> Self {
Self {
tags: HashMap::new(),
windows: HashMap::new(),
}
}
#[allow(unreachable_code)]
pub async fn manage_tags(
self,
ev_rx: channel::Receiver<Event>,
tag_rx: channel::Receiver<TagCmd>,
) -> Result<()> {
let mut state = EventStreamState::default();
let mut tags = NiriTag::new();
let mut socket = create_niri_socket().await?;
// base tag is always visible
tags.tags.insert(0, true);
async fn on_focused_window<O>(socket: &mut BufReader<UnixStream>, mut ok: O) -> Result<()>
where
O: FnMut(Option<Window>) -> Result<()>,
{
let q = query(socket, Request::FocusedWindow).await?;
if let Reply::Ok(Response::FocusedWindow(win)) = q {
ok(win)
} else {
Err(anyhow!(
"Invalid response from Niri when requesting FocusedWindow: {}",
if q.is_err() {
q.unwrap_err()
} else {
serde_json::to_string(&q.unwrap())?
}
))
}
}
loop {
let recvd: Receivable =
future::or(async { ev_rx.recv().await.map(Receivable::Event) }, async {
tag_rx.recv().await.map(Receivable::TagCmd)
})
.await?;
tracing::debug!("received {:?}", recvd);
let res = match recvd {
Receivable::Event(ev) => {
let _ = state.apply(ev.clone());
Ok(())
}
Receivable::TagCmd(cmd) => match cmd {
TagCmd::Add(t) => {
on_focused_window(&mut socket, |win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.insert(wid, t);
Ok(())
} else {
Err(anyhow!("No focused window to tag"))
}
})
.await
}
TagCmd::Remove(_) => {
on_focused_window(&mut socket, |win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.insert(wid, 0);
Ok(())
} else {
Err(anyhow!("No focused window to untag"))
}
})
.await
}
TagCmd::Toggle(t) => {
on_focused_window(&mut socket, |win| {
if let Some(win) = win {
let wid = win.id;
let toggle = if *tags.windows.get(&wid).unwrap_or(&0) == t {
0
} else {
t
};
tracing::debug!("toggling {} to tag {}", wid, toggle);
tags.windows.insert(wid, toggle);
Ok(())
} else {
Err(anyhow!("No focused window to untag"))
}
})
.await
}
TagCmd::Enable(t) => {
tags.tags.insert(t, true);
Ok(())
}
TagCmd::Disable(t) => {
tags.tags.insert(t, false);
Ok(())
}
TagCmd::ToggleWs(t) => {
tracing::debug!("toggling tag {}", t);
tags.tags.insert(t, !tags.tags.get(&t).unwrap_or(&false));
Ok(())
}
},
};
match res {
Ok(()) => (),
Err(e) => tracing::error!("error occurred in manager loop: {}", e),
}
// TODO: react selectively instead of brute forcing window state
// use Event::*;
// match ev {
// WorkspaceActivated { .. } => (),
// WorkspacesChanged { .. } => (),
// WorkspaceUrgencyChanged { .. } => (),
// WindowsChanged { .. } => (),
// WindowOpenedOrChanged { .. } => (),
// WindowUrgencyChanged { .. } => (),
// WindowClosed { .. } => (),
// _ => (),
// }
for (&wid, window) in state.windows.windows.iter() {
let (active, inactive): (Vec<_>, Vec<_>) = state
.workspaces
.workspaces
.iter()
.map(|(wsid, ws)| (wsid, ws.is_active))
.partition(|(_, a)| *a);
if let Some(wsid) = window.workspace_id {
if let Some(&window_tag) = tags.windows.get(&wid) {
if let Some(&tag_enabled) = tags.tags.get(&window_tag) {
if tag_enabled && inactive.contains(&(&wsid, false)) {
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Index(0),
focus: false,
}),
)
.await?;
tracing::debug!("making visible {}", wid);
} else if !tag_enabled && active.contains(&(&wsid, true)) {
let hidden = *inactive.first().unwrap().0;
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Id(hidden),
focus: false,
}),
)
.await?;
tracing::debug!("making hidden {}", wid);
}
} else {
tags.windows.insert(wid, 0);
}
}
}
}
}
tracing::error!("Manager loop ended");
unreachable!("Manager loop ended");
Ok(())
}
}
#[derive(Debug)]
enum Receivable {
Event(Event),
TagCmd(TagCmd),
}

37
daemon/socket.rs Normal file
View file

@ -0,0 +1,37 @@
use anyhow::{Context, Result, anyhow};
use niri_ipc::{Reply, Request, Response};
use smol::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::UnixStream,
};
use std::env;
pub async fn query(socket: &mut BufReader<UnixStream>, req: Request) -> Result<Reply> {
let req = serde_json::to_string(&req)?;
tracing::debug!("sending request: {}", req);
socket.write_all(&[req.as_bytes(), b"\n"].concat()).await?;
socket.flush().await?;
let mut rep = String::new();
socket.read_line(&mut rep).await?;
Ok(serde_json::from_str(&rep)?)
}
pub async fn tell(socket: &mut BufReader<UnixStream>, req: Request) -> Result<()> {
let rep = query(socket, req).await?;
if let Reply::Ok(Response::Handled) = rep {
Ok(())
} else {
Err(anyhow!(
"Expected Reply::Ok(Response::Handled), got {}",
rep.unwrap_err()
))
}
}
pub async fn create_niri_socket() -> Result<BufReader<UnixStream>> {
let socket_path = env::var(niri_ipc::socket::SOCKET_PATH_ENV)
.context("Couldn't find Niri socket path ($NIRI_SOCKET) in environment")?;
tracing::debug!("socket path is: {}", socket_path);
let raw = UnixStream::connect(&socket_path).await?;
Ok(BufReader::new(raw))
}