niri-tag/src/main.rs

304 lines
11 KiB
Rust

use std::{collections::HashMap, env};
use anyhow::{Context, Result, anyhow};
use niri_ipc::{
Action, Event, Reply, Request, Response, Window, WorkspaceReferenceArg,
state::{EventStreamState, EventStreamStatePart},
};
use nix::unistd::{geteuid, getpid};
use serde::{Deserialize, Serialize};
use smol::{
channel::{self},
future,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::{UnixListener, UnixStream},
};
struct NiriTag {
tags: HashMap<u8, bool>,
windows: HashMap<u64, u8>,
}
impl NiriTag {
fn new() -> Self {
Self {
tags: HashMap::new(),
windows: HashMap::new(),
}
}
}
async fn query(socket: &mut BufReader<UnixStream>, req: Request) -> Result<Reply> {
let req = serde_json::to_string(&req)?;
tracing::debug!("sending request: {}", req);
socket.write_all(&[req.as_bytes(), b"\n"].concat()).await?;
socket.flush().await?;
let mut rep = String::new();
socket.read_line(&mut rep).await?;
Ok(serde_json::from_str(&rep)?)
}
async fn tell(socket: &mut BufReader<UnixStream>, req: Request) -> Result<()> {
let rep = query(socket, req).await?;
if let Reply::Ok(Response::Handled) = rep {
Ok(())
} else {
Err(anyhow!(
"Expected Reply::Ok(Response::Handled), got {}",
rep.unwrap_err()
))
}
}
async fn create_niri_socket() -> Result<BufReader<UnixStream>> {
let socket_path = env::var(niri_ipc::socket::SOCKET_PATH_ENV)
.context("Couldn't find Niri socket path ($NIRI_SOCKET) in environment")?;
tracing::debug!("socket path is: {}", socket_path);
let raw = UnixStream::connect(&socket_path).await?;
Ok(BufReader::new(raw))
}
#[allow(unreachable_code)]
async fn event_listener(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating listener socket");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::debug!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
#[derive(Serialize, Deserialize)]
enum TagCmd {
Add(u8),
Remove(u8),
Enable(u8),
Disable(u8),
}
enum Receivable {
Event(Event),
TagCmd(TagCmd),
}
#[allow(unreachable_code)]
async fn manage_tags(
ev_rx: channel::Receiver<Event>,
tag_rx: channel::Receiver<TagCmd>,
) -> Result<()> {
// notify.recv().await?;
// drop(notify);
let mut state = EventStreamState::default();
let mut tags = NiriTag::new();
let mut socket = create_niri_socket().await?;
// base tag is always visible
tags.tags.insert(0, true);
async fn on_focused_window<O, E>(
socket: &mut BufReader<UnixStream>,
mut ok: O,
mut err: E,
) -> Result<()>
where
O: FnMut(Option<Window>) -> Result<()>,
E: FnMut(Reply) -> Result<()>,
{
let q = query(socket, Request::FocusedWindow).await?;
if let Reply::Ok(Response::FocusedWindow(win)) = q {
ok(win)
} else {
err(q)
}
}
loop {
let recvd: Receivable =
future::or(async { ev_rx.recv().await.map(Receivable::Event) }, async {
tag_rx.recv().await.map(Receivable::TagCmd)
})
.await?;
let res =
match recvd {
Receivable::Event(ev) => {
let _ = state.apply(ev.clone());
Ok(())
}
Receivable::TagCmd(cmd) => {
// get wid of current window, add tag to it
match cmd {
TagCmd::Add(t) => on_focused_window(
&mut socket,
|win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.insert(wid, t);
Ok(())
} else {
Err(anyhow!("No focused window to tag"))
}
},
|q: Reply| {
Err(anyhow!(
"Invalid response from Niri when requesting FocusedWindow: {}",
if q.is_err() {
q.unwrap_err()
} else {
serde_json::to_string(&q.unwrap())?
}
))
},
)
.await,
TagCmd::Remove(_) => on_focused_window(
&mut socket,
|win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.remove(&wid);
Ok(())
} else {
Err(anyhow!("No focused window to untag"))
}
},
|q: Reply| {
Err(anyhow!(
"Invalid response from Niri when requesting FocusedWindow: {}",
if q.is_err() {
q.unwrap_err()
} else {
serde_json::to_string(&q.unwrap())?
}
))
},
)
.await,
TagCmd::Enable(t) => {
tags.tags.insert(t, true);
Ok(())
}
TagCmd::Disable(t) => {
tags.tags.insert(t, false);
Ok(())
}
}
}
};
match res {
Ok(()) => (),
Err(e) => tracing::error!("error occurred in manager loop: {}", e),
}
// use Event::*;
// do we want to catch events or just let them apply and then set things right?
// match ev {
// WorkspaceActivated { .. } => (),
// WorkspacesChanged { .. } => (),
// WorkspaceUrgencyChanged { .. } => (),
// WindowsChanged { .. } => (),
// WindowOpenedOrChanged { .. } => (),
// WindowUrgencyChanged { .. } => (),
// WindowClosed { .. } => (),
// _ => (),
// }
for (&wid, window) in state.windows.windows.iter() {
let (active, inactive): (Vec<_>, Vec<_>) = state
.workspaces
.workspaces
.iter()
.map(|(wsid, ws)| (wsid, ws.is_active))
.partition(|(_, a)| *a);
if let Some(wsid) = window.workspace_id {
if let Some(&window_tag) = tags.windows.get(&wid) {
if let Some(&tag_enabled) = tags.tags.get(&window_tag) {
if tag_enabled && inactive.contains(&(&wsid, false)) {
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Index(0),
focus: false,
}),
)
.await?;
tracing::debug!("making visible {}", wid);
} else if !tag_enabled && active.contains(&(&wsid, true)) {
let hidden = *inactive.first().unwrap().0;
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Id(hidden),
focus: false,
}),
)
.await?;
tracing::debug!("making hidden {}", wid);
}
} else {
tags.windows.insert(wid, 0);
}
}
}
}
}
tracing::error!("Manager loop ended");
unreachable!("Manager loop ended");
Ok(())
}
#[allow(unreachable_code)]
async fn ipc_listener(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating niri-tag socket");
let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid());
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old niri-tag socket");
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing niri-tag socket connection");
let listen = UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?;
let mut buf = String::new();
tracing::debug!(
"this is what a proper tagcmd call looks like: {}",
serde_json::to_string(&TagCmd::Add(1))?
);
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command");
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}
fn main() -> Result<()> {
// let systemd know we're ready
let _ = libsystemd::daemon::notify(false, &[libsystemd::daemon::NotifyState::Ready])?;
// debug stuff
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)
.init();
let span = tracing::span!(tracing::Level::TRACE, "main");
let _ = span.enter();
// spawn socket listener for niri event stream
let (event_tx, event_rx) = smol::channel::unbounded();
smol::spawn(event_listener(event_tx)).detach();
// spawn socket listener for ipc
let (ipc_tx, ipc_rx) = smol::channel::unbounded();
smol::spawn(ipc_listener(ipc_tx)).detach();
// begin managing niri tags
smol::block_on(manage_tags(event_rx, ipc_rx))
}