use std::{collections::HashMap, env}; use anyhow::{Context, Result, anyhow}; use niri_ipc::{ Action, Event, Reply, Request, Response, Window, WorkspaceReferenceArg, state::{EventStreamState, EventStreamStatePart}, }; use nix::unistd::{geteuid, getpid}; use serde::{Deserialize, Serialize}; use smol::{ channel::{self}, future, io::{AsyncBufReadExt, AsyncWriteExt, BufReader}, net::unix::{UnixListener, UnixStream}, }; struct NiriTag { tags: HashMap, windows: HashMap, } impl NiriTag { fn new() -> Self { Self { tags: HashMap::new(), windows: HashMap::new(), } } } async fn query(socket: &mut BufReader, req: Request) -> Result { let req = serde_json::to_string(&req)?; tracing::debug!("sending request: {}", req); socket.write_all(&[req.as_bytes(), b"\n"].concat()).await?; socket.flush().await?; let mut rep = String::new(); socket.read_line(&mut rep).await?; Ok(serde_json::from_str(&rep)?) } async fn tell(socket: &mut BufReader, req: Request) -> Result<()> { let rep = query(socket, req).await?; if let Reply::Ok(Response::Handled) = rep { Ok(()) } else { Err(anyhow!( "Expected Reply::Ok(Response::Handled), got {}", rep.unwrap_err() )) } } async fn create_niri_socket() -> Result> { let socket_path = env::var(niri_ipc::socket::SOCKET_PATH_ENV) .context("Couldn't find Niri socket path ($NIRI_SOCKET) in environment")?; tracing::debug!("socket path is: {}", socket_path); let raw = UnixStream::connect(&socket_path).await?; Ok(BufReader::new(raw)) } #[allow(unreachable_code)] async fn event_listener(tx: channel::Sender) -> Result<()> { tracing::debug!("creating listener socket"); let mut socket = create_niri_socket().await?; tracing::debug!("requesting event stream"); let mut buf = String::new(); tell(&mut socket, Request::EventStream).await?; tracing::debug!("beginning event loop"); loop { let _ = socket.read_line(&mut buf).await?; let event: Event = serde_json::from_str(&buf)?; tracing::debug!("event: {:?}", event); tx.send(event).await?; } unreachable!("Listener loop ended"); } #[derive(Serialize, Deserialize)] enum TagCmd { Add(u8), Remove(u8), Enable(u8), Disable(u8), } enum Receivable { Event(Event), TagCmd(TagCmd), } #[allow(unreachable_code)] async fn manage_tags( ev_rx: channel::Receiver, tag_rx: channel::Receiver, ) -> Result<()> { // notify.recv().await?; // drop(notify); let mut state = EventStreamState::default(); let mut tags = NiriTag::new(); let mut socket = create_niri_socket().await?; // base tag is always visible tags.tags.insert(0, true); async fn on_focused_window( socket: &mut BufReader, mut ok: O, mut err: E, ) -> Result<()> where O: FnMut(Option) -> Result<()>, E: FnMut(Reply) -> Result<()>, { let q = query(socket, Request::FocusedWindow).await?; if let Reply::Ok(Response::FocusedWindow(win)) = q { ok(win) } else { err(q) } } loop { let recvd: Receivable = future::or(async { ev_rx.recv().await.map(Receivable::Event) }, async { tag_rx.recv().await.map(Receivable::TagCmd) }) .await?; let res = match recvd { Receivable::Event(ev) => { let _ = state.apply(ev.clone()); Ok(()) } Receivable::TagCmd(cmd) => { // get wid of current window, add tag to it match cmd { TagCmd::Add(t) => on_focused_window( &mut socket, |win| { if let Some(win) = win { let wid = win.id; tags.windows.insert(wid, t); Ok(()) } else { Err(anyhow!("No focused window to tag")) } }, |q: Reply| { Err(anyhow!( "Invalid response from Niri when requesting FocusedWindow: {}", if q.is_err() { q.unwrap_err() } else { serde_json::to_string(&q.unwrap())? } )) }, ) .await, TagCmd::Remove(_) => on_focused_window( &mut socket, |win| { if let Some(win) = win { let wid = win.id; tags.windows.remove(&wid); Ok(()) } else { Err(anyhow!("No focused window to untag")) } }, |q: Reply| { Err(anyhow!( "Invalid response from Niri when requesting FocusedWindow: {}", if q.is_err() { q.unwrap_err() } else { serde_json::to_string(&q.unwrap())? } )) }, ) .await, TagCmd::Enable(t) => { tags.tags.insert(t, true); Ok(()) } TagCmd::Disable(t) => { tags.tags.insert(t, false); Ok(()) } } } }; match res { Ok(()) => (), Err(e) => tracing::error!("error occurred in manager loop: {}", e), } // use Event::*; // do we want to catch events or just let them apply and then set things right? // match ev { // WorkspaceActivated { .. } => (), // WorkspacesChanged { .. } => (), // WorkspaceUrgencyChanged { .. } => (), // WindowsChanged { .. } => (), // WindowOpenedOrChanged { .. } => (), // WindowUrgencyChanged { .. } => (), // WindowClosed { .. } => (), // _ => (), // } for (&wid, window) in state.windows.windows.iter() { let (active, inactive): (Vec<_>, Vec<_>) = state .workspaces .workspaces .iter() .map(|(wsid, ws)| (wsid, ws.is_active)) .partition(|(_, a)| *a); if let Some(wsid) = window.workspace_id { if let Some(&window_tag) = tags.windows.get(&wid) { if let Some(&tag_enabled) = tags.tags.get(&window_tag) { if tag_enabled && inactive.contains(&(&wsid, false)) { tell( &mut socket, Request::Action(Action::MoveWindowToWorkspace { window_id: Some(wid), reference: WorkspaceReferenceArg::Index(0), focus: false, }), ) .await?; tracing::debug!("making visible {}", wid); } else if !tag_enabled && active.contains(&(&wsid, true)) { let hidden = *inactive.first().unwrap().0; tell( &mut socket, Request::Action(Action::MoveWindowToWorkspace { window_id: Some(wid), reference: WorkspaceReferenceArg::Id(hidden), focus: false, }), ) .await?; tracing::debug!("making hidden {}", wid); } } else { tags.windows.insert(wid, 0); } } } } } tracing::error!("Manager loop ended"); unreachable!("Manager loop ended"); Ok(()) } #[allow(unreachable_code)] async fn ipc_listener(tx: channel::Sender) -> Result<()> { tracing::debug!("creating niri-tag socket"); let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid()); if smol::fs::metadata(&sock_path).await.is_ok() { tracing::debug!("removing old niri-tag socket"); smol::fs::remove_file(&sock_path).await?; } tracing::debug!("establishing niri-tag socket connection"); let listen = UnixListener::bind(&sock_path) .inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?; let mut buf = String::new(); tracing::debug!( "this is what a proper tagcmd call looks like: {}", serde_json::to_string(&TagCmd::Add(1))? ); loop { let (raw_sock, _addr) = listen.accept().await?; let mut socket = BufReader::new(raw_sock); tracing::debug!("awaiting ipc socket"); let _ = socket.read_line(&mut buf).await?; tracing::debug!("forwarding ipc command"); tx.send(serde_json::from_str(&buf)?).await?; socket.close().await?; } tracing::error!("IPC loop ended"); unreachable!("IPC loop ended"); Ok(()) } fn main() -> Result<()> { // let systemd know we're ready let _ = libsystemd::daemon::notify(false, &[libsystemd::daemon::NotifyState::Ready])?; // debug stuff tracing_subscriber::fmt() .with_max_level(tracing::Level::TRACE) .init(); let span = tracing::span!(tracing::Level::TRACE, "main"); let _ = span.enter(); // spawn socket listener for niri event stream let (event_tx, event_rx) = smol::channel::unbounded(); smol::spawn(event_listener(event_tx)).detach(); // spawn socket listener for ipc let (ipc_tx, ipc_rx) = smol::channel::unbounded(); smol::spawn(ipc_listener(ipc_tx)).detach(); // begin managing niri tags smol::block_on(manage_tags(event_rx, ipc_rx)) }