niri-tag/daemon/ipc.rs

166 lines
5.7 KiB
Rust

use std::{
collections::{BTreeMap, HashMap},
net::SocketAddr,
os::linux::net::SocketAddrExt,
path::PathBuf,
str::FromStr,
};
use crate::socket::{create_niri_socket, tell};
use anyhow::{Error, Result, anyhow};
use microxdg::Xdg;
use niri_ipc::{Event, Request};
use niri_tag::{TagCmd, TagEvent, TagState};
use nix::unistd::geteuid;
use smol::{
channel::{self},
future,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
net::unix::{UnixListener, UnixStream},
stream::StreamExt,
};
#[allow(unreachable_code)]
pub async fn event_consumer(tx: channel::Sender<Event>) -> Result<()> {
tracing::debug!("creating event consumer");
let mut socket = create_niri_socket().await?;
tracing::debug!("requesting event stream");
let mut buf = String::new();
tell(&mut socket, Request::EventStream).await?;
tracing::debug!("beginning event loop");
loop {
let _ = socket.read_line(&mut buf).await?;
let event: Event = serde_json::from_str(&buf)?;
tracing::trace!("event: {:?}", event);
tx.send(event).await?;
}
unreachable!("Listener loop ended");
}
fn get_run_path() -> Result<PathBuf> {
let xdg = Xdg::new()?;
Ok(xdg
.runtime()?
.unwrap_or(PathBuf::from_str(&format!("/run/user/{}", geteuid()))?))
}
async fn create_provider_socket(name: &'static str, socket: &'static str) -> Result<UnixListener> {
let mut sock_path = get_run_path()?;
sock_path.push(format!("{}.sock", socket));
if smol::fs::metadata(&sock_path).await.is_ok() {
tracing::debug!("removing old {} socket", name);
smol::fs::remove_file(&sock_path).await?;
}
tracing::debug!("establishing {} socket", name);
UnixListener::bind(&sock_path)
.inspect_err(|f| tracing::error!("failed to listen on {} socket: {}", name, f))
.map_err(|e| anyhow!(e))
}
#[allow(unreachable_code)]
pub async fn ipc_provider(tx: channel::Sender<TagCmd>) -> Result<()> {
tracing::debug!("creating ipc provider");
let mut buf = String::new();
let listen = create_provider_socket("ipc provider", "niri-tag").await?;
tracing::debug!("beginning ipc provider loop");
loop {
let (raw_sock, _addr) = listen.accept().await?;
let mut socket = BufReader::new(raw_sock);
tracing::debug!("awaiting ipc socket");
let _ = socket.read_line(&mut buf).await?;
tracing::debug!("forwarding ipc command {}", buf);
tx.send(serde_json::from_str(&buf)?).await?;
socket.close().await?;
}
tracing::error!("IPC loop ended");
unreachable!("IPC loop ended");
Ok(())
}
#[derive(Debug)]
enum EventsReceivable {
Event(TagEvent),
Conn(String, UnixStream),
}
#[allow(unreachable_code)]
pub async fn event_provider(
rx: channel::Receiver<TagEvent>,
fullstate_tx: channel::Sender<channel::Sender<HashMap<u8, TagState>>>,
) -> Result<()> {
tracing::debug!("creating event provider");
let listen = create_provider_socket("event provider", "niri-tag-events").await?;
let mut sockets = BTreeMap::new();
tracing::debug!("beginning event provider loop");
loop {
use EventsReceivable::*;
let recvd: EventsReceivable = future::or(
async { rx.recv().await.map(Event).map_err(|e| anyhow!(e)) },
async {
listen
.accept()
.await
.map(|(s, a)| Conn(format!("{:?}", a), s))
.map_err(|e| anyhow!(e))
},
)
.await?;
tracing::debug!("event provider received {:?}", recvd);
enum Res {
Ok,
BadSockets(Vec<String>),
}
let res = match recvd {
Conn(addr, mut socket) => {
tracing::debug!("received a new event provider connection");
sockets.insert(addr, socket.clone());
let (t, r) = smol::channel::bounded(1);
tracing::debug!("sending fullstate request");
fullstate_tx.send(t).await?;
match r.recv().await {
Ok(fullstate) => {
tracing::debug!("received fullstate, sending");
let data = serde_json::to_string(&TagEvent::TagFullState(fullstate))?;
if let Err(e) = socket.write_all(&[data.as_bytes(), b"\n"].concat()).await {
tracing::error!("Failed to send fullstate to socket: {}", e);
}
}
Err(e) => tracing::error!("Failed to receive fullstate: {}", e),
}
Res::Ok
}
Event(e) => {
let data = serde_json::to_string(&e).map_err(|e| anyhow!(e))?;
let conns = smol::stream::iter(sockets.iter_mut());
let send_all: Vec<(&String, Result<(), _>)> = conns
.then(async |(a, s)| (a, s.write_all(&[data.as_bytes(), b"\n"].concat()).await))
.collect()
.await;
let bad = send_all
.into_iter()
.fold(Vec::new(), |mut acc, (addr, res)| {
if let Err(e) = res {
tracing::warn!("error from event provider client {}: {}", addr, e);
acc.push(addr.to_owned());
}
acc
});
if !bad.is_empty() {
Res::BadSockets(bad)
} else {
Res::Ok
}
}
};
if let Res::BadSockets(bad) = res {
bad.into_iter().for_each(|b| {
sockets.remove(&b);
});
}
}
Ok(())
}