local.rs (13038B)
1 //! Local dotfile installation utilities 2 //! 3 //! ``` 4 //! copy_file("foo", "~/foo"); 5 //! link_file("bar", "~/bar"); 6 //! run_command("echo 'Hello world'"); 7 //! ``` 8 9 use anyhow::{bail, Context, Result}; 10 use shellexpand::tilde; 11 use std::fs; 12 #[cfg(target_family = "unix")] 13 use std::os::unix::fs::symlink; 14 use std::path::{PathBuf, absolute}; 15 use std::process::Command; 16 17 /// Copies the contents of a file to another file 18 /// 19 /// Tildes are expanded if present and the destination file is overwritten if 20 /// necessary. 21 /// 22 /// ``` 23 /// copy_file("foo", "~/foo"); 24 /// ``` 25 pub fn copy_file(src: &str, dst: &str) -> Result<()> { 26 let src_abs = absolute(src).with_context(|| { 27 format!("Failed to make {} absolute", src) 28 })?; 29 let dst_abs = absolute(dst).with_context(|| { 30 format!("Failed to make {} absolute", dst) 31 })?; 32 if src_abs == dst_abs { return Ok(()); } 33 34 let _dst = prepare_path(dst)?; 35 fs::copy(src, _dst)?; 36 Ok(()) 37 } 38 39 /// Creates a symbolic link to a file 40 /// 41 /// Tildes are expanded if present and the destination file is overwritten if 42 /// necessary. On non-Unix platforms, a hard link will be created instead. 43 /// 44 /// ``` 45 /// link_file("bar", "~/bar"); 46 /// ``` 47 #[cfg(target_family = "unix")] 48 pub fn link_file(src: &str, dst: &str) -> Result<()> { 49 let src_abs = absolute(src).with_context(|| { 50 format!("Failed to make {} absolute", src) 51 })?; 52 let dst_abs = absolute(dst).with_context(|| { 53 format!("Failed to make {} absolute", dst) 54 })?; 55 if src_abs == dst_abs { return Ok(()); } 56 57 let _dst = prepare_path(dst)?; 58 symlink(src_abs, _dst)?; 59 Ok(()) 60 } 61 #[cfg(not(target_family = "unix"))] 62 pub fn link_file(src: &str, dst: &str) -> Result<()> { 63 let src_abs = absolute(src).with_context(|| { 64 format!("Failed to make {} absolute", src) 65 })?; 66 let dst_abs = absolute(dst).with_context(|| { 67 format!("Failed to make {} absolute", dst) 68 })?; 69 if src_abs == dst_abs { return Ok(()); } 70 71 let _dst = prepare_path(dst)?; 72 fs::hard_link(src, _dst)?; 73 Ok(()) 74 } 75 76 /// Creates the parent directories of a path, deletes the file if it exists, and 77 /// returns the path with tildes expanded 78 /// 79 /// ``` 80 /// prepare_path("~/foo"); 81 /// ``` 82 fn prepare_path(path: &str) -> Result<PathBuf> { 83 let _dst: PathBuf = (&tilde(path).to_mut()).into(); 84 if let Some(_path) = _dst.parent() { 85 fs::create_dir_all(_path).with_context(|| { 86 format!("Failed to create parent directories of {}", path) 87 })?; 88 } 89 if fs::symlink_metadata(&_dst).is_ok() { 90 // Check for existing files, including broken symlinks 91 fs::remove_file(&_dst).with_context(|| { 92 format!("Failed to remove existing file at {}", path) 93 })?; 94 } 95 Ok(_dst) 96 } 97 98 /// Executes a command using `sh` on Unix and `cmd` on Windows 99 /// 100 /// ``` 101 /// run_command("echo 'Hello world'"); 102 /// ``` 103 pub fn run_command(command: &str) -> Result<()> 104 { 105 let mut cmd; 106 if cfg!(target_family = "unix") { 107 cmd = Command::new("sh"); 108 cmd.args(["-c", command]); 109 } else { 110 cmd = Command::new("cmd.exe"); 111 cmd.args(["/C", command]); 112 } 113 114 let status = cmd.status().with_context(|| { 115 format!("Failed to execute {:?}", cmd) 116 })?; 117 if !status.success() { 118 bail!("Process terminated unsuccessfully: {}", status); 119 } 120 Ok(()) 121 } 122 123 #[cfg(test)] 124 mod tests { 125 use super::*; 126 use crate::test_utils::{setup_integration, write_file}; 127 128 #[test] 129 fn test_copy_file_create_dirs() { 130 let tmp = setup_integration("test_copy_file_create_dirs"); 131 132 let src = &tmp.local.join("foo"); 133 let dst = &tmp.local.join("dir1").join("dir2").join("bar"); 134 write_file(src, "old contents of foo"); 135 136 let result = copy_file(src.to_str().unwrap(), dst.to_str().unwrap()); 137 138 write_file(src, "new contents of foo"); 139 let contents = fs::read_to_string(dst).unwrap(); 140 assert_eq!(result.is_ok(), true); 141 assert_eq!(contents, "old contents of foo"); 142 } 143 144 #[test] 145 fn test_copy_file_same_file() { 146 let tmp = setup_integration("test_copy_file_same_file"); 147 148 let src = &tmp.local.join("foo"); 149 let dst = &tmp.local.join("foo"); 150 write_file(src, "contents of foo"); 151 152 let result = copy_file(src.to_str().unwrap(), dst.to_str().unwrap()); 153 154 let contents = fs::read_to_string(dst).unwrap(); 155 assert_eq!(result.is_ok(), true); 156 assert_eq!(contents, "contents of foo"); 157 } 158 159 #[test] 160 fn test_copy_file_existing_file() { 161 let tmp = setup_integration("test_copy_file_existing_file"); 162 163 let src = &tmp.local.join("foo"); 164 let dst = &tmp.local.join("bar"); 165 write_file(src, "old contents of foo"); 166 write_file(dst, "old contents of bar"); 167 168 let result = copy_file(src.to_str().unwrap(), dst.to_str().unwrap()); 169 170 write_file(src, "new contents of foo"); 171 let contents = fs::read_to_string(dst).unwrap(); 172 assert_eq!(result.is_ok(), true); 173 assert_eq!(contents, "old contents of foo"); 174 } 175 176 #[test] 177 #[cfg(target_family = "unix")] 178 fn test_copy_file_existing_broken_symlink() { 179 let tmp = setup_integration("test_copy_file_existing_broken_symlink"); 180 181 let src = &tmp.local.join("foo"); 182 let dst = &tmp.local.join("bar"); 183 write_file(src, "old contents of foo"); 184 symlink("missing", dst).unwrap(); 185 186 let result = copy_file(src.to_str().unwrap(), dst.to_str().unwrap()); 187 188 write_file(src, "new contents of foo"); 189 let contents = fs::read_to_string(dst).unwrap(); 190 assert_eq!(result.is_ok(), true); 191 assert_eq!(contents, "old contents of foo"); 192 } 193 194 #[test] 195 #[cfg(target_family = "unix")] 196 fn test_copy_file_tilde_expansion() { 197 let tmp = setup_integration("test_copy_file_tilde_expansion"); 198 199 let src = &tmp.local.join("foo"); 200 let dst = &tmp.home.join("dir").join("bar"); 201 let dst_tilde = "~/test_copy_file_tilde_expansion/dir/bar"; 202 write_file(src, "old contents of foo"); 203 204 let result = copy_file(src.to_str().unwrap(), dst_tilde); 205 206 write_file(src, "new contents of foo"); 207 let contents = fs::read_to_string(dst).unwrap(); 208 assert_eq!(result.is_ok(), true); 209 assert_eq!(contents, "old contents of foo"); 210 } 211 212 #[test] 213 fn test_link_file_create_dirs() { 214 let tmp = setup_integration("test_link_file_create_dirs"); 215 216 let src = &tmp.local.join("foo"); 217 let dst = &tmp.local.join("dir1").join("dir2").join("bar"); 218 write_file(src, "old contents of foo"); 219 220 let result = link_file(src.to_str().unwrap(), dst.to_str().unwrap()); 221 222 write_file(src, "new contents of foo"); 223 let contents = fs::read_to_string(dst).unwrap(); 224 assert_eq!(result.is_ok(), true); 225 assert_eq!(contents, "new contents of foo"); 226 } 227 228 #[test] 229 fn test_link_file_same_file() { 230 let tmp = setup_integration("test_link_file_same_file"); 231 232 let src = &tmp.local.join("foo"); 233 let dst = &tmp.local.join("foo"); 234 write_file(src, "contents of foo"); 235 236 let result = link_file(src.to_str().unwrap(), dst.to_str().unwrap()); 237 238 let contents = fs::read_to_string(dst).unwrap(); 239 assert_eq!(result.is_ok(), true); 240 assert_eq!(contents, "contents of foo"); 241 } 242 243 #[test] 244 fn test_link_file_existing_file() { 245 let tmp = setup_integration("test_link_file_existing_file"); 246 247 let src = &tmp.local.join("foo"); 248 let dst = &tmp.local.join("bar"); 249 write_file(src, "old contents of foo"); 250 write_file(dst, "old contents of bar"); 251 252 let result = link_file(src.to_str().unwrap(), dst.to_str().unwrap()); 253 254 write_file(src, "new contents of foo"); 255 let contents = fs::read_to_string(dst).unwrap(); 256 assert_eq!(result.is_ok(), true); 257 assert_eq!(contents, "new contents of foo"); 258 } 259 260 #[test] 261 #[cfg(target_family = "unix")] 262 fn test_link_file_existing_broken_symlink() { 263 let tmp = setup_integration("test_link_file_existing_broken_symlink"); 264 265 let src = &tmp.local.join("foo"); 266 let dst = &tmp.local.join("bar"); 267 write_file(src, "old contents of foo"); 268 symlink("missing", dst).unwrap(); 269 270 let result = link_file(src.to_str().unwrap(), dst.to_str().unwrap()); 271 272 write_file(src, "new contents of foo"); 273 let contents = fs::read_to_string(dst).unwrap(); 274 assert_eq!(result.is_ok(), true); 275 assert_eq!(contents, "new contents of foo"); 276 } 277 278 #[test] 279 #[cfg(target_family = "unix")] 280 fn test_link_file_tilde_expansion() { 281 let tmp = setup_integration("test_link_file_tilde_expansion"); 282 283 let src = &tmp.local.join("foo"); 284 let dst = &tmp.home.join("dir").join("bar"); 285 let dst_tilde = "~/test_link_file_tilde_expansion/dir/bar"; 286 write_file(src, "old contents of foo"); 287 288 let result = link_file(src.to_str().unwrap(), dst_tilde); 289 290 write_file(src, "new contents of foo"); 291 let contents = fs::read_to_string(dst).unwrap(); 292 assert_eq!(result.is_ok(), true); 293 assert_eq!(contents, "new contents of foo"); 294 } 295 296 #[test] 297 #[cfg(target_family = "unix")] 298 fn test_link_file_relative_source() { 299 let dir = PathBuf::from("tests/.temp/ssh/test_link_file_relative_source"); 300 fs::create_dir_all(&dir).unwrap(); 301 302 let src = absolute(&dir.join("foo")).unwrap(); 303 let src_rel = "tests/.temp/ssh/test_link_file_relative_source/foo"; 304 let dst = &dir.join("dir1").join("dir2").join("bar"); 305 write_file(&src, "old contents of foo"); 306 307 let result = link_file(src_rel, dst.to_str().unwrap()); 308 309 write_file(&src, "new contents of foo"); 310 let contents = fs::read_to_string(dst).unwrap(); 311 let link = fs::read_link(dst).unwrap(); 312 assert_eq!(result.is_ok(), true); 313 assert_eq!(contents, "new contents of foo"); 314 assert_eq!(link, src); // src changed to absolute path 315 316 fs::remove_dir_all(&dir).unwrap(); 317 } 318 319 #[test] 320 #[cfg(target_family = "unix")] 321 fn test_run_command_successful() { 322 let tmp = setup_integration("test_run_command_successful"); 323 324 let src = &tmp.local.join("foo"); 325 write_file(src, "exit 0"); 326 327 let result = run_command(&format!("sh {}", src.to_str().unwrap())); 328 329 assert_eq!(result.is_ok(), true); 330 } 331 332 #[test] 333 #[cfg(target_family = "windows")] 334 fn test_run_command_successful() { 335 let tmp = setup_integration("test_run_command_successful"); 336 337 let src = &tmp.local.join("foo.bat"); 338 write_file(src, "exit 0"); 339 340 let result = run_command(src.to_str().unwrap()); 341 342 assert_eq!(result.is_ok(), true); 343 } 344 345 #[test] 346 #[cfg(target_family = "unix")] 347 fn test_run_command_failure() { 348 let tmp = setup_integration("test_run_command_failure"); 349 350 let src = &tmp.local.join("foo"); 351 write_file(src, "exit 2"); 352 353 let result = run_command(&format!("sh {}", src.to_str().unwrap())); 354 355 assert_eq!(result.is_ok(), false); 356 assert_eq!(result.unwrap_err().to_string(), 357 "Process terminated unsuccessfully: exit status: 2"); 358 } 359 360 #[test] 361 #[cfg(target_family = "windows")] 362 fn test_run_command_failure() { 363 let tmp = setup_integration("test_run_command_failure"); 364 365 let src = &tmp.local.join("foo.bat"); 366 write_file(src, "exit 1"); 367 368 let result = run_command(src.to_str().unwrap()); 369 370 assert_eq!(result.is_ok(), false); 371 assert_eq!(result.unwrap_err().to_string(), 372 "Process terminated unsuccessfully: exit code: 1"); 373 } 374 375 #[test] 376 #[cfg(target_family = "unix")] 377 fn test_run_command_arguments() { 378 let tmp = setup_integration("test_run_command_arguments"); 379 380 let src = &tmp.local.join("foo"); 381 let dst = &tmp.local.join("bar"); 382 write_file(src, &format!("echo $@ > {}", dst.to_str().unwrap())); 383 384 let result = run_command(&format!("sh {} arg1 arg2", 385 src.to_str().unwrap())); 386 387 let contents = fs::read_to_string(dst).unwrap(); 388 assert_eq!(result.is_ok(), true); 389 assert_eq!(contents, "arg1 arg2\n"); 390 } 391 392 #[test] 393 #[cfg(target_family = "windows")] 394 fn test_run_command_arguments() { 395 let tmp = setup_integration("test_run_command_arguments"); 396 397 let src = &tmp.local.join("foo.bat"); 398 let dst = &tmp.local.join("bar"); 399 write_file(src, &format!("echo %* > {}", dst.to_str().unwrap())); 400 401 let result = run_command(&format!("{} arg1 arg2", 402 src.to_str().unwrap())); 403 404 let contents = fs::read_to_string(dst).unwrap(); 405 assert_eq!(result.is_ok(), true); 406 assert_eq!(contents, "arg1 arg2 \r\n"); 407 } 408 }