feat: full poc

This commit is contained in:
atagen 2025-06-19 17:35:38 +10:00
parent b76036038e
commit a8847b93cf
14 changed files with 630 additions and 323 deletions

140
Cargo.lock generated
View File

@ -2,6 +2,56 @@
# It is not intended for manual editing.
version = 4
[[package]]
name = "anstream"
version = "0.6.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
[[package]]
name = "anstyle-parse"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9"
dependencies = [
"windows-sys",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys",
]
[[package]]
name = "anyhow"
version = "1.0.98"
@ -137,9 +187,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "autocfg"
version = "1.4.0"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "bitflags"
@ -187,6 +237,52 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "clap"
version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
[[package]]
name = "colorchoice"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "concurrent-queue"
version = "2.5.0"
@ -304,6 +400,12 @@ dependencies = [
"version_check",
]
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.5.2"
@ -319,6 +421,12 @@ dependencies = [
"digest",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itoa"
version = "1.0.15"
@ -343,9 +451,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "libc"
version = "0.2.173"
version = "0.2.174"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8cfeafaffdbc32176b64fb251369d52ea9f0a8fbc6f8759edffef7b525d64bb"
checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776"
[[package]]
name = "libsystemd"
@ -405,13 +513,13 @@ name = "niri-tag"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"libsystemd",
"niri-ipc",
"nix 0.30.1",
"serde",
"serde_json",
"smol",
"thiserror",
"tracing",
"tracing-subscriber",
]
@ -466,6 +574,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "once_cell_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
[[package]]
name = "overload"
version = "0.1.1"
@ -643,6 +757,12 @@ dependencies = [
"futures-lite",
]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "subtle"
version = "2.6.1"
@ -702,9 +822,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.29"
version = "0.1.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b1ffbcf9c6f6b99d386e7444eb608ba646ae452a36b39737deb9663b610f662"
checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903"
dependencies = [
"proc-macro2",
"quote",
@ -758,6 +878,12 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.17.0"

View File

@ -3,14 +3,27 @@ name = "niri-tag"
version = "0.1.0"
edition = "2024"
[lib]
crate-type = [ "lib" ]
path = "lib/main.rs"
[[bin]]
name = "niri-tag"
path = "daemon/main.rs"
[[bin]]
name = "tagctl"
path = "cli/main.rs"
[dependencies]
anyhow = "1.0.98"
libsystemd = "0.7.2"
serde_json = "1.0.140"
thiserror = "2.0.12"
tracing-subscriber = "0.3.19"
libsystemd = "0.7"
niri-ipc = { path = "niri/niri-ipc" }
smol = "2.0.2"
tracing = "0.1.41"
nix = { version = "0.30.1", features = ["process", "user"] }
serde = { version = "1.0.219", features = ["derive"] }
nix = { version = "0.30", features = ["process", "user"] }
anyhow = "1.0"
tracing = "0.1"
tracing-subscriber = "0.3"
serde_json = "1.0"
smol = "2.0"
serde = { version = "1.0", features = ["derive"] }
clap = { version = "4.5", features = ["derive"] }

6
TODO Normal file
View File

@ -0,0 +1,6 @@
- check if tags exist when adding a window to one
- enable all tags by default
- make event loop properly scoped to window events instead of iterating over all windows
- move windows down a workspace instead of to a random invisible one ?
- move windows down, THEN to invisible after a user specified timeout to allow animations to complete ?
- add auto-categorisation of tags based on user prefs ie. all firefox by default goes to x

12
cli/Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "tagctl"
version = "0.1.0"
edition = "2024"
[dependencies]
anyhow = "1.0"
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
nix = { version = "0.30", features = ["process", "user"] }
clap = { version = "4.5", features = ["derive"] }
niri-tag = { path = "../lib" }

59
cli/main.rs Normal file
View File

@ -0,0 +1,59 @@
use anyhow::{Context, Result, anyhow};
use clap::{Parser, Subcommand};
use niri_tag::TagCmd;
use nix::unistd::geteuid;
use std::{io::Write, os::unix::net::UnixStream};
#[derive(Parser)]
#[command(name = "tagctl")]
#[command(about = "ipc wrapper for niri-tag")]
struct Cli {
#[command(subcommand)]
cmd: Commands,
}
#[derive(Clone, Debug, Subcommand)]
enum Commands {
Add { tag: u8 },
Remove { tag: u8 },
Toggle { tag: u8 },
Enable { tag: u8 },
Disable { tag: u8 },
ToggleWs { tag: u8 },
}
impl From<Commands> for niri_tag::TagCmd {
fn from(value: Commands) -> Self {
match value {
Commands::Add { tag } => TagCmd::Add(tag),
Commands::Remove { tag } => TagCmd::Remove(tag),
Commands::Enable { tag } => TagCmd::Enable(tag),
Commands::Disable { tag } => TagCmd::Disable(tag),
Commands::Toggle { tag } => TagCmd::Toggle(tag),
Commands::ToggleWs { tag } => TagCmd::ToggleWs(tag),
}
}
}
fn main() -> Result<()> {
let cli = Cli::parse();
use Commands::*;
println!("{:?}", cli.cmd);
match cli.cmd {
Add { tag } if tag > 0 => (),
Remove { tag } if tag > 0 => (),
Enable { tag } if tag > 0 => (),
Disable { tag } if tag > 0 => (),
Toggle { tag } if tag > 0 => (),
ToggleWs { tag } if tag > 0 => (),
_ => return Err(anyhow!("Can't change tag 0!")),
};
let cmd = TagCmd::from(cli.cmd);
let mut ipc = UnixStream::connect(format!("/run/user/{}/niri-tag.sock", geteuid()))
.context("Connecting to niri-tag ipc socket")?;
ipc.write_all(serde_json::to_string(&cmd)?.as_bytes())?;
Ok(())
}

16
daemon/Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "niri-tag"
version = "0.1.0"
edition = "2024"
[dependencies]
libsystemd = "0.7"
niri-ipc = { path = "../niri/niri-ipc" }
nix = { version = "0.30", features = ["process", "user"] }
anyhow = "1.0"
tracing = "0.1"
tracing-subscriber = "0.3"
serde_json = "1.0"
smol = "2.0"
serde = { version = "1.0", features = ["derive"] }
niri-tag = { path = "../lib" }

53
daemon/listeners.rs Normal file
View File

@ -0,0 +1,53 @@
use crate::socket::{create_niri_socket, tell};
use anyhow::Result;
use niri_ipc::{Event, Request};
use niri_tag::TagCmd;
use nix::unistd::geteuid;
use smol::{
channel::{self},
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::UnixListener,
};
#[allow(unreachable_code)]
pub async fn event_listener(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating listener socket");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::debug!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
#[allow(unreachable_code)]
pub async fn ipc_listener(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating niri-tag socket");
let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid());
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old niri-tag socket");
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing niri-tag socket connection");
let listen = UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?;
let mut buf = String::new();
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command {}", buf);
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}

25
daemon/main.rs Normal file
View File

@ -0,0 +1,25 @@
mod listeners;
mod manager;
mod socket;
use anyhow::Result;
fn main() -> Result<()> {
// let systemd know we're ready
let _ = libsystemd::daemon::notify(false, &[libsystemd::daemon::NotifyState::Ready])?;
// debug stuff
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();
let span = tracing::span!(tracing::Level::DEBUG, "main");
let _ = span.enter();
// spawn socket listener for niri event stream
let (event_tx, event_rx) = smol::channel::unbounded();
smol::spawn(listeners::event_listener(event_tx)).detach();
// spawn socket listener for ipc
let (ipc_tx, ipc_rx) = smol::channel::unbounded();
smol::spawn(listeners::ipc_listener(ipc_tx)).detach();
// begin managing niri tags
let niri_tag = manager::NiriTag::new();
smol::block_on(niri_tag.manage_tags(event_rx, ipc_rx))
}

199
daemon/manager.rs Normal file
View File

@ -0,0 +1,199 @@
use crate::socket::{create_niri_socket, query, tell};
use anyhow::{Result, anyhow};
use niri_ipc::{
Action, Event, Reply, Request, Response, Window, WorkspaceReferenceArg,
state::{EventStreamState, EventStreamStatePart},
};
use niri_tag::TagCmd;
use smol::{
channel::{self},
future,
io::BufReader,
net::unix::UnixStream,
};
use std::collections::HashMap;
pub struct NiriTag {
tags: HashMap<u8, bool>,
windows: HashMap<u64, u8>,
}
impl NiriTag {
pub fn new() -> Self {
Self {
tags: HashMap::new(),
windows: HashMap::new(),
}
}
#[allow(unreachable_code)]
pub async fn manage_tags(
self,
ev_rx: channel::Receiver<Event>,
tag_rx: channel::Receiver<TagCmd>,
) -> Result<()> {
let mut state = EventStreamState::default();
let mut tags = NiriTag::new();
let mut socket = create_niri_socket().await?;
// base tag is always visible
tags.tags.insert(0, true);
async fn on_focused_window<O>(socket: &mut BufReader<UnixStream>, mut ok: O) -> Result<()>
where
O: FnMut(Option<Window>) -> Result<()>,
{
let q = query(socket, Request::FocusedWindow).await?;
if let Reply::Ok(Response::FocusedWindow(win)) = q {
ok(win)
} else {
Err(anyhow!(
"Invalid response from Niri when requesting FocusedWindow: {}",
if q.is_err() {
q.unwrap_err()
} else {
serde_json::to_string(&q.unwrap())?
}
))
}
}
loop {
let recvd: Receivable =
future::or(async { ev_rx.recv().await.map(Receivable::Event) }, async {
tag_rx.recv().await.map(Receivable::TagCmd)
})
.await?;
tracing::debug!("received {:?}", recvd);
let res = match recvd {
Receivable::Event(ev) => {
let _ = state.apply(ev.clone());
Ok(())
}
Receivable::TagCmd(cmd) => match cmd {
TagCmd::Add(t) => {
on_focused_window(&mut socket, |win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.insert(wid, t);
Ok(())
} else {
Err(anyhow!("No focused window to tag"))
}
})
.await
}
TagCmd::Remove(_) => {
on_focused_window(&mut socket, |win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.insert(wid, 0);
Ok(())
} else {
Err(anyhow!("No focused window to untag"))
}
})
.await
}
TagCmd::Toggle(t) => {
on_focused_window(&mut socket, |win| {
if let Some(win) = win {
let wid = win.id;
let toggle = if *tags.windows.get(&wid).unwrap_or(&0) == t {
0
} else {
t
};
tracing::debug!("toggling {} to tag {}", wid, toggle);
tags.windows.insert(wid, toggle);
Ok(())
} else {
Err(anyhow!("No focused window to untag"))
}
})
.await
}
TagCmd::Enable(t) => {
tags.tags.insert(t, true);
Ok(())
}
TagCmd::Disable(t) => {
tags.tags.insert(t, false);
Ok(())
}
TagCmd::ToggleWs(t) => {
tracing::debug!("toggling tag {}", t);
tags.tags.insert(t, !tags.tags.get(&t).unwrap_or(&false));
Ok(())
}
},
};
match res {
Ok(()) => (),
Err(e) => tracing::error!("error occurred in manager loop: {}", e),
}
// TODO: react selectively instead of brute forcing window state
// use Event::*;
// match ev {
// WorkspaceActivated { .. } => (),
// WorkspacesChanged { .. } => (),
// WorkspaceUrgencyChanged { .. } => (),
// WindowsChanged { .. } => (),
// WindowOpenedOrChanged { .. } => (),
// WindowUrgencyChanged { .. } => (),
// WindowClosed { .. } => (),
// _ => (),
// }
for (&wid, window) in state.windows.windows.iter() {
let (active, inactive): (Vec<_>, Vec<_>) = state
.workspaces
.workspaces
.iter()
.map(|(wsid, ws)| (wsid, ws.is_active))
.partition(|(_, a)| *a);
if let Some(wsid) = window.workspace_id {
if let Some(&window_tag) = tags.windows.get(&wid) {
if let Some(&tag_enabled) = tags.tags.get(&window_tag) {
if tag_enabled && inactive.contains(&(&wsid, false)) {
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Index(0),
focus: false,
}),
)
.await?;
tracing::debug!("making visible {}", wid);
} else if !tag_enabled && active.contains(&(&wsid, true)) {
let hidden = *inactive.first().unwrap().0;
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Id(hidden),
focus: false,
}),
)
.await?;
tracing::debug!("making hidden {}", wid);
}
} else {
tags.windows.insert(wid, 0);
}
}
}
}
}
tracing::error!("Manager loop ended");
unreachable!("Manager loop ended");
Ok(())
}
}
#[derive(Debug)]
enum Receivable {
Event(Event),
TagCmd(TagCmd),
}

