diff --git a/apps/smoo-gadget-cli/src/main.rs b/apps/smoo-gadget-cli/src/main.rs index 4f2940f..f606f17 100644 --- a/apps/smoo-gadget-cli/src/main.rs +++ b/apps/smoo-gadget-cli/src/main.rs @@ -14,10 +14,14 @@ use smoo_proto::{Ident, OpCode, Request, Response, SMOO_STATUS_REQUEST, SMOO_STA use std::{ collections::{HashMap, HashSet, VecDeque}, convert::Infallible, + ffi::{CString, OsStr}, fs::File, io, net::SocketAddr, os::fd::{FromRawFd, IntoRawFd, OwnedFd}, + os::unix::fs::FileTypeExt, + os::fd::AsRawFd, + os::unix::process::CommandExt, path::{Path, PathBuf}, sync::{ atomic::{AtomicU64, Ordering}, @@ -25,6 +29,7 @@ use std::{ }, time::{Duration, Instant}, }; +use std::io::Write; use tokio::{ io::AsyncReadExt, signal, @@ -37,8 +42,9 @@ use tokio::{ task::JoinHandle, }; use tokio_util::sync::CancellationToken; -use tracing::{debug, info, trace, warn}; +use tracing::{debug, error, info, trace, warn}; use tracing_subscriber::prelude::*; +use tracing_subscriber::fmt::MakeWriter; use usb_gadget::{ function::custom::{ CtrlReceiver, CtrlReq, CtrlSender, Custom, Endpoint, EndpointDirection, Event, Interface, @@ -47,6 +53,33 @@ use usb_gadget::{ Class, Config, Gadget, Id, RegGadget, Strings, }; +struct KmsgWriter { + file: File, +} + +impl Write for KmsgWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.file.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.file.flush() + } +} + +struct KmsgMakeWriter; + +impl<'a> MakeWriter<'a> for KmsgMakeWriter { + type Writer = Box; + + fn make_writer(&'a self) -> Self::Writer { + match File::options().write(true).open("/dev/kmsg") { + Ok(file) => Box::new(KmsgWriter { file }), + Err(_) => Box::new(io::sink()), + } + } +} + const SMOO_CLASS: u8 = 0xFF; const SMOO_SUBCLASS: u8 = 0x53; const SMOO_PROTOCOL: u8 = 0x4D; @@ -92,6 +125,12 @@ struct Args { /// Expose Prometheus metrics on this TCP port (0 disables). #[arg(long, default_value_t = 0)] metrics_port: u16, + /// Run as the initramfs PID1 wrapper (auto-enabled when argv0 == /init). + #[arg(long)] + pid1: bool, + /// Internal flag for the forked gadget child. + #[arg(long, hide = true)] + pid1_child: bool, } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -113,14 +152,24 @@ impl From for DmaHeap { #[tokio::main(flavor = "multi_thread")] async fn main() -> Result<()> { - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); + let result = main_impl().await; + if let Err(err) = &result { + error!(error = ?err, "smoo-gadget-cli exiting with error"); + } + result +} - let args = Args::parse(); +async fn main_impl() -> Result<()> { + let mut args = Args::parse(); + let argv0 = std::env::args().next().unwrap_or_default(); + let auto_pid1 = argv0 == "/init"; + if (args.pid1 || auto_pid1) && !args.pid1_child { + args.pid1 = true; + init_logging(true); + run_pid1().context("pid1 initramfs flow")?; + return Ok(()); + } + init_logging(args.pid1_child); let metrics_shutdown = CancellationToken::new(); let metrics_task = spawn_metrics_listener(args.metrics_port, metrics_shutdown.clone())?; let mut ublk = SmooUblk::new().context("init ublk")?; @@ -226,6 +275,562 @@ async fn main() -> Result<()> { result } +fn init_logging(pid1: bool) { + let filter = + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()); + if pid1 { + tracing_subscriber::registry() + .with(filter) + .with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .without_time() + .with_writer(KmsgMakeWriter), + ) + .init(); + } else { + tracing_subscriber::registry() + .with(filter) + .with(tracing_subscriber::fmt::layer()) + .init(); + } +} + +fn run_pid1() -> Result<()> { + ensure!( + unsafe { libc::getpid() } == 1, + "pid1 mode requires PID 1" + ); + + info!("pid1: starting smoo initramfs flow"); + mount_fs(Some("proc"), "/proc", Some("proc"), 0, None).ok(); + mount_fs(Some("sysfs"), "/sys", Some("sysfs"), 0, None).ok(); + mount_fs(Some("devtmpfs"), "/dev", Some("devtmpfs"), 0, None).ok(); + mount_fs(Some("tmpfs"), "/run", Some("tmpfs"), 0, None).ok(); + debug!("pid1: mounted proc/sys/dev/run"); + + let default_modules = [ + "configfs", + "ublk", + "ublk_drv", + "overlay", + "erofs", + "libcomposite", + "usb_f_fs", + ]; + let modules = load_modules_from_dir("/etc/modules-load.d") + .filter(|list| !list.is_empty()) + .unwrap_or_else(|| default_modules.iter().map(|s| s.to_string()).collect()); + match ModuleIndex::load() { + Ok(module_index) => { + for module in modules { + if let Err(err) = module_index.load_module_by_name(&module) { + warn!("module load failed for {module}: {err:#}"); + } + } + } + Err(err) => { + warn!("module index unavailable: {err:#}"); + } + } + + std::fs::create_dir_all("/sys/kernel/config").ok(); + mount_fs(Some("configfs"), "/sys/kernel/config", Some("configfs"), 0, None).ok(); + debug!("pid1: mounted configfs"); + + let udc_wait_secs = 15; + info!("pid1: waiting for UDC (timeout {udc_wait_secs}s)"); + if !wait_for_udc(Duration::from_secs(udc_wait_secs))? { + error!("pid1: fatal UDC not ready after {udc_wait_secs}s"); + return Err(anyhow!("UDC not ready after {udc_wait_secs}s")); + } + + info!("pid1: spawning gadget child"); + let mut child = spawn_gadget_child().context("spawn gadget child")?; + info!("pid1: gadget child pid {}", child.id()); + let ublk_dev = "/dev/ublkb0"; + let wait_secs = 30; + debug!("pid1: waiting for block device {ublk_dev} (timeout {wait_secs}s)"); + if !wait_for_block_device(ublk_dev, Duration::from_secs(wait_secs), &mut child)? { + error!("pid1: fatal timeout waiting for {ublk_dev} after {wait_secs}s"); + return Err(anyhow!("timed out waiting for {ublk_dev}")); + } + info!("pid1: found ublk device {ublk_dev}"); + + std::fs::create_dir_all("/lower").ok(); + std::fs::create_dir_all("/upper").ok(); + std::fs::create_dir_all("/newroot").ok(); + + debug!("pid1: mounting lower erofs from {ublk_dev}"); + mount_fs( + Some(ublk_dev), + "/lower", + Some("erofs"), + libc::MS_RDONLY as libc::c_ulong, + None, + ) + .context("mount erofs lower")?; + debug!("pid1: mounted lower EROFS"); + debug!("pid1: mounting upper tmpfs"); + mount_fs(Some("tmpfs"), "/upper", Some("tmpfs"), 0, None).context("mount tmpfs upper")?; + std::fs::create_dir_all("/upper/upper").ok(); + std::fs::create_dir_all("/upper/work").ok(); + if !filesystem_available("overlay")? { + return Err(anyhow!("overlayfs not available in kernel")); + } + debug!("pid1: mounting overlay root"); + mount_fs( + Some("overlay"), + "/newroot", + Some("overlay"), + 0, + Some("lowerdir=/lower,upperdir=/upper/upper,workdir=/upper/work"), + ) + .context("mount overlay root")?; + debug!("pid1: mounted overlay root"); + + // Avoid EINVAL from pivot_root on shared mount trees. + debug!("pid1: making / private"); + mount_fs(None, "/", None, libc::MS_PRIVATE | libc::MS_REC, None) + .context("make / private")?; + + if matches!(cmdline_value("smoo.break").as_deref(), Some("1")) { + debug_shell("smoo.break requested")?; + } + + std::fs::create_dir_all("/newroot/proc").ok(); + std::fs::create_dir_all("/newroot/sys").ok(); + std::fs::create_dir_all("/newroot/dev").ok(); + std::fs::create_dir_all("/newroot/run").ok(); + + debug!("pid1: moving proc/sys/dev/run into newroot"); + move_mount("/proc", "/newroot/proc").ok(); + move_mount("/sys", "/newroot/sys").ok(); + move_mount("/dev", "/newroot/dev").ok(); + move_mount("/run", "/newroot/run").ok(); + debug!("pid1: moved proc/sys/dev/run into newroot"); + + std::env::set_current_dir("/newroot").ok(); + debug!("pid1: moving newroot to /"); + mount_fs(Some("/newroot"), "/", None, libc::MS_MOVE as libc::c_ulong, None) + .context("move newroot to /")?; + debug!("pid1: chrooting to new root"); + chroot_to(".").context("chroot to new root")?; + std::env::set_current_dir("/").ok(); + info!("pid1: switched root"); + + info!("pid1: exec /sbin/init"); + let err = std::process::Command::new("/sbin/init").exec(); + Err(anyhow!("exec /sbin/init failed: {err}")) +} + +fn cmdline_value(key: &str) -> Option { + let data = std::fs::read_to_string("/proc/cmdline").ok()?; + for token in data.split_whitespace() { + if let Some(value) = token.strip_prefix(&format!("{key}=")) { + return Some(value.to_string()); + } + } + None +} + +fn cmdline_u16(key: &str) -> Option { + let raw = cmdline_value(key)?; + match raw.parse::() { + Ok(value) => Some(value), + Err(_) => { + warn!("pid1: invalid {key} value '{raw}'"); + None + } + } +} + +fn load_modules_from_dir(path: &str) -> Option> { + let mut modules = Vec::new(); + let entries = std::fs::read_dir(path).ok()?; + for entry in entries.filter_map(Result::ok) { + let name = entry.file_name(); + if name.to_string_lossy().ends_with(".conf") { + if let Ok(contents) = std::fs::read_to_string(entry.path()) { + for line in contents.lines() { + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + modules.push(trimmed.to_string()); + } + } + } + } + Some(modules) +} + +struct ModuleIndex { + base_dir: PathBuf, + name_to_path: HashMap, + path_to_deps: HashMap>, + aliases: Vec<(String, String)>, +} + +impl ModuleIndex { + fn load() -> Result { + let release = std::fs::read_to_string("/proc/sys/kernel/osrelease") + .context("read /proc/sys/kernel/osrelease")?; + let base_dir = PathBuf::from("/lib/modules").join(release.trim()); + let dep_path = base_dir.join("modules.dep"); + let alias_path = base_dir.join("modules.alias"); + + let mut name_to_path = HashMap::new(); + let mut path_to_deps = HashMap::new(); + let dep_contents = + std::fs::read_to_string(&dep_path).with_context(|| format!("read {}", dep_path.display()))?; + for line in dep_contents.lines() { + let (path, deps) = match line.split_once(':') { + Some(parts) => parts, + None => continue, + }; + let path = path.trim().to_string(); + let deps = deps + .split_whitespace() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect::>(); + let name = module_name_from_path(&path); + name_to_path.entry(name).or_insert_with(|| path.clone()); + path_to_deps.insert(path, deps); + } + + let mut aliases = Vec::new(); + if let Ok(alias_contents) = std::fs::read_to_string(&alias_path) { + for line in alias_contents.lines() { + let line = line.trim(); + if !line.starts_with("alias ") { + continue; + } + let mut parts = line.split_whitespace(); + let _ = parts.next(); + if let (Some(pattern), Some(target)) = (parts.next(), parts.next()) { + aliases.push((pattern.to_string(), target.to_string())); + } + } + } + + Ok(Self { + base_dir, + name_to_path, + path_to_deps, + aliases, + }) + } + + fn load_module_by_name(&self, name: &str) -> Result<()> { + let path = self + .resolve_module_path(name) + .ok_or_else(|| anyhow!("module {name} not found"))?; + let mut loaded = HashSet::new(); + let mut stack = HashSet::new(); + self.load_module_recursive(&path, &mut loaded, &mut stack) + } + + fn resolve_module_path(&self, name: &str) -> Option { + if let Some(path) = self.name_to_path.get(name) { + return Some(path.clone()); + } + for (pattern, target) in &self.aliases { + if glob_match(pattern, name) { + if let Some(path) = self.name_to_path.get(target) { + return Some(path.clone()); + } + } + } + None + } + + fn load_module_recursive( + &self, + rel_path: &str, + loaded: &mut HashSet, + stack: &mut HashSet, + ) -> Result<()> { + if loaded.contains(rel_path) { + return Ok(()); + } + if !stack.insert(rel_path.to_string()) { + return Err(anyhow!("dependency cycle at {rel_path}")); + } + + let deps = self + .path_to_deps + .get(rel_path) + .cloned() + .unwrap_or_default(); + for dep in deps { + self.load_module_recursive(&dep, loaded, stack)?; + } + + let path = self.base_dir.join(rel_path); + let params = CString::new("")?; + let file = File::open(&path).with_context(|| format!("open {}", path.display()))?; + let fd = file.as_raw_fd(); + let res = if is_compressed_module(&path) { + // Let the kernel handle module decompression when supported. + finit_module(fd, ¶ms, MODULE_INIT_COMPRESSED_FILE) + } else { + finit_module(fd, ¶ms, 0) + }; + if res != 0 { + let err = io::Error::last_os_error(); + if err.raw_os_error() != Some(libc::EEXIST) { + warn!("pid1: module {} load failed: {}", path.display(), err); + return Err(err).with_context(|| format!("finit_module {}", path.display())); + } + } + info!("pid1: module {} load ok", path.display()); + + loaded.insert(rel_path.to_string()); + stack.remove(rel_path); + Ok(()) + } +} + +fn finit_module(fd: libc::c_int, params: &CString, flags: libc::c_int) -> libc::c_long { + unsafe { + libc::syscall( + libc::SYS_finit_module, + fd, + params.as_ptr(), + flags, + ) + } +} + +const MODULE_INIT_COMPRESSED_FILE: libc::c_int = 4; + +fn module_name_from_path(path: &str) -> String { + let filename = Path::new(path) + .file_name() + .and_then(|s| s.to_str()) + .unwrap_or(path); + let mut name = filename.to_string(); + for suffix in [".xz", ".zst", ".gz"] { + if let Some(stripped) = name.strip_suffix(suffix) { + name = stripped.to_string(); + } + } + if let Some(stripped) = name.strip_suffix(".ko") { + name = stripped.to_string(); + } + name +} + +fn is_compressed_module(path: &Path) -> bool { + matches!( + path.extension().and_then(|s| s.to_str()), + Some("xz") | Some("zst") | Some("gz") + ) +} + +fn glob_match(pattern: &str, text: &str) -> bool { + let (mut pi, mut ti) = (0usize, 0usize); + let (mut star_pi, mut star_ti) = (None, None); + let p = pattern.as_bytes(); + let t = text.as_bytes(); + + while ti < t.len() { + if pi < p.len() && (p[pi] == b'?' || p[pi] == t[ti]) { + pi += 1; + ti += 1; + continue; + } + if pi < p.len() && p[pi] == b'*' { + star_pi = Some(pi); + star_ti = Some(ti); + pi += 1; + continue; + } + if let (Some(sp), Some(st)) = (star_pi, star_ti) { + pi = sp + 1; + ti = st + 1; + star_ti = Some(ti); + continue; + } + return false; + } + while pi < p.len() && p[pi] == b'*' { + pi += 1; + } + pi == p.len() +} + +fn debug_shell(reason: &str) -> Result<()> { + warn!("pid1: dropping to shell ({reason})"); + for dev in ["/dev/ttyMSM0", "/dev/console"] { + if let Ok(meta) = std::fs::metadata(dev) { + if meta.file_type().is_char_device() { + let file = std::fs::OpenOptions::new() + .read(true) + .write(true) + .open(dev) + .with_context(|| format!("open {dev}"))?; + let _ = unsafe { libc::setsid() }; + let err = std::process::Command::new("/bin/sh") + .arg("-i") + .stdin(file.try_clone()?) + .stdout(file.try_clone()?) + .stderr(file) + .exec(); + return Err(anyhow!("exec /bin/sh failed: {err}")); + } + } + } + Err(anyhow!("no console device available for debug shell")) +} + +fn wait_for_udc(timeout: Duration) -> Result { + let start = Instant::now(); + let mut warned_missing = false; + let mut ticks: u32 = 0; + loop { + if Path::new("/sys/class/udc").exists() { + if let Ok(entries) = std::fs::read_dir("/sys/class/udc") { + if let Some(entry) = entries.filter_map(Result::ok).next() { + let name = entry.file_name().to_string_lossy().to_string(); + info!("pid1: UDC ready ({name})"); + return Ok(true); + } + } + } else if !warned_missing { + warn!("pid1: /sys/class/udc missing"); + warned_missing = true; + } + ticks = ticks.wrapping_add(1); + if ticks % 5 == 0 { + debug!("pid1: UDC not ready yet"); + } + if start.elapsed() >= timeout { + warn!("pid1: UDC wait timed out"); + return Ok(false); + } + std::thread::sleep(Duration::from_secs(1)); + } +} + +fn wait_for_block_device(path: &str, timeout: Duration, child: &mut std::process::Child) -> Result { + let start = Instant::now(); + let mut ticks: u32 = 0; + loop { + if let Ok(meta) = std::fs::metadata(path) { + if meta.file_type().is_block_device() { + return Ok(true); + } + } + if let Ok(Some(status)) = child.try_wait() { + error!("pid1: gadget child exited while waiting for {path}: {status}"); + return Err(anyhow!("gadget child exited: {status}")); + } + ticks = ticks.wrapping_add(1); + if ticks % 5 == 0 { + debug!("pid1: waiting for {path}"); + } + if start.elapsed() >= timeout { + return Ok(false); + } + std::thread::sleep(Duration::from_secs(1)); + } +} + +fn spawn_gadget_child() -> Result { + let exe = std::env::current_exe().context("locate self")?; + let mut child_args: Vec<_> = std::env::args_os().collect(); + child_args.retain(|arg| { + arg != OsStr::new("--pid1") + && arg != OsStr::new("--pid1-child") + && !arg.to_string_lossy().starts_with("--queue-depth") + && !arg.to_string_lossy().starts_with("--queue-count") + }); + if let Some(queue_depth) = cmdline_u16("smoo.queue_depth") + .or_else(|| cmdline_u16("smoo.queue_size")) + { + child_args.push(OsStr::new("--queue-depth").to_os_string()); + child_args.push(OsStr::new(&queue_depth.to_string()).to_os_string()); + info!("pid1: using queue depth {queue_depth} from cmdline"); + } + if let Some(queue_count) = cmdline_u16("smoo.queue_count") { + child_args.push(OsStr::new("--queue-count").to_os_string()); + child_args.push(OsStr::new(&queue_count.to_string()).to_os_string()); + info!("pid1: using queue count {queue_count} from cmdline"); + } + child_args.push(OsStr::new("--pid1-child").to_os_string()); + let mut cmd = std::process::Command::new(exe); + if let Some(log_level) = cmdline_value("smoo.log") { + cmd.env("RUST_LOG", log_level); + info!("pid1: set RUST_LOG from smoo.log"); + } + debug!( + "pid1: spawning gadget child exe={:?} args={:?}", + cmd.get_program(), + child_args + ); + cmd.args(child_args.iter().skip(1)); + cmd.stdin(std::process::Stdio::null()); + cmd.spawn().context("spawn gadget process") +} + +fn filesystem_available(name: &str) -> Result { + let data = std::fs::read_to_string("/proc/filesystems").context("read /proc/filesystems")?; + Ok(data.lines().any(|line| line.split_whitespace().last() == Some(name))) +} + +fn move_mount(src: &str, dst: &str) -> Result<()> { + mount_fs( + Some(src), + dst, + None, + libc::MS_MOVE as libc::c_ulong, + None, + ) + .with_context(|| format!("move mount {src} -> {dst}")) +} + +fn chroot_to(path: &str) -> Result<()> { + let path = CString::new(path)?; + let res = unsafe { libc::chroot(path.as_ptr()) }; + if res != 0 { + return Err(io::Error::last_os_error()).context("chroot syscall failed"); + } + Ok(()) +} + +fn mount_fs( + source: Option<&str>, + target: &str, + fstype: Option<&str>, + flags: libc::c_ulong, + data: Option<&str>, +) -> Result<()> { + let target = CString::new(target)?; + let source = source.map(CString::new).transpose()?; + let fstype = fstype.map(CString::new).transpose()?; + let data = data.map(CString::new).transpose()?; + let data_ptr = data + .as_ref() + .map(|s| s.as_ptr() as *const libc::c_void) + .unwrap_or(std::ptr::null()); + let res = unsafe { + libc::mount( + source.as_ref().map(|s| s.as_ptr()).unwrap_or(std::ptr::null()), + target.as_ptr(), + fstype.as_ref().map(|s| s.as_ptr()).unwrap_or(std::ptr::null()), + flags, + data_ptr, + ) + }; + if res != 0 { + return Err(io::Error::last_os_error()).context("mount failed"); + } + Ok(()) +} + fn spawn_metrics_listener( port: u16, shutdown: CancellationToken,