use crate::socket::{create_niri_socket, query, tell}; use anyhow::{Context, Result, anyhow}; use niri_ipc::{ Action, Event, Reply, Request, Response, Window, Workspace, WorkspaceReferenceArg, state::{EventStreamState, EventStreamStatePart}, }; use niri_tag::{Config, TagCmd, TagEvent, TagState}; use smol::{ channel::{self, Sender}, future, io::BufReader, net::unix::UnixStream, }; use std::collections::HashMap; pub struct NiriTag { config: Config, tags: HashMap, windows: HashMap, active_ws: Vec, state: EventStreamState, socket: BufReader, ev_tx: channel::Sender, } enum TagAction { Window(u64), Tag(u8), TagExclusive(u8), } impl NiriTag { pub async fn new(config: Config, ev_tx: channel::Sender) -> Result { Ok(Self { config, tags: HashMap::new(), windows: HashMap::new(), active_ws: Vec::new(), state: EventStreamState::default(), socket: create_niri_socket().await?, ev_tx, }) } fn same_output(&self, wsid: u64, candidates: &HashMap) -> Result { candidates .values() .filter_map(|ws| { let output = ws.output.clone()?; let win_output = self .state .workspaces .workspaces .get(&wsid)? .output .clone()?; (win_output == output).then_some(ws) }) .last() .context(anyhow!( "No inactive workspaces on output of workspace {} found", wsid )) .cloned() } async fn move_windows( &mut self, candidates: &HashMap, affected_windows: Vec, ) { for wid in affected_windows { tracing::debug!("Changing affected window {}", wid); if let Some(win) = self.state.windows.windows.get(&wid) { let wsid = win.workspace_id.unwrap(); match self.same_output(wsid, candidates) { Ok(status_same_output) => { if let Err(e) = tell( &mut self.socket, Request::Action(Action::MoveWindowToWorkspace { window_id: Some(wid), reference: WorkspaceReferenceArg::Id(status_same_output.id), focus: false, }), ) .await { tracing::error!( "Failed to move window {} to workspace {}: {}", wid, status_same_output.id, e ); } } Err(e) => { tracing::error!("Failed to get workspace on same output as {}: {}", wsid, e) } } } else { tracing::warn!("Failed to get wid {} from niri state", wid); continue; } } } async fn do_action(&mut self, action: TagAction) -> Result<()> { use TagAction::*; let (active, inactive): (HashMap<_, _>, HashMap<_, _>) = self .state .workspaces .workspaces .clone() .into_iter() .partition(|(_, ws)| ws.is_active); match action { Window(wid) => { let current_tag = *self.windows.entry(wid).or_insert(0); let tag_visible = *self.tags.entry(current_tag).or_insert(true); let win = self .state .windows .windows .get(&wid) .ok_or(anyhow!("Failed to retrieve window {} from niri state", wid))?; let wsid: u64 = win .workspace_id .ok_or(anyhow!("Retrieving workspace id of a changed window"))?; let win_visible = active.contains_key(&wsid); match (win_visible, tag_visible) { (true, false) => { let inactive_same_output = self.same_output(wsid, &inactive)?; tell( &mut self.socket, Request::Action(Action::MoveWindowToWorkspace { window_id: Some(wid), reference: WorkspaceReferenceArg::Id(inactive_same_output.id), focus: false, }), ) .await } (false, true) => { let active_same_output = self.same_output(wsid, &active)?; tell( &mut self.socket, Request::Action(Action::MoveWindowToWorkspace { window_id: Some(wid), reference: WorkspaceReferenceArg::Id(active_same_output.id), focus: true, }), ) .await } _ => Ok(()), } } Tag(tag) => { tracing::debug!("Changing tag {}", tag); let tag_visible = *self.tags.entry(tag).or_insert(true); let affected_windows: Vec = self .windows .iter() .filter(|(_, t)| tag == **t) .map(|(wid, _)| *wid) .collect(); tracing::debug!( "{} affected windows of tag {}: {:?}", affected_windows.len(), tag, affected_windows ); let focus = affected_windows.last().cloned(); self.move_windows( if tag_visible { &active } else { &inactive }, affected_windows, ) .await; if let Some(focus) = focus { if tag_visible { tell( &mut self.socket, Request::Action(Action::FocusWindow { id: focus }), ) .await?; } } Ok(()) } TagExclusive(t) => { tracing::debug!("Changing all tags"); let (active_wid, inactive_wid): (HashMap, HashMap) = self.windows.iter().partition(|(_, it)| **it == t); let focus = active_wid.keys().last(); self.move_windows(&inactive, inactive_wid.keys().cloned().collect()) .await; self.move_windows(&active, active_wid.keys().cloned().collect()) .await; if let Some(f) = focus { tell( &mut self.socket, Request::Action(Action::FocusWindow { id: *f }), ) .await?; } Ok(()) } } } async fn get_focused_window(&mut self) -> Result { let q = query(&mut self.socket, Request::FocusedWindow).await?; if let Reply::Ok(Response::FocusedWindow(win)) = q { if let Some(win) = win { Ok(win) } else { Err(anyhow!("No focused window to operate on")) } } else { Err(anyhow!( "Invalid response from Niri when requesting FocusedWindow: {}", if q.is_err() { q.unwrap_err() } else { serde_json::to_string(&q.unwrap())? } )) } } async fn handle_recvd(&mut self, recvd: Receivable) -> Result<()> { use TagAction::*; let send_event = async |tx: Sender, ev| { smol::spawn(async move { tx.send(ev) .await .inspect_err(|e| tracing::error!("Failed to send event: {}", e)) }) .detach(); }; let add_tag_fire = async |tx: Sender, windows: &HashMap, t| { if windows .iter() .filter(|(_, tag)| **tag == t) .collect::>() .is_empty() { send_event(tx, TagEvent::TagOccupied(t)).await; } }; let rm_tag_fire = async |tx: Sender, windows: &HashMap, wid, old_tag| { let same_tagged = windows .iter() .filter(|(w, tag)| **tag == old_tag && **w != wid) .count(); if same_tagged == 0 && old_tag != 0 { send_event(tx, TagEvent::TagEmpty(old_tag)).await; } }; // first do any local mutations let action: TagAction = match recvd { Receivable::Event(ev) => { let _ = self.state.apply(ev.clone()); return self.handle_event(ev).await; } Receivable::FullState(tx) => { tracing::debug!("received request for full state"); let fullstate: HashMap = self .tags .iter() .filter(|(t, _)| **t != 0) .map(|(&t, &enabled)| { ( t, TagState { enabled, occupied: self.windows.values().filter(|w_t| **w_t == t).count() > 0, urgent: false, // urgency is TODO }, ) }) .collect(); return tx.send(fullstate).await.map_err(|e| anyhow!(e)); } Receivable::TagCmd(cmd) => match cmd { TagCmd::AddTagToWin(t) => { let win = self.get_focused_window().await?; let wid = win.id; self.windows.insert(wid, t); tracing::debug!("adding tag {} to {}", t, wid); let tx = self.ev_tx.clone(); add_tag_fire(tx, &self.windows, t).await; Window(wid) } TagCmd::RemoveTagFromWin(_) => { let win = self.get_focused_window().await?; let wid = win.id; let old_tag = self.windows.insert(wid, 0).unwrap_or(0); tracing::debug!("resetting tag on {}", wid); let tx = self.ev_tx.clone(); rm_tag_fire(tx, &self.windows, wid, old_tag).await; Window(wid) } TagCmd::ToggleTagOnWin(t) => { let win = self.get_focused_window().await?; let wid = win.id; tracing::debug!("{} has tag {:?}", wid, self.windows.get(&wid)); let this_tag = *self.windows.entry(wid).or_insert(0); let toggle = if this_tag == t { 0 } else { t }; let tx = self.ev_tx.clone(); if toggle == 0 { rm_tag_fire(tx, &self.windows, wid, this_tag).await; } else { add_tag_fire(tx, &self.windows, t).await; } tracing::debug!("toggling {} to tag {}", wid, toggle); self.windows.insert(wid, toggle); Window(wid) } TagCmd::EnableTag(t) => { self.tags.insert(t, true); send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await; Tag(t) } TagCmd::DisableTag(t) => { self.tags.insert(t, false); send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await; Tag(t) } TagCmd::ToggleTag(t) => { let new_state = !*self.tags.entry(t).or_insert(false); if new_state { send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await; } else { send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await; } tracing::debug!("toggling tag {} to {}", t, new_state); self.tags.insert(t, new_state); Tag(t) } TagCmd::ExclusiveTag(t) => { self.tags.entry(t).insert_entry(true); self.tags.iter_mut().for_each(|(it, en)| *en = *it == t); send_event(self.ev_tx.clone(), TagEvent::TagExclusive(t)).await; TagExclusive(t) } }, }; // then arrange corresponding state in the compositor self.do_action(action).await?; tell( &mut self.socket, Request::Action(Action::CenterVisibleColumns {}), ) .await } async fn handle_event(&mut self, ev: Event) -> Result<()> { use Event::*; match ev { WindowOpenedOrChanged { window } => { self.windows.entry(window.id).or_insert(0); Ok(()) } WindowClosed { id } => { self.windows.remove(&id); Ok(()) } WorkspaceActivated { id, .. } => { if self.config.strict && !self.active_ws.contains(&id) { let q = query(&mut self.socket, Request::Workspaces).await?; let wsid = if let Reply::Ok(Response::Workspaces(workspaces)) = q { let new_ws = workspaces .iter() .find(|ws| ws.id == id) .expect("Activated workspace not found in workspace query"); workspaces .iter() .find(|ws| { ws.output == new_ws.output && ws.id != new_ws.id && self.active_ws.contains(&ws.id) }) .expect("Could not find a valid niri-tag workspace to return to") .id } else { return Err(anyhow!("Invalid response to workspace query")); }; tell( &mut self.socket, Request::Action(Action::FocusWorkspace { reference: WorkspaceReferenceArg::Id(wsid), }), ) .await } else { Ok(()) } } WorkspacesChanged { workspaces } => { self.active_ws = workspaces .into_iter() .filter(|ws| ws.is_active) .map(|ws| ws.id) .collect(); Ok(()) } // WorkspaceUrgencyChanged { .. } => (), WindowsChanged { windows } => { for w in windows { self.windows.entry(w.id).or_insert(0); let action = self.do_action(TagAction::Window(w.id)).await; if let Err(e) = action { tracing::warn!("Failed to ChangeWindow on {}: {}", w.id, e); } } Ok(()) } // WindowUrgencyChanged { .. } => (), _ => Ok(()), } } #[allow(unreachable_code)] pub async fn manage_tags( mut self, ev_rx: channel::Receiver, tag_rx: channel::Receiver, fullstate_rx: channel::Receiver>>, ) -> Result<()> { // prepopulate tags self.tags.insert(0, true); (1..=self.config.prepopulate).for_each(|i| { self.tags.insert(i, true); }); loop { let recvd: Receivable = future::or( async { ev_rx.recv().await.map(Receivable::Event) }, future::or( async { tag_rx.recv().await.map(Receivable::TagCmd) }, async { fullstate_rx.recv().await.map(Receivable::FullState) }, ), ) .await?; tracing::trace!("manager received {:?}", recvd); let res = self.handle_recvd(recvd).await; match res { Ok(()) => (), Err(e) => tracing::error!("error occurred in manager loop: {}", e), } } tracing::error!("Manager loop ended"); unreachable!("Manager loop ended"); } } #[derive(Debug)] enum Receivable { Event(Event), TagCmd(TagCmd), FullState(channel::Sender>), }