37
daemon/socket.rs Normal file
View File

@ -0,0 +1,37 @@
use anyhow::{Context, Result, anyhow};
use niri_ipc::{Reply, Request, Response};
use smol::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::UnixStream,
};
use std::env;
pub async fn query(socket: &mut BufReader<UnixStream>, req: Request) -> Result<Reply> {
let req = serde_json::to_string(&req)?;
tracing::debug!("sending request: {}", req);
socket.write_all(&[req.as_bytes(), b"\n"].concat()).await?;
socket.flush().await?;
let mut rep = String::new();
socket.read_line(&mut rep).await?;
Ok(serde_json::from_str(&rep)?)
}
pub async fn tell(socket: &mut BufReader<UnixStream>, req: Request) -> Result<()> {
let rep = query(socket, req).await?;
if let Reply::Ok(Response::Handled) = rep {
Ok(())
} else {
Err(anyhow!(
"Expected Reply::Ok(Response::Handled), got {}",
rep.unwrap_err()
))
}
}
pub async fn create_niri_socket() -> Result<BufReader<UnixStream>> {
let socket_path = env::var(niri_ipc::socket::SOCKET_PATH_ENV)
.context("Couldn't find Niri socket path ($NIRI_SOCKET) in environment")?;
tracing::debug!("socket path is: {}", socket_path);
let raw = UnixStream::connect(&socket_path).await?;
Ok(BufReader::new(raw))
}

