refactor: use persistent TagState instead of ad hoc calculations

This commit is contained in:
atagen 2025-06-23 17:13:46 +10:00
parent 1ec5921248
commit 6746e94f6d
2 changed files with 124 additions and 97 deletions

View File

@ -5,17 +5,12 @@ use niri_ipc::{
state::{EventStreamState, EventStreamStatePart},
};
use niri_tag::{Config, TagCmd, TagEvent, TagState};
use smol::{
channel::{self, Sender},
future,
io::BufReader,
net::unix::UnixStream,
};
use std::collections::HashMap;
use smol::{channel, future, io::BufReader, net::unix::UnixStream};
use std::collections::{HashMap, HashSet};
pub struct NiriTag {
config: Config,
tags: HashMap<u8, bool>,
tags: HashMap<u8, TagState>,
windows: HashMap<u64, u8>,
active_ws: Vec<u64>,
state: EventStreamState,
@ -42,6 +37,42 @@ impl NiriTag {
})
}
async fn change_window_tag(&mut self, wid: u64, replace: Option<u8>) -> Result<()> {
let destination = replace.unwrap_or(0);
if let Some(old_tag) = self.windows.insert(wid, destination) {
self.tags.entry(old_tag).and_modify(|ts| {
ts.windows.remove(&wid);
if ts.windows.is_empty() {
ts.occupied = false;
}
});
if let Some(old) = self.tags.get(&old_tag) {
if old_tag != 0 && !old.occupied {
self.fire_event(TagEvent::TagEmpty(old_tag)).await;
}
};
}
let modified_tag = self
.tags
.entry(destination)
.and_modify(|ts| {
ts.windows.insert(wid);
ts.occupied = true;
})
.or_insert(TagState {
enabled: true,
occupied: true,
urgent: false,
windows: HashSet::from([wid]),
});
if destination != 0 && modified_tag.windows.len() == 1 {
self.fire_event(TagEvent::TagOccupied(destination)).await;
}
Ok(())
}
fn same_output(&self, wsid: u64, candidates: &HashMap<u64, Workspace>) -> Result<Workspace> {
candidates
.values()
@ -116,7 +147,7 @@ impl NiriTag {
match action {
Window(wid) => {
let current_tag = *self.windows.entry(wid).or_insert(0);
let tag_visible = *self.tags.entry(current_tag).or_insert(true);
let tag_visible = self.tags.entry(current_tag).or_default().enabled;
let win = self
.state
.windows
@ -157,7 +188,7 @@ impl NiriTag {
}
Tag(tag) => {
tracing::debug!("Changing tag {}", tag);
let tag_visible = *self.tags.entry(tag).or_insert(true);
let tag_visible = self.tags.entry(tag).or_default().enabled;
let affected_windows: Vec<u64> = self
.windows
.iter()
@ -228,36 +259,18 @@ impl NiriTag {
}
}
async fn fire_event(&mut self, event: TagEvent) {
let tx = self.ev_tx.clone();
smol::spawn(async move {
tx.send(event)
.await
.inspect_err(|e| tracing::error!("Failed to send event: {}", e))
})
.detach();
}
async fn handle_recvd(&mut self, recvd: Receivable) -> Result<()> {
use TagAction::*;
let send_event = async |tx: Sender<TagEvent>, ev| {
smol::spawn(async move {
tx.send(ev)
.await
.inspect_err(|e| tracing::error!("Failed to send event: {}", e))
})
.detach();
};
let add_tag_fire = async |tx: Sender<TagEvent>, windows: &HashMap<u64, u8>, t| {
if windows
.iter()
.filter(|(_, tag)| **tag == t)
.collect::<HashMap<_, _>>()
.is_empty()
{
send_event(tx, TagEvent::TagOccupied(t)).await;
}
};
let rm_tag_fire = async |tx: Sender<TagEvent>, windows: &HashMap<u64, u8>, wid, old_tag| {
let same_tagged = windows
.iter()
.filter(|(w, tag)| **tag == old_tag && **w != wid)
.count();
if same_tagged == 0 && old_tag != 0 {
send_event(tx, TagEvent::TagEmpty(old_tag)).await;
}
};
// first do any local mutations
let action: TagAction = match recvd {
Receivable::Event(ev) => {
@ -268,83 +281,75 @@ impl NiriTag {
tracing::debug!("received request for full state");
let fullstate: HashMap<u8, TagState> = self
.tags
.iter()
.filter(|(t, _)| **t != 0)
.map(|(&t, &enabled)| {
(
t,
TagState {
enabled,
occupied: self.windows.values().filter(|w_t| **w_t == t).count()
> 0,
urgent: false, // urgency is TODO
},
)
})
.clone()
.into_iter()
.filter(|(t, _)| *t != 0)
.collect();
return tx.send(fullstate).await.map_err(|e| anyhow!(e));
}
Receivable::TagCmd(cmd) => match cmd {
TagCmd::AddTagToWin(t) => {
let win = self.get_focused_window().await?;
let wid = win.id;
self.windows.insert(wid, t);
tracing::debug!("adding tag {} to {}", t, wid);
let tx = self.ev_tx.clone();
add_tag_fire(tx, &self.windows, t).await;
let wid = self.get_focused_window().await?.id;
self.change_window_tag(wid, Some(t)).await?;
Window(wid)
}
TagCmd::RemoveTagFromWin(_) => {
let win = self.get_focused_window().await?;
let wid = win.id;
let old_tag = self.windows.insert(wid, 0).unwrap_or(0);
tracing::debug!("resetting tag on {}", wid);
let tx = self.ev_tx.clone();
rm_tag_fire(tx, &self.windows, wid, old_tag).await;
let wid = self.get_focused_window().await?.id;
self.change_window_tag(wid, None).await?;
Window(wid)
}
TagCmd::ToggleTagOnWin(t) => {
let win = self.get_focused_window().await?;
let wid = win.id;
tracing::debug!("{} has tag {:?}", wid, self.windows.get(&wid));
let this_tag = *self.windows.entry(wid).or_insert(0);
let toggle = if this_tag == t { 0 } else { t };
let tx = self.ev_tx.clone();
if toggle == 0 {
rm_tag_fire(tx, &self.windows, wid, this_tag).await;
let wid = self.get_focused_window().await?.id;
let new_tag = if *self.windows.entry(wid).or_insert(0) == t {
0
} else {
add_tag_fire(tx, &self.windows, t).await;
}
tracing::debug!("toggling {} to tag {}", wid, toggle);
self.windows.insert(wid, toggle);
t
};
self.change_window_tag(wid, Some(new_tag)).await?;
tracing::debug!("toggling {} to tag {}", wid, new_tag);
Window(wid)
}
TagCmd::EnableTag(t) => {
self.tags.insert(t, true);
send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await;
self.tags
.entry(t)
.and_modify(|ts| ts.enabled = true)
.or_default();
self.fire_event(TagEvent::TagEnabled(t)).await;
Tag(t)
}
TagCmd::DisableTag(t) => {
self.tags.insert(t, false);
send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await;
self.tags
.entry(t)
.and_modify(|ts| ts.enabled = false)
.or_default();
self.fire_event(TagEvent::TagDisabled(t)).await;
Tag(t)
}
TagCmd::ToggleTag(t) => {
let new_state = !*self.tags.entry(t).or_insert(false);
if new_state {
send_event(self.ev_tx.clone(), TagEvent::TagEnabled(t)).await;
let new_state = self
.tags
.entry(t)
.and_modify(|ts| ts.enabled = !ts.enabled)
.or_default()
.enabled;
self.fire_event(if new_state {
TagEvent::TagEnabled(t)
} else {
send_event(self.ev_tx.clone(), TagEvent::TagDisabled(t)).await;
}
tracing::debug!("toggling tag {} to {}", t, new_state);
self.tags.insert(t, new_state);
TagEvent::TagDisabled(t)
})
.await;
Tag(t)
}
TagCmd::ExclusiveTag(t) => {
self.tags.entry(t).insert_entry(true);
self.tags.iter_mut().for_each(|(it, en)| *en = *it == t);
send_event(self.ev_tx.clone(), TagEvent::TagExclusive(t)).await;
self.tags
.entry(t)
.and_modify(|ts| ts.enabled = true)
.or_default();
self.tags
.iter_mut()
.for_each(|(it, ts)| ts.enabled = *it == t);
self.fire_event(TagEvent::TagExclusive(t)).await;
TagExclusive(t)
}
},
@ -366,7 +371,17 @@ impl NiriTag {
Ok(())
}
WindowClosed { id } => {
self.windows.remove(&id);
if let Some(t) = self.windows.remove(&id) {
self.tags.entry(t).and_modify(|ts| {
ts.windows.remove(&id);
ts.occupied = !ts.windows.is_empty();
});
if let Some(tag) = self.tags.get(&t) {
if !tag.occupied {
self.fire_event(TagEvent::TagEmpty(t)).await;
}
}
}
Ok(())
}
WorkspaceActivated { id, .. } => {
@ -411,7 +426,7 @@ impl NiriTag {
// WorkspaceUrgencyChanged { .. } => (),
WindowsChanged { windows } => {
for w in windows {
self.windows.entry(w.id).or_insert(0);
self.change_window_tag(w.id, None).await?;
let action = self.do_action(TagAction::Window(w.id)).await;
if let Err(e) = action {
tracing::warn!("Failed to ChangeWindow on {}: {}", w.id, e);
@ -432,9 +447,8 @@ impl NiriTag {
fullstate_rx: channel::Receiver<channel::Sender<HashMap<u8, TagState>>>,
) -> Result<()> {
// prepopulate tags
self.tags.insert(0, true);
(1..=self.config.prepopulate).for_each(|i| {
self.tags.insert(i, true);
(0..=self.config.prepopulate).for_each(|i| {
self.tags.entry(i).or_default();
});
loop {

View File

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
@ -13,7 +13,7 @@ pub enum TagCmd {
ExclusiveTag(u8),
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Debug)]
pub enum TagEvent {
TagEmpty(u8),
TagOccupied(u8),
@ -24,11 +24,24 @@ pub enum TagEvent {
TagFullState(HashMap<u8, TagState>),
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Debug, Clone)]
pub struct TagState {
pub enabled: bool,
pub occupied: bool,
pub urgent: bool,
#[serde(skip_serializing)]
pub windows: HashSet<u64>,
}
impl Default for TagState {
fn default() -> Self {
Self {
enabled: true,
occupied: false,
urgent: false,
windows: HashSet::new(),
}
}
}
#[derive(Default, Deserialize)]