View File

@ -39,27 +39,38 @@
RUST_SRC_DIR = "${pkgs.rustPlatform.rustLibSrc}";
RUST_LOG = "debug";
shellHook = ''
cp -R ${niri-flake.inputs.niri-unstable} niri/
cp --no-preserve=mode,ownership -R ${niri-flake.inputs.niri-unstable} niri/
'';
};
});
packages = forAllSystems (pkgs: {
unstable = deps.${pkgs.system}.naersk.buildPackage {
src = ./.;
RUSTFLAGS = "--cfg tokio_unstable";
preConfigure = ''
cp -R ${niri-flake.inputs.niri-unstable} niri/
'';
meta.mainProgram = "niri-tag";
};
stable = deps.${pkgs.system}.naersk.buildPackage {
src = ./.;
RUSTFLAGS = "--cfg tokio_unstable";
preBuild = ''
cp -R ${niri-flake.inputs.niri-stable} niri/
'';
meta.mainProgram = "niri-tag";
};
default = self.packages.${pkgs.system}.unstable;
});
# TODO make module that compiles based on niri version
nixosModules.default = self.nixosModules.niri-tag;
nixosModules.niri-tag =
{ pkgs, ... }:
{
imports = [
./module.nix
];
services.niri-tag.package = self.packages.${pkgs.system}.unstable;
};
};
}

11
lib/main.rs Normal file
View File

@ -0,0 +1,11 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
pub enum TagCmd {
Add(u8),
Remove(u8),
Enable(u8),
Disable(u8),
Toggle(u8),
ToggleWs(u8),
}

42
module.nix Normal file
View File

@ -0,0 +1,42 @@
{
config,
pkgs,
lib,
...
}:
let
inherit (lib)
mkEnableOption
mkPackageOption
mkIf
getExe
;
name = "Niri Tag Manager";
in
{
options.services.niri-tag = {
enable = mkEnableOption name;
package = mkPackageOption pkgs name {
nullable = true;
default = "niri-tag";
};
};
config =
let
cfg = config.services.niri-tag;
in
mkIf (cfg.enable) {
systemd.user.services.niri-tag = {
enable = true;
description = name;
wantedBy = [ "graphical-session.target" ];
partOf = [ "graphical-session.target" ];
serviceConfig = {
Type = "notify";
Restart = "always";
ExecStart = "${getExe cfg.package}";
PrivateTmp = true;
};
};
};
}

View File

@ -1,303 +0,0 @@
use std::{collections::HashMap, env};
use anyhow::{Context, Result, anyhow};
use niri_ipc::{
Action, Event, Reply, Request, Response, Window, WorkspaceReferenceArg,
state::{EventStreamState, EventStreamStatePart},
};
use nix::unistd::{geteuid, getpid};
use serde::{Deserialize, Serialize};
use smol::{
channel::{self},
future,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
net::unix::{UnixListener, UnixStream},
};
struct NiriTag {
tags: HashMap<u8, bool>,
windows: HashMap<u64, u8>,
}
impl NiriTag {
fn new() -> Self {
Self {
tags: HashMap::new(),
windows: HashMap::new(),
}
}
}
async fn query(socket: &mut BufReader<UnixStream>, req: Request) -> Result<Reply> {
let req = serde_json::to_string(&req)?;
tracing::debug!("sending request: {}", req);
socket.write_all(&[req.as_bytes(), b"\n"].concat()).await?;
socket.flush().await?;
let mut rep = String::new();
socket.read_line(&mut rep).await?;
Ok(serde_json::from_str(&rep)?)
}
async fn tell(socket: &mut BufReader<UnixStream>, req: Request) -> Result<()> {
let rep = query(socket, req).await?;
if let Reply::Ok(Response::Handled) = rep {
Ok(())
} else {
Err(anyhow!(
"Expected Reply::Ok(Response::Handled), got {}",
rep.unwrap_err()
))
}
}
async fn create_niri_socket() -> Result<BufReader<UnixStream>> {
let socket_path = env::var(niri_ipc::socket::SOCKET_PATH_ENV)
.context("Couldn't find Niri socket path ($NIRI_SOCKET) in environment")?;
tracing::debug!("socket path is: {}", socket_path);
let raw = UnixStream::connect(&socket_path).await?;
Ok(BufReader::new(raw))
}
#[allow(unreachable_code)]
async fn event_listener(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating listener socket");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::debug!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
#[derive(Serialize, Deserialize)]
enum TagCmd {
Add(u8),
Remove(u8),
Enable(u8),
Disable(u8),
}
enum Receivable {
Event(Event),
TagCmd(TagCmd),
}
#[allow(unreachable_code)]
async fn manage_tags(
ev_rx: channel::Receiver<Event>,
tag_rx: channel::Receiver<TagCmd>,
) -> Result<()> {
// notify.recv().await?;
// drop(notify);
let mut state = EventStreamState::default();
let mut tags = NiriTag::new();
let mut socket = create_niri_socket().await?;
// base tag is always visible
tags.tags.insert(0, true);
async fn on_focused_window<O, E>(
socket: &mut BufReader<UnixStream>,
mut ok: O,
mut err: E,
) -> Result<()>
where
O: FnMut(Option<Window>) -> Result<()>,
E: FnMut(Reply) -> Result<()>,
{
let q = query(socket, Request::FocusedWindow).await?;
if let Reply::Ok(Response::FocusedWindow(win)) = q {
ok(win)
} else {
err(q)
}
}
loop {
let recvd: Receivable =
future::or(async { ev_rx.recv().await.map(Receivable::Event) }, async {
tag_rx.recv().await.map(Receivable::TagCmd)
})
.await?;
let res =
match recvd {
Receivable::Event(ev) => {
let _ = state.apply(ev.clone());
Ok(())
}
Receivable::TagCmd(cmd) => {
// get wid of current window, add tag to it
match cmd {
TagCmd::Add(t) => on_focused_window(
&mut socket,
|win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.insert(wid, t);
Ok(())
} else {
Err(anyhow!("No focused window to tag"))
}
},
|q: Reply| {
Err(anyhow!(
"Invalid response from Niri when requesting FocusedWindow: {}",
if q.is_err() {
q.unwrap_err()
} else {
serde_json::to_string(&q.unwrap())?
}
))
},
)
.await,
TagCmd::Remove(_) => on_focused_window(
&mut socket,
|win| {
if let Some(win) = win {
let wid = win.id;
tags.windows.remove(&wid);
Ok(())
} else {
Err(anyhow!("No focused window to untag"))
}
},
|q: Reply| {
Err(anyhow!(
"Invalid response from Niri when requesting FocusedWindow: {}",
if q.is_err() {
q.unwrap_err()
} else {
serde_json::to_string(&q.unwrap())?
}
))
},
)
.await,
TagCmd::Enable(t) => {
tags.tags.insert(t, true);
Ok(())
}
TagCmd::Disable(t) => {
tags.tags.insert(t, false);
Ok(())
}
}
}
};
match res {
Ok(()) => (),
Err(e) => tracing::error!("error occurred in manager loop: {}", e),
}
// use Event::*;
// do we want to catch events or just let them apply and then set things right?
// match ev {
// WorkspaceActivated { .. } => (),
// WorkspacesChanged { .. } => (),
// WorkspaceUrgencyChanged { .. } => (),
// WindowsChanged { .. } => (),
// WindowOpenedOrChanged { .. } => (),
// WindowUrgencyChanged { .. } => (),
// WindowClosed { .. } => (),
// _ => (),
// }
for (&wid, window) in state.windows.windows.iter() {
let (active, inactive): (Vec<_>, Vec<_>) = state
.workspaces
.workspaces
.iter()
.map(|(wsid, ws)| (wsid, ws.is_active))
.partition(|(_, a)| *a);
if let Some(wsid) = window.workspace_id {
if let Some(&window_tag) = tags.windows.get(&wid) {
if let Some(&tag_enabled) = tags.tags.get(&window_tag) {
if tag_enabled && inactive.contains(&(&wsid, false)) {
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Index(0),
focus: false,
}),
)
.await?;
tracing::debug!("making visible {}", wid);
} else if !tag_enabled && active.contains(&(&wsid, true)) {
let hidden = *inactive.first().unwrap().0;
tell(
&mut socket,
Request::Action(Action::MoveWindowToWorkspace {
window_id: Some(wid),
reference: WorkspaceReferenceArg::Id(hidden),
focus: false,
}),
)
.await?;
tracing::debug!("making hidden {}", wid);
}
} else {
tags.windows.insert(wid, 0);
}
}
}
}
}
tracing::error!("Manager loop ended");
unreachable!("Manager loop ended");
Ok(())
}
#[allow(unreachable_code)]
async fn ipc_listener(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating niri-tag socket");
let sock_path = format!("/run/user/{}/niri-tag.sock", geteuid());
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old niri-tag socket");
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing niri-tag socket connection");
let listen = UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen to niri-tag socket: {}", f))?;
let mut buf = String::new();
tracing::debug!(
"this is what a proper tagcmd call looks like: {}",
serde_json::to_string(&TagCmd::Add(1))?
);
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command");
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}
fn main() -> Result<()> {
// let systemd know we're ready
let _ = libsystemd::daemon::notify(false, &[libsystemd::daemon::NotifyState::Ready])?;
// debug stuff
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)
.init();
let span = tracing::span!(tracing::Level::TRACE, "main");
let _ = span.enter();
// spawn socket listener for niri event stream
let (event_tx, event_rx) = smol::channel::unbounded();
smol::spawn(event_listener(event_tx)).detach();
// spawn socket listener for ipc
let (ipc_tx, ipc_rx) = smol::channel::unbounded();
smol::spawn(ipc_listener(ipc_tx)).detach();
// begin managing niri tags
smol::block_on(manage_tags(event_rx, ipc_rx))
}