diff --git a/smb-core/src/nt_status.rs b/smb-core/src/nt_status.rs index 9c6e60d..8d4ab0f 100644 --- a/smb-core/src/nt_status.rs +++ b/smb-core/src/nt_status.rs @@ -19,6 +19,12 @@ pub enum NTStatus { UserSessionDeleted = 0xC0000203, NetworkSessionExpired = 0xC000035C, FileNotAvailable = 0xC0000467, + FileClosed = 0xC0000128, + EndOfFile = 0xC0000011, + InvalidInfoClass = 0xC0000003, + InvalidDeviceRequest = 0xC0000010, + BufferOverflow = 0x80000005, + InfoLengthMismatch = 0xC0000004, UnknownError = 0xFFFFFFFF, } diff --git a/smb/src/main.rs b/smb/src/main.rs index 0a2b652..e44f41a 100644 --- a/smb/src/main.rs +++ b/smb/src/main.rs @@ -47,7 +47,12 @@ async fn main() -> SMBResult<()> { .unencrypted_access(true) .require_message_signing(false) .encrypt_data(false) - .add_fs_share("test".into(), "".into(), file_allowed, get_file_perms) + .add_fs_share( + "test".into(), + std::env::var("SMB_SHARE_PATH").unwrap_or_default(), + file_allowed, + get_file_perms, + ) .add_ipc_share() .auth_provider(NTLMAuthProvider::new( vec![ diff --git a/smb/src/protocol/body/close/mod.rs b/smb/src/protocol/body/close/mod.rs index 0b08bb6..b99ff93 100644 --- a/smb/src/protocol/body/close/mod.rs +++ b/smb/src/protocol/body/close/mod.rs @@ -9,7 +9,7 @@ use crate::protocol::body::create::file_attributes::SMBFileAttributes; use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::filetime::FileTime; -mod flags; +pub mod flags; #[derive( Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, @@ -24,6 +24,16 @@ pub struct SMBCloseRequest { file_id: SMBFileId, } +impl SMBCloseRequest { + pub fn flags(&self) -> SMBCloseFlags { + self.flags + } + + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } +} + #[derive( Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] @@ -48,3 +58,118 @@ pub struct SMBCloseResponse { #[smb_direct(start(fixed = 56))] file_attributes: SMBFileAttributes, } + +impl SMBCloseResponse { + pub fn from_metadata( + metadata: &crate::server::share::SMBFileMetadata, + attributes: SMBFileAttributes, + ) -> Self { + Self { + flags: SMBCloseFlags::POSTQUERY_ATTRIB, + reserved: PhantomData, + creation_time: metadata.creation_time().clone(), + last_access_time: metadata.last_access_time().clone(), + last_write_time: metadata.last_write_time().clone(), + change_time: metadata.last_modification_time().clone(), + allocation_size: metadata.allocated_size(), + end_of_file: metadata.actual_size(), + file_attributes: attributes, + } + } + + pub fn empty() -> Self { + Self { + flags: SMBCloseFlags::empty(), + reserved: PhantomData, + creation_time: FileTime::zero(), + last_access_time: FileTime::zero(), + last_write_time: FileTime::zero(), + change_time: FileTime::zero(), + allocation_size: 0, + end_of_file: 0, + file_attributes: SMBFileAttributes::empty(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + + #[test] + fn close_response_empty_has_zero_fields() { + let resp = SMBCloseResponse::empty(); + assert_eq!(resp.flags, SMBCloseFlags::empty()); + assert_eq!(resp.allocation_size, 0); + assert_eq!(resp.end_of_file, 0); + assert_eq!(resp.file_attributes, SMBFileAttributes::empty()); + } + + #[test] + fn close_response_empty_serialization_round_trip() { + let resp = SMBCloseResponse::empty(); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBCloseResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn close_response_from_metadata_sets_postquery_flag() { + use crate::server::share::SMBFileMetadata; + let metadata = SMBFileMetadata::new( + FileTime::from_unix(1700000000), + FileTime::from_unix(1700000100), + FileTime::from_unix(1700000200), + FileTime::from_unix(1700000300), + 4096, + 1024, + ); + let resp = SMBCloseResponse::from_metadata(&metadata, SMBFileAttributes::NORMAL); + assert!(resp.flags.contains(SMBCloseFlags::POSTQUERY_ATTRIB)); + assert_eq!(resp.allocation_size, 4096); + assert_eq!(resp.end_of_file, 1024); + assert_eq!(resp.file_attributes, SMBFileAttributes::NORMAL); + } + + #[test] + fn close_response_from_metadata_serialization_round_trip() { + use crate::server::share::SMBFileMetadata; + let metadata = SMBFileMetadata::new( + FileTime::from_unix(1700000000), + FileTime::from_unix(1700000100), + FileTime::from_unix(1700000200), + FileTime::from_unix(1700000300), + 8192, + 2048, + ); + let resp = SMBCloseResponse::from_metadata(&metadata, SMBFileAttributes::ARCHIVE); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBCloseResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn close_request_accessors() { + let file_id = SMBFileId::new(42, 99); + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 24 + buf.extend_from_slice(&24u16.to_le_bytes()); + // flags (u16) = POSTQUERY_ATTRIB = 0x0001 + buf.extend_from_slice(&1u16.to_le_bytes()); + // reserved (4 bytes) + buf.extend_from_slice(&[0u8; 4]); + // file_id: persistent (u64) + volatile (u64) + buf.extend_from_slice(&42u64.to_le_bytes()); + buf.extend_from_slice(&99u64.to_le_bytes()); + buf + }; + let (_, req) = SMBCloseRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.file_id().persistent(), file_id.persistent()); + assert_eq!(req.file_id().volatile(), file_id.volatile()); + assert!(req.flags().contains(SMBCloseFlags::POSTQUERY_ATTRIB)); + } +} diff --git a/smb/src/protocol/body/create/file_id.rs b/smb/src/protocol/body/create/file_id.rs index 85237b2..8c83b33 100644 --- a/smb/src/protocol/body/create/file_id.rs +++ b/smb/src/protocol/body/create/file_id.rs @@ -7,7 +7,75 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; )] pub struct SMBFileId { #[smb_direct(start(fixed = 0))] - pub persistent: u64, + persistent: u64, #[smb_direct(start(fixed = 8))] - pub volatile: u64, + volatile: u64, +} + +impl SMBFileId { + pub fn new(persistent: u64, volatile: u64) -> Self { + Self { + persistent, + volatile, + } + } + + pub fn persistent(&self) -> u64 { + self.persistent + } + + pub fn volatile(&self) -> u64 { + self.volatile + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + + /// MS-SMB2 section 2.2.14.1: SMB2_FILEID is 16 bytes (Persistent u64 + Volatile u64) + #[test] + fn file_id_is_16_bytes() { + let fid = SMBFileId::new(0, 0); + assert_eq!(fid.smb_byte_size(), 16); + } + + /// Persistent is at offset 0, Volatile at offset 8 + #[test] + fn file_id_wire_layout() { + let fid = SMBFileId::new(0xDEAD, 0xBEEF); + let bytes = fid.smb_to_bytes(); + assert_eq!(bytes.len(), 16); + let persistent = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + let volatile = u64::from_le_bytes(bytes[8..16].try_into().unwrap()); + assert_eq!(persistent, 0xDEAD); + assert_eq!(volatile, 0xBEEF); + } + + #[test] + fn file_id_round_trip() { + let fid = SMBFileId::new(42, 99); + let bytes = fid.smb_to_bytes(); + let (_, parsed) = SMBFileId::smb_from_bytes(&bytes).unwrap(); + assert_eq!(fid, parsed); + } + + /// Per section 2.2.14.1, persistent and volatile are distinct fields. + /// Verify they serialize independently. + #[test] + fn file_id_persistent_and_volatile_are_independent() { + let a = SMBFileId::new(1, 2); + let b = SMBFileId::new(2, 1); + let bytes_a = a.smb_to_bytes(); + let bytes_b = b.smb_to_bytes(); + assert_ne!(bytes_a, bytes_b); + } + + #[test] + fn file_id_getters() { + let fid = SMBFileId::new(100, 200); + assert_eq!(fid.persistent(), 100); + assert_eq!(fid.volatile(), 200); + } } diff --git a/smb/src/protocol/body/create/mod.rs b/smb/src/protocol/body/create/mod.rs index a7c79ef..80cf1ae 100644 --- a/smb/src/protocol/body/create/mod.rs +++ b/smb/src/protocol/body/create/mod.rs @@ -65,7 +65,7 @@ pub struct SMBCreateRequest { create_options: SMBCreateOptions, #[smb_string( order = 0, - start(inner(start = 44, num_type = "u16", subtract = 68)), + start(inner(start = 44, num_type = "u16", subtract = 64)), length(inner(start = 46, num_type = "u16")), underlying = "u16" )] @@ -179,12 +179,12 @@ impl SMBCreateResponse { oplock_level: open.oplock_level(), flags: SMBCreateFlags::empty(), action: SMBCreateAction::Created, - creation_time: metadata.creation_time, - last_access_time: metadata.last_access_time, - last_write_time: metadata.last_write_time, - change_time: metadata.last_modification_time, - allocation_size: metadata.allocated_size, - end_of_file: metadata.actual_size, + creation_time: metadata.creation_time().clone(), + last_access_time: metadata.last_access_time().clone(), + last_write_time: metadata.last_write_time().clone(), + change_time: metadata.last_modification_time().clone(), + allocation_size: metadata.allocated_size(), + end_of_file: metadata.actual_size(), attributes: open.file_attributes(), reserved: PhantomData, file_id: open.file_id(), diff --git a/smb/src/protocol/body/file_info/access.rs b/smb/src/protocol/body/file_info/access.rs new file mode 100644 index 0000000..7c64ada --- /dev/null +++ b/smb/src/protocol/body/file_info/access.rs @@ -0,0 +1,61 @@ +use bitflags::bitflags; +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; + +bitflags! { + /// ACCESS_MASK flags for FILE_ACCESS_INFORMATION (MS-FSCC 2.4.1). + /// + /// These are the same ACCESS_MASK values defined in [MS-DTYP] §2.4.3 / + /// [MS-SMB2] §2.2.13.1, representing the access rights granted on the open. + #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] + pub struct FileAccessFlags: u32 { + const FILE_READ_DATA = 0x00000001; + const FILE_WRITE_DATA = 0x00000002; + const FILE_APPEND_DATA = 0x00000004; + const FILE_READ_EA = 0x00000008; + const FILE_WRITE_EA = 0x00000010; + const FILE_EXECUTE = 0x00000020; + const FILE_DELETE_CHILD = 0x00000040; + const FILE_READ_ATTRIBUTES = 0x00000080; + const FILE_WRITE_ATTRIBUTES = 0x00000100; + const DELETE = 0x00010000; + const READ_CONTROL = 0x00020000; + const WRITE_DAC = 0x00040000; + const WRITE_OWNER = 0x00080000; + const SYNCHRONIZE = 0x00100000; + const ACCESS_SYSTEM_SECURITY = 0x01000000; + const MAXIMUM_ALLOWED = 0x02000000; + const GENERIC_ALL = 0x10000000; + const GENERIC_EXECUTE = 0x20000000; + const GENERIC_WRITE = 0x40000000; + const GENERIC_READ = 0x80000000; + } +} + +impl_smb_byte_size_for_bitflag! { FileAccessFlags } +impl_smb_to_bytes_for_bitflag! { FileAccessFlags } +impl_smb_from_bytes_for_bitflag! { FileAccessFlags } + +/// FILE_ACCESS_INFORMATION (MS-FSCC 2.4.1) — 4 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileAccessInformation { + #[smb_direct(start(fixed = 0))] + access_flags: FileAccessFlags, +} + +impl FileAccessInformation { + pub fn new(access_flags: FileAccessFlags) -> Self { + Self { access_flags } + } + + pub fn access_flags(&self) -> FileAccessFlags { + self.access_flags + } +} diff --git a/smb/src/protocol/body/file_info/alignment.rs b/smb/src/protocol/body/file_info/alignment.rs new file mode 100644 index 0000000..f036d65 --- /dev/null +++ b/smb/src/protocol/body/file_info/alignment.rs @@ -0,0 +1,82 @@ +use num_enum::TryFromPrimitive; +use serde::{Deserialize, Serialize}; + +use smb_core::error::SMBError; +use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; +use smb_derive::{ + SMBByteSize as SMBByteSizeDerive, SMBFromBytes as SMBFromBytesDerive, + SMBToBytes as SMBToBytesDerive, +}; + +/// Device alignment requirements (MS-FSCC 2.4.3). +/// +/// Each value specifies the address boundary the device requires +/// for data transfers. For example, `Quad` means the device requires +/// 8-byte aligned addresses. +#[repr(u32)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, TryFromPrimitive)] +pub enum FileAlignmentRequirement { + Byte = 0x00000000, + Word = 0x00000001, + Long = 0x00000003, + Quad = 0x00000007, + Octa = 0x0000000F, + Align32 = 0x0000001F, + Align64 = 0x0000003F, + Align128 = 0x0000007F, + Align256 = 0x000000FF, + Align512 = 0x000001FF, +} + +impl SMBByteSize for FileAlignmentRequirement { + fn smb_byte_size(&self) -> usize { + std::mem::size_of::() + } +} + +impl SMBFromBytes for FileAlignmentRequirement { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { + u32::smb_from_bytes(input).map(|(remaining, val)| { + let req = Self::try_from_primitive(val).map_err(SMBError::parse_error)?; + Ok((remaining, req)) + })? + } +} + +impl SMBToBytes for FileAlignmentRequirement { + fn smb_to_bytes(&self) -> Vec { + (*self as u32).smb_to_bytes() + } +} + +/// FILE_ALIGNMENT_INFORMATION (MS-FSCC 2.4.3) — 4 bytes +#[derive( + Debug, + PartialEq, + Eq, + Clone, + Serialize, + Deserialize, + SMBByteSizeDerive, + SMBFromBytesDerive, + SMBToBytesDerive, +)] +pub struct FileAlignmentInformation { + #[smb_direct(start(fixed = 0))] + alignment_requirement: FileAlignmentRequirement, +} + +impl FileAlignmentInformation { + pub fn new(alignment_requirement: FileAlignmentRequirement) -> Self { + Self { + alignment_requirement, + } + } + + pub fn alignment_requirement(&self) -> FileAlignmentRequirement { + self.alignment_requirement + } +} diff --git a/smb/src/protocol/body/file_info/basic.rs b/smb/src/protocol/body/file_info/basic.rs new file mode 100644 index 0000000..bda566f --- /dev/null +++ b/smb/src/protocol/body/file_info/basic.rs @@ -0,0 +1,60 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +use crate::protocol::body::create::file_attributes::SMBFileAttributes; +use crate::protocol::body::filetime::FileTime; + +/// FILE_BASIC_INFORMATION (MS-FSCC 2.4.7) — 40 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileBasicInformation { + #[smb_direct(start(fixed = 0))] + creation_time: FileTime, + #[smb_direct(start(fixed = 8))] + last_access_time: FileTime, + #[smb_direct(start(fixed = 16))] + last_write_time: FileTime, + #[smb_direct(start(fixed = 24))] + change_time: FileTime, + #[smb_direct(start(fixed = 32))] + file_attributes: SMBFileAttributes, + #[smb_direct(start(fixed = 36))] + reserved: u32, +} + +impl FileBasicInformation { + pub fn new( + creation_time: FileTime, + last_access_time: FileTime, + last_write_time: FileTime, + change_time: FileTime, + file_attributes: SMBFileAttributes, + ) -> Self { + Self { + creation_time, + last_access_time, + last_write_time, + change_time, + file_attributes, + reserved: 0, + } + } + + pub fn creation_time(&self) -> &FileTime { + &self.creation_time + } + pub fn last_access_time(&self) -> &FileTime { + &self.last_access_time + } + pub fn last_write_time(&self) -> &FileTime { + &self.last_write_time + } + pub fn change_time(&self) -> &FileTime { + &self.change_time + } + pub fn file_attributes(&self) -> SMBFileAttributes { + self.file_attributes + } +} diff --git a/smb/src/protocol/body/file_info/ea.rs b/smb/src/protocol/body/file_info/ea.rs new file mode 100644 index 0000000..511c96d --- /dev/null +++ b/smb/src/protocol/body/file_info/ea.rs @@ -0,0 +1,22 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_EA_INFORMATION (MS-FSCC 2.4.12) — 4 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileEaInformation { + #[smb_direct(start(fixed = 0))] + ea_size: u32, +} + +impl FileEaInformation { + pub fn new(ea_size: u32) -> Self { + Self { ea_size } + } + + pub fn ea_size(&self) -> u32 { + self.ea_size + } +} diff --git a/smb/src/protocol/body/file_info/internal.rs b/smb/src/protocol/body/file_info/internal.rs new file mode 100644 index 0000000..2316014 --- /dev/null +++ b/smb/src/protocol/body/file_info/internal.rs @@ -0,0 +1,22 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_INTERNAL_INFORMATION (MS-FSCC 2.4.20) — 8 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileInternalInformation { + #[smb_direct(start(fixed = 0))] + index_number: u64, +} + +impl FileInternalInformation { + pub fn new(index_number: u64) -> Self { + Self { index_number } + } + + pub fn index_number(&self) -> u64 { + self.index_number + } +} diff --git a/smb/src/protocol/body/file_info/mod.rs b/smb/src/protocol/body/file_info/mod.rs new file mode 100644 index 0000000..341c217 --- /dev/null +++ b/smb/src/protocol/body/file_info/mod.rs @@ -0,0 +1,360 @@ +//! MS-FSCC File Information Classes +//! +//! Typed representations of the file information structures defined in +//! [MS-FSCC] sections 2.4.x, used in QueryInfo / SetInfo responses. + +mod access; +mod alignment; +mod basic; +mod ea; +mod internal; +mod mode; +mod name; +mod network_open; +mod position; +mod standard; + +pub use access::{FileAccessFlags, FileAccessInformation}; +pub use alignment::{FileAlignmentInformation, FileAlignmentRequirement}; +pub use basic::FileBasicInformation; +pub use ea::FileEaInformation; +pub use internal::FileInternalInformation; +pub use mode::{FileModeFlags, FileModeInformation}; +pub use name::FileNameInformation; +pub use network_open::FileNetworkOpenInformation; +pub use position::FilePositionInformation; +pub use standard::FileStandardInformation; + +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_ALL_INFORMATION (MS-FSCC 2.4.2) — composite structure +/// +/// Concatenation of sub-structures at fixed offsets: +/// basic(40) + standard(24) + internal(8) + ea(4) + access(4) +/// + position(8) + mode(4) + alignment(4) + name(variable). +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileAllInformation { + #[smb_direct(start(fixed = 0))] + basic: FileBasicInformation, + #[smb_direct(start(fixed = 40))] + standard: FileStandardInformation, + #[smb_direct(start(fixed = 64))] + internal: FileInternalInformation, + #[smb_direct(start(fixed = 72))] + ea: FileEaInformation, + #[smb_direct(start(fixed = 76))] + access: FileAccessInformation, + #[smb_direct(start(fixed = 80))] + position: FilePositionInformation, + #[smb_direct(start(fixed = 88))] + mode: FileModeInformation, + #[smb_direct(start(fixed = 92))] + alignment: FileAlignmentInformation, + #[smb_direct(start(fixed = 96))] + name: FileNameInformation, +} + +impl FileAllInformation { + #[allow(clippy::too_many_arguments)] + pub fn new( + basic: FileBasicInformation, + standard: FileStandardInformation, + internal: FileInternalInformation, + ea: FileEaInformation, + access: FileAccessInformation, + position: FilePositionInformation, + mode: FileModeInformation, + alignment: FileAlignmentInformation, + name: FileNameInformation, + ) -> Self { + Self { + basic, + standard, + internal, + ea, + access, + position, + mode, + alignment, + name, + } + } + + pub fn basic(&self) -> &FileBasicInformation { + &self.basic + } + pub fn standard(&self) -> &FileStandardInformation { + &self.standard + } + pub fn internal(&self) -> &FileInternalInformation { + &self.internal + } + pub fn ea(&self) -> &FileEaInformation { + &self.ea + } + pub fn access(&self) -> &FileAccessInformation { + &self.access + } + pub fn position(&self) -> &FilePositionInformation { + &self.position + } + pub fn mode(&self) -> &FileModeInformation { + &self.mode + } + pub fn alignment(&self) -> &FileAlignmentInformation { + &self.alignment + } + pub fn name(&self) -> &FileNameInformation { + &self.name + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::body::create::file_attributes::SMBFileAttributes; + use crate::protocol::body::filetime::FileTime; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + + #[test] + fn file_basic_information_size_is_40() { + let info = FileBasicInformation::new( + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + SMBFileAttributes::NORMAL, + ); + assert_eq!(info.smb_byte_size(), 40); + } + + #[test] + fn file_basic_information_round_trip() { + let info = FileBasicInformation::new( + FileTime::now(), + FileTime::now(), + FileTime::now(), + FileTime::now(), + SMBFileAttributes::ARCHIVE | SMBFileAttributes::READONLY, + ); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 40); + let (_, parsed) = FileBasicInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_standard_information_size_is_24() { + let info = FileStandardInformation::new(4096, 1024, 1, false, false); + assert_eq!(info.smb_byte_size(), 24); + } + + #[test] + fn file_standard_information_round_trip() { + let info = FileStandardInformation::new(8192, 2048, 3, true, false); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 24); + let (_, parsed) = FileStandardInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_standard_information_bool_getters() { + let info = FileStandardInformation::new(0, 0, 1, true, false); + assert!(info.delete_pending()); + assert!(!info.directory()); + + let info2 = FileStandardInformation::new(0, 0, 1, false, true); + assert!(!info2.delete_pending()); + assert!(info2.directory()); + } + + #[test] + fn file_internal_information_round_trip() { + let info = FileInternalInformation::new(42); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 8); + let (_, parsed) = FileInternalInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_ea_information_round_trip() { + let info = FileEaInformation::new(0); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileEaInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_access_information_round_trip() { + let info = FileAccessInformation::new(FileAccessFlags::from_bits_truncate(0x001f01ff)); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileAccessInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_position_information_round_trip() { + let info = FilePositionInformation::new(512); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 8); + let (_, parsed) = FilePositionInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_mode_information_round_trip() { + let info = FileModeInformation::new(FileModeFlags::empty()); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileModeInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_alignment_information_round_trip() { + let info = FileAlignmentInformation::new(FileAlignmentRequirement::Byte); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileAlignmentInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_network_open_information_size_is_56() { + let info = FileNetworkOpenInformation::new( + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + 0, + 0, + SMBFileAttributes::NORMAL, + ); + assert_eq!(info.smb_byte_size(), 56); + } + + #[test] + fn file_network_open_information_round_trip() { + let info = FileNetworkOpenInformation::new( + FileTime::now(), + FileTime::now(), + FileTime::now(), + FileTime::now(), + 4096, + 1024, + SMBFileAttributes::ARCHIVE, + ); + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 56); + let (_, parsed) = FileNetworkOpenInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_all_information_contains_all_sub_structs() { + let all = FileAllInformation::new( + FileBasicInformation::new( + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + SMBFileAttributes::NORMAL, + ), + FileStandardInformation::new(4096, 21, 1, false, false), + FileInternalInformation::new(0), + FileEaInformation::new(0), + FileAccessInformation::new(FileAccessFlags::from_bits_truncate(0x001f01ff)), + FilePositionInformation::new(0), + FileModeInformation::new(FileModeFlags::empty()), + FileAlignmentInformation::new(FileAlignmentRequirement::Byte), + FileNameInformation::from_name("testfile.txt".into()), + ); + let bytes = all.smb_to_bytes(); + // 40 + 24 + 8 + 4 + 4 + 8 + 4 + 4 + (4 + 24) = 124 + assert_eq!(bytes.len(), 124); + } + + #[test] + fn file_all_information_basic_segment_matches_standalone() { + let basic = FileBasicInformation::new( + FileTime::now(), + FileTime::now(), + FileTime::now(), + FileTime::now(), + SMBFileAttributes::ARCHIVE, + ); + let all = FileAllInformation::new( + basic.clone(), + FileStandardInformation::new(0, 0, 1, false, false), + FileInternalInformation::new(0), + FileEaInformation::new(0), + FileAccessInformation::new(FileAccessFlags::empty()), + FilePositionInformation::new(0), + FileModeInformation::new(FileModeFlags::empty()), + FileAlignmentInformation::new(FileAlignmentRequirement::Byte), + FileNameInformation::from_name(String::new()), + ); + let all_bytes = all.smb_to_bytes(); + let basic_bytes = basic.smb_to_bytes(); + assert_eq!(&all_bytes[..40], &basic_bytes[..]); + } + + #[test] + fn file_all_information_round_trip() { + let all = FileAllInformation::new( + FileBasicInformation::new( + FileTime::now(), + FileTime::now(), + FileTime::now(), + FileTime::now(), + SMBFileAttributes::ARCHIVE, + ), + FileStandardInformation::new(4096, 512, 1, false, false), + FileInternalInformation::new(7), + FileEaInformation::new(0), + FileAccessInformation::new(FileAccessFlags::from_bits_truncate(0x001f01ff)), + FilePositionInformation::new(256), + FileModeInformation::new(FileModeFlags::empty()), + FileAlignmentInformation::new(FileAlignmentRequirement::Byte), + FileNameInformation::from_name("testfile.txt".into()), + ); + let bytes = all.smb_to_bytes(); + let (_, parsed) = FileAllInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(all, parsed); + } + + #[test] + fn file_all_information_getters() { + let all = FileAllInformation::new( + FileBasicInformation::new( + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + FileTime::zero(), + SMBFileAttributes::NORMAL, + ), + FileStandardInformation::new(4096, 100, 1, false, false), + FileInternalInformation::new(5), + FileEaInformation::new(0), + FileAccessInformation::new(FileAccessFlags::from_bits_truncate(0x001f01ff)), + FilePositionInformation::new(50), + FileModeInformation::new(FileModeFlags::empty()), + FileAlignmentInformation::new(FileAlignmentRequirement::Byte), + FileNameInformation::from_name("test".into()), + ); + assert_eq!(all.basic().file_attributes(), SMBFileAttributes::NORMAL); + assert_eq!(all.standard().allocation_size(), 4096); + assert_eq!(all.standard().end_of_file(), 100); + assert_eq!(all.internal().index_number(), 5); + assert_eq!(all.position().current_byte_offset(), 50); + assert_eq!(all.name().file_name(), "test"); + } +} diff --git a/smb/src/protocol/body/file_info/mode.rs b/smb/src/protocol/body/file_info/mode.rs new file mode 100644 index 0000000..102002a --- /dev/null +++ b/smb/src/protocol/body/file_info/mode.rs @@ -0,0 +1,44 @@ +use bitflags::bitflags; +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; + +bitflags! { + /// Mode flags for FILE_MODE_INFORMATION (MS-FSCC 2.4.26). + #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] + pub struct FileModeFlags: u32 { + const FILE_WRITE_THROUGH = 0x00000002; + const FILE_SEQUENTIAL_ONLY = 0x00000004; + const FILE_NO_INTERMEDIATE_BUFFERING = 0x00000008; + const FILE_SYNCHRONOUS_IO_ALERT = 0x00000010; + const FILE_SYNCHRONOUS_IO_NONALERT = 0x00000020; + const FILE_DELETE_ON_CLOSE = 0x00001000; + } +} + +impl_smb_byte_size_for_bitflag! { FileModeFlags } +impl_smb_to_bytes_for_bitflag! { FileModeFlags } +impl_smb_from_bytes_for_bitflag! { FileModeFlags } + +/// FILE_MODE_INFORMATION (MS-FSCC 2.4.26) — 4 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileModeInformation { + #[smb_direct(start(fixed = 0))] + mode: FileModeFlags, +} + +impl FileModeInformation { + pub fn new(mode: FileModeFlags) -> Self { + Self { mode } + } + + pub fn mode(&self) -> FileModeFlags { + self.mode + } +} diff --git a/smb/src/protocol/body/file_info/name.rs b/smb/src/protocol/body/file_info/name.rs new file mode 100644 index 0000000..b81360d --- /dev/null +++ b/smb/src/protocol/body/file_info/name.rs @@ -0,0 +1,76 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_NAME_INFORMATION (MS-FSCC 2.4.28) — variable length +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileNameInformation { + #[smb_direct(start(fixed = 0))] + file_name_length: u32, + #[smb_string( + order = 0, + start(fixed = 4), + length(inner(start = 0, num_type = "u32")), + underlying = "u16" + )] + file_name: String, +} + +impl FileNameInformation { + pub fn from_name(file_name: String) -> Self { + let file_name_length = (file_name.encode_utf16().count() * 2) as u32; + Self { + file_name_length, + file_name, + } + } + + pub fn file_name_length(&self) -> u32 { + self.file_name_length + } + pub fn file_name(&self) -> &str { + &self.file_name + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBFromBytes, SMBToBytes}; + + #[test] + fn from_name_computes_utf16_byte_length() { + let info = FileNameInformation::from_name("test.txt".into()); + // "test.txt" = 8 UTF-16 code units × 2 bytes = 16 + assert_eq!(info.file_name_length(), 16); + assert_eq!(info.file_name(), "test.txt"); + } + + #[test] + fn from_name_empty_string() { + let info = FileNameInformation::from_name(String::new()); + assert_eq!(info.file_name_length(), 0); + assert_eq!(info.file_name(), ""); + } + + #[test] + fn from_name_round_trip() { + let info = FileNameInformation::from_name("hello.doc".into()); + let bytes = info.smb_to_bytes(); + let (_, parsed) = FileNameInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn from_name_length_matches_wire_size() { + let info = FileNameInformation::from_name("testfile.txt".into()); + let bytes = info.smb_to_bytes(); + // Wire: 4 bytes (length field) + 24 bytes (12 UTF-16 code units) + assert_eq!(bytes.len(), 4 + 24); + // The length field in the first 4 bytes should equal 24 + let wire_length = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); + assert_eq!(wire_length, 24); + } +} diff --git a/smb/src/protocol/body/file_info/network_open.rs b/smb/src/protocol/body/file_info/network_open.rs new file mode 100644 index 0000000..862bf0b --- /dev/null +++ b/smb/src/protocol/body/file_info/network_open.rs @@ -0,0 +1,74 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +use crate::protocol::body::create::file_attributes::SMBFileAttributes; +use crate::protocol::body::filetime::FileTime; + +/// FILE_NETWORK_OPEN_INFORMATION (MS-FSCC 2.4.29) — 56 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileNetworkOpenInformation { + #[smb_direct(start(fixed = 0))] + creation_time: FileTime, + #[smb_direct(start(fixed = 8))] + last_access_time: FileTime, + #[smb_direct(start(fixed = 16))] + last_write_time: FileTime, + #[smb_direct(start(fixed = 24))] + change_time: FileTime, + #[smb_direct(start(fixed = 32))] + allocation_size: u64, + #[smb_direct(start(fixed = 40))] + end_of_file: u64, + #[smb_direct(start(fixed = 48))] + file_attributes: SMBFileAttributes, + #[smb_direct(start(fixed = 52))] + reserved: u32, +} + +impl FileNetworkOpenInformation { + pub fn new( + creation_time: FileTime, + last_access_time: FileTime, + last_write_time: FileTime, + change_time: FileTime, + allocation_size: u64, + end_of_file: u64, + file_attributes: SMBFileAttributes, + ) -> Self { + Self { + creation_time, + last_access_time, + last_write_time, + change_time, + allocation_size, + end_of_file, + file_attributes, + reserved: 0, + } + } + + pub fn creation_time(&self) -> &FileTime { + &self.creation_time + } + pub fn last_access_time(&self) -> &FileTime { + &self.last_access_time + } + pub fn last_write_time(&self) -> &FileTime { + &self.last_write_time + } + pub fn change_time(&self) -> &FileTime { + &self.change_time + } + pub fn allocation_size(&self) -> u64 { + self.allocation_size + } + pub fn end_of_file(&self) -> u64 { + self.end_of_file + } + pub fn file_attributes(&self) -> SMBFileAttributes { + self.file_attributes + } +} diff --git a/smb/src/protocol/body/file_info/position.rs b/smb/src/protocol/body/file_info/position.rs new file mode 100644 index 0000000..5a9d4ac --- /dev/null +++ b/smb/src/protocol/body/file_info/position.rs @@ -0,0 +1,24 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_POSITION_INFORMATION (MS-FSCC 2.4.35) — 8 bytes +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FilePositionInformation { + #[smb_direct(start(fixed = 0))] + current_byte_offset: u64, +} + +impl FilePositionInformation { + pub fn new(current_byte_offset: u64) -> Self { + Self { + current_byte_offset, + } + } + + pub fn current_byte_offset(&self) -> u64 { + self.current_byte_offset + } +} diff --git a/smb/src/protocol/body/file_info/standard.rs b/smb/src/protocol/body/file_info/standard.rs new file mode 100644 index 0000000..3ca08aa --- /dev/null +++ b/smb/src/protocol/body/file_info/standard.rs @@ -0,0 +1,59 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_STANDARD_INFORMATION (MS-FSCC 2.4.41) — 24 bytes +/// +/// `delete_pending` and `directory` are booleans on the wire (u8: 0 or 1). +#[derive( + Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes, +)] +pub struct FileStandardInformation { + #[smb_direct(start(fixed = 0))] + allocation_size: u64, + #[smb_direct(start(fixed = 8))] + end_of_file: u64, + #[smb_direct(start(fixed = 16))] + number_of_links: u32, + #[smb_direct(start(fixed = 20))] + delete_pending: u8, + #[smb_direct(start(fixed = 21))] + directory: u8, + #[smb_direct(start(fixed = 22))] + reserved: u16, +} + +impl FileStandardInformation { + pub fn new( + allocation_size: u64, + end_of_file: u64, + number_of_links: u32, + delete_pending: bool, + directory: bool, + ) -> Self { + Self { + allocation_size, + end_of_file, + number_of_links, + delete_pending: delete_pending as u8, + directory: directory as u8, + reserved: 0, + } + } + + pub fn allocation_size(&self) -> u64 { + self.allocation_size + } + pub fn end_of_file(&self) -> u64 { + self.end_of_file + } + pub fn number_of_links(&self) -> u32 { + self.number_of_links + } + pub fn delete_pending(&self) -> bool { + self.delete_pending != 0 + } + pub fn directory(&self) -> bool { + self.directory != 0 + } +} diff --git a/smb/src/protocol/body/mod.rs b/smb/src/protocol/body/mod.rs index 7db66f5..a8290ad 100644 --- a/smb/src/protocol/body/mod.rs +++ b/smb/src/protocol/body/mod.rs @@ -56,6 +56,7 @@ pub mod create; pub mod echo; pub mod empty; pub mod error; +pub mod file_info; pub mod flush; pub mod ioctl; pub mod lock; diff --git a/smb/src/protocol/body/query_info/info_type.rs b/smb/src/protocol/body/query_info/info_type.rs index a23d556..00f8ffc 100644 --- a/smb/src/protocol/body/query_info/info_type.rs +++ b/smb/src/protocol/body/query_info/info_type.rs @@ -20,8 +20,8 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; Deserialize, )] pub enum SMBInfoType { - File, - Filesystem, - Security, - Quota, + File = 1, + Filesystem = 2, + Security = 3, + Quota = 4, } diff --git a/smb/src/protocol/body/query_info/mod.rs b/smb/src/protocol/body/query_info/mod.rs index b66b2af..ee32008 100644 --- a/smb/src/protocol/body/query_info/mod.rs +++ b/smb/src/protocol/body/query_info/mod.rs @@ -10,7 +10,7 @@ use crate::protocol::body::query_info::info_type::SMBInfoType; use crate::protocol::body::query_info::security_information::SMBSecurityInformation; mod flags; -mod info_type; +pub mod info_type; mod security_information; #[derive( @@ -39,10 +39,28 @@ pub struct SMBQueryInfoRequest { buffer: Vec, } +impl SMBQueryInfoRequest { + pub fn info_type(&self) -> SMBInfoType { + self.info_type + } + + pub fn file_info_class(&self) -> u8 { + self.file_info_class + } + + pub fn output_buffer_length(&self) -> u32 { + self.output_buffer_length + } + + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } +} + #[derive( Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] -#[smb_byte_tag(value = 17)] +#[smb_byte_tag(value = 9)] pub struct SMBQueryInfoResponse { #[smb_skip(start = 2, length = 6)] reserved: PhantomData>, @@ -54,3 +72,77 @@ pub struct SMBQueryInfoResponse { )] data: Vec, } + +impl SMBQueryInfoResponse { + pub fn new(data: Vec) -> Self { + Self { + reserved: PhantomData, + data, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + + #[test] + fn query_info_response_new_sets_data() { + let data = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let resp = SMBQueryInfoResponse::new(data.clone()); + assert_eq!(resp.data, data); + } + + #[test] + fn query_info_response_serialization_round_trip() { + let resp = SMBQueryInfoResponse::new(vec![0xAA; 40]); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBQueryInfoResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn query_info_response_empty_data_round_trip() { + let resp = SMBQueryInfoResponse::new(vec![]); + let bytes = resp.smb_to_bytes(); + let (_, parsed) = SMBQueryInfoResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn query_info_request_accessors() { + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 41 + buf.extend_from_slice(&41u16.to_le_bytes()); + // info_type (u8) = 1 (File) per MS-SMB2 + buf.push(1); + // file_info_class (u8) = 4 (FileBasicInformation) + buf.push(4); + // output_buffer_length (u32) = 4096 + buf.extend_from_slice(&4096u32.to_le_bytes()); + // input_buffer_offset (u16) = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // reserved (u16) = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // input_buffer_length (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // additional_information (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // flags (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // file_id: persistent (u64) + volatile (u64) + buf.extend_from_slice(&55u64.to_le_bytes()); + buf.extend_from_slice(&77u64.to_le_bytes()); + buf + }; + let (_, req) = SMBQueryInfoRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.info_type(), SMBInfoType::File); + assert_eq!(req.file_info_class(), 4); + assert_eq!(req.output_buffer_length(), 4096); + assert_eq!(req.file_id().persistent(), 55); + assert_eq!(req.file_id().volatile(), 77); + } +} diff --git a/smb/src/protocol/body/read/mod.rs b/smb/src/protocol/body/read/mod.rs index 14ea887..b596ee5 100644 --- a/smb/src/protocol/body/read/mod.rs +++ b/smb/src/protocol/body/read/mod.rs @@ -37,6 +37,24 @@ pub struct SMBReadRequest { channel_information: Vec, } +impl SMBReadRequest { + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } + + pub fn read_length(&self) -> u32 { + self.read_length + } + + pub fn read_offset(&self) -> u64 { + self.read_offset + } + + pub fn minimum_count(&self) -> u32 { + self.minimum_count + } +} + #[derive( Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] @@ -55,3 +73,82 @@ pub struct SMBReadResponse { )] data: Vec, } + +impl SMBReadResponse { + pub fn new(data: Vec, data_remaining: u32) -> Self { + Self { + reserved: PhantomData, + data_remaining, + flags: SMBReadResponseFlags::None, + data, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + + #[test] + fn read_response_new_sets_fields() { + let data = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let resp = SMBReadResponse::new(data.clone(), 100); + assert_eq!(resp.data, data); + assert_eq!(resp.data_remaining, 100); + assert_eq!(resp.flags, SMBReadResponseFlags::None); + } + + #[test] + fn read_response_serialization_round_trip() { + let resp = SMBReadResponse::new(vec![1, 2, 3, 4, 5], 0); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBReadResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn read_response_empty_data() { + let resp = SMBReadResponse::new(vec![], 0); + let bytes = resp.smb_to_bytes(); + let (_, parsed) = SMBReadResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn read_request_accessors() { + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 49 + buf.extend_from_slice(&49u16.to_le_bytes()); + // padding (u8) + buf.push(0); + // flags (u8) = 0 + buf.push(0); + // read_length (u32) = 1024 + buf.extend_from_slice(&1024u32.to_le_bytes()); + // read_offset (u64) = 512 + buf.extend_from_slice(&512u64.to_le_bytes()); + // file_id: persistent (u64) + volatile (u64) + buf.extend_from_slice(&10u64.to_le_bytes()); + buf.extend_from_slice(&20u64.to_le_bytes()); + // minimum_count (u32) = 256 + buf.extend_from_slice(&256u32.to_le_bytes()); + // channel (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // remaining_bytes (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // channel_info_offset (u16) = 0, channel_info_length (u16) = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + buf + }; + let (_, req) = SMBReadRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.read_length(), 1024); + assert_eq!(req.read_offset(), 512); + assert_eq!(req.minimum_count(), 256); + assert_eq!(req.file_id().persistent(), 10); + assert_eq!(req.file_id().volatile(), 20); + } +} diff --git a/smb/src/protocol/body/write/mod.rs b/smb/src/protocol/body/write/mod.rs index f4ff936..02e9d7f 100644 --- a/smb/src/protocol/body/write/mod.rs +++ b/smb/src/protocol/body/write/mod.rs @@ -39,6 +39,24 @@ pub struct SMBWriteRequest { data_to_write: Vec, } +impl SMBWriteRequest { + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } + + pub fn write_offset(&self) -> u64 { + self.write_offset + } + + pub fn write_length(&self) -> u32 { + self.write_length + } + + pub fn data_to_write(&self) -> &[u8] { + &self.data_to_write + } +} + #[derive( Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] @@ -55,3 +73,77 @@ pub struct SMBWriteResponse { #[smb_skip(start = 14, length = 2)] write_channel_info_len: PhantomData>, } + +impl SMBWriteResponse { + pub fn new(bytes_written: u32) -> Self { + Self { + reserved: PhantomData, + bytes_written, + remaining_bytes: PhantomData, + write_channel_info_offset: PhantomData, + write_channel_info_len: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + + #[test] + fn write_response_new_sets_bytes_written() { + let resp = SMBWriteResponse::new(4096); + assert_eq!(resp.bytes_written, 4096); + } + + #[test] + fn write_response_serialization_round_trip() { + let resp = SMBWriteResponse::new(512); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBWriteResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn write_request_accessors() { + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 49 + buf.extend_from_slice(&49u16.to_le_bytes()); + // data_offset (u16) — points past header (offset 2) + let data_offset: u16 = 64 + 49; // header + struct + buf.extend_from_slice(&data_offset.to_le_bytes()); + // write_length (u32) = 100 (offset 4) + buf.extend_from_slice(&100u32.to_le_bytes()); + // write_offset (u64) = 200 (offset 8) + buf.extend_from_slice(&200u64.to_le_bytes()); + // file_id: persistent (u64) + volatile (u64) (offset 16) + buf.extend_from_slice(&5u64.to_le_bytes()); + buf.extend_from_slice(&15u64.to_le_bytes()); + // channel (u32) = 0 (offset 32..36 — but channel is at 36 per struct) + // remaining_bytes (u32) = 0 (offset 36) + buf.extend_from_slice(&0u32.to_le_bytes()); + buf.extend_from_slice(&0u32.to_le_bytes()); + // channel_info_offset (u16) = 0, channel_info_length (u16) = 0 (offset 40..44) + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + // flags (u32) = 0 (offset 44) + buf.extend_from_slice(&0u32.to_le_bytes()); + // pad to data_offset - 64 + while buf.len() < (data_offset - 64) as usize { + buf.push(0); + } + // data (100 bytes) + buf.extend_from_slice(&[0xAB; 100]); + buf + }; + let (_, req) = SMBWriteRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.write_length(), 100); + assert_eq!(req.write_offset(), 200); + assert_eq!(req.file_id().persistent(), 5); + assert_eq!(req.file_id().volatile(), 15); + assert_eq!(req.data_to_write().len(), 100); + } +} diff --git a/smb/src/server/connection.rs b/smb/src/server/connection.rs index 42c1797..cebe9ec 100644 --- a/smb/src/server/connection.rs +++ b/smb/src/server/connection.rs @@ -283,7 +283,8 @@ where debug!(?status, "handler returned response error"); Self::build_error_response(&incoming, status) } - Err(_e) => { + #[allow(unused_variables)] + Err(e) => { error!(?e, "non-response error, sending NOT_SUPPORTED"); Self::build_error_response(&incoming, NTStatus::NotSupported) } diff --git a/smb/src/server/mod.rs b/smb/src/server/mod.rs index 18a3d08..d66f874 100644 --- a/smb/src/server/mod.rs +++ b/smb/src/server/mod.rs @@ -57,6 +57,7 @@ pub trait Server: Send + Sync { fn shares(&self) -> &HashMap>; fn opens(&self) -> &HashMap>>; fn add_open(&mut self, open: Arc>) -> impl Future; + fn remove_open(&mut self, global_id: u32); fn sessions(&self) -> &HashMap>>; fn sessions_mut(&mut self) -> &mut HashMap>>; fn guid(&self) -> Uuid; @@ -228,6 +229,10 @@ impl< 0 } + fn remove_open(&mut self, global_id: u32) { + self.open_table.remove(&global_id); + } + fn sessions(&self) -> &HashMap>> { &self.session_table } @@ -452,15 +457,17 @@ impl< tokio::spawn(async move { debug!(client = %name, "starting message handler"); let mut stream = socket.lock().await; - match SMBConnection::start_message_handler::( + #[allow(unused_variables)] + if let Err(e) = SMBConnection::start_message_handler::( &mut stream, wrapped_connection, update_channel, ) .await { - Ok(()) => debug!("message handler completed"), - Err(_e) => warn!(?e, "message handler exited with error"), + warn!(?e, "message handler exited with error"); + } else { + debug!("message handler completed"); } }); } diff --git a/smb/src/server/open.rs b/smb/src/server/open.rs index 4b32141..5cef55f 100644 --- a/smb/src/server/open.rs +++ b/smb/src/server/open.rs @@ -32,6 +32,8 @@ pub trait Open: Send + Sync { fn file_attributes(&self) -> SMBFileAttributes; fn file_id(&self) -> SMBFileId; fn file_metadata(&self) -> SMBResult; + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult>; + fn write_data(&mut self, offset: u64, data: &[u8]) -> SMBResult; } pub struct SMBOpen { @@ -140,15 +142,20 @@ impl Open for SMBOpen { } fn file_id(&self) -> SMBFileId { - SMBFileId { - persistent: self.session_id, - volatile: self.session_id, - } + SMBFileId::new(self.global_id as u64, self.session_id) } fn file_metadata(&self) -> SMBResult { self.underlying.metadata() } + + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult> { + self.underlying.read_data(offset, length) + } + + fn write_data(&mut self, offset: u64, data: &[u8]) -> SMBResult { + self.underlying.write_data(offset, data) + } } // TODO: From MS-FSCC section 2.6 @@ -234,7 +241,7 @@ impl SMBLockedMessageHandlerBase for Arc> { type Inner = (); async fn inner(&self, _message: &SMBMessageType) -> Option { - todo!() + None } } diff --git a/smb/src/server/session.rs b/smb/src/server/session.rs index 5bf7865..7b33d85 100644 --- a/smb/src/server/session.rs +++ b/smb/src/server/session.rs @@ -63,6 +63,7 @@ pub trait Session: Send + Sync { fn provider(&self) -> &Arc; fn encrypt_data(&self) -> bool; fn open_table(&self) -> &HashMap>>; + fn open_table_mut(&mut self) -> &mut HashMap>>; fn add_open(&mut self, open: Arc>) -> impl Future; fn set_previous_file_id(&mut self, file_id: SMBFileId); fn signing_key(&self) -> &[u8]; @@ -289,7 +290,7 @@ impl>> SMBLockedMessageHandlerBase share.clone(), response.access_mask().clone(), ); - let header = SMBSyncHeader::create_response_header(header, 0, self_rd.id(), 1); + let header = SMBSyncHeader::create_response_header(header, 0, self_rd.id(), tree_id); drop(self_rd); let mut self_wr = self.write().await; self_wr @@ -397,6 +398,10 @@ impl> Session &self.open_table } + fn open_table_mut(&mut self) -> &mut HashMap>> { + &mut self.open_table + } + async fn add_open(&mut self, open: Arc>) { let id = Self::get_next_map_id(&self.open_table); let mut open_wr = open.write().await; diff --git a/smb/src/server/share/file_system.rs b/smb/src/server/share/file_system.rs index ab0a9a1..e70079c 100644 --- a/smb/src/server/share/file_system.rs +++ b/smb/src/server/share/file_system.rs @@ -2,12 +2,15 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::fs; use std::fs::{File, OpenOptions, ReadDir}; +use std::io::{Read, Seek, SeekFrom, Write}; use std::marker::PhantomData; +use std::path::{Component, Path, PathBuf}; use std::time::{SystemTime, UNIX_EPOCH}; use smb_core::SMBResult; use smb_core::error::SMBError; -use smb_core::logging::debug; +use smb_core::logging::{debug, warn}; +use smb_core::nt_status::NTStatus; use crate::protocol::body::create::disposition::SMBCreateDisposition; use crate::protocol::body::filetime::FileTime; @@ -17,6 +20,32 @@ use crate::server::share::{ ConnectAllowed, FilePerms, ResourceHandle, ResourceType, SMBFileMetadata, SharedResource, }; +/// Maximum single read size (8 MB), per MS-SMB2 §3.3.5.12 recommendation for SMB 3.x. +const MAX_READ_SIZE: u32 = 8 * 1024 * 1024; +const MAX_WRITE_SIZE: u32 = 8 * 1024 * 1024; + +/// Normalize a path by resolving `.` and `..` components lexically (without +/// touching the filesystem). Returns `None` if the normalized path would +/// escape the root (i.e., more `..` than preceding components). +fn normalize_path(path: &str) -> Option { + let mut components = Vec::new(); + for component in Path::new(path).components() { + match component { + Component::ParentDir => { + if components.is_empty() { + // Attempting to go above root — reject + return None; + } + components.pop(); + } + Component::Normal(c) => components.push(c), + Component::CurDir => {} // skip "." + Component::RootDir | Component::Prefix(_) => {} // skip absolute prefixes + } + } + Some(components.iter().collect()) +} + #[derive(Debug)] pub struct SMBFileSystemHandle { path: String, @@ -88,20 +117,49 @@ impl ResourceHandle for SMBFileSystemHandle { )) })?; let time_transform = |time: SystemTime| time.duration_since(UNIX_EPOCH).unwrap().as_secs(); - Ok(SMBFileMetadata { - creation_time: FileTime::from_unix(metadata.created().map(time_transform).unwrap_or(0)), - last_access_time: FileTime::from_unix( - metadata.accessed().map(time_transform).unwrap_or(0), - ), - last_write_time: FileTime::from_unix( - metadata.modified().map(time_transform).unwrap_or(0), - ), - last_modification_time: FileTime::from_unix( - metadata.modified().map(time_transform).unwrap_or(0), - ), - allocated_size: metadata.len(), - actual_size: metadata.len(), - }) + Ok(SMBFileMetadata::new( + FileTime::from_unix(metadata.created().map(time_transform).unwrap_or(0)), + FileTime::from_unix(metadata.accessed().map(time_transform).unwrap_or(0)), + FileTime::from_unix(metadata.modified().map(time_transform).unwrap_or(0)), + FileTime::from_unix(metadata.modified().map(time_transform).unwrap_or(0)), + metadata.len(), + metadata.len(), + )) + } + + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult> { + match &mut self.resource { + SMBFileSystemResourceHandle::File(file) => { + // Cap to MAX_READ_SIZE to prevent OOM from malicious clients + let capped_length = length.min(MAX_READ_SIZE) as u64; + file.seek(SeekFrom::Start(offset)) + .map_err(SMBError::io_error)?; + // Use take() + read_to_end() to handle short reads correctly + let mut buf = Vec::with_capacity(capped_length as usize); + file.take(capped_length) + .read_to_end(&mut buf) + .map_err(SMBError::io_error)?; + Ok(buf) + } + SMBFileSystemResourceHandle::Directory(_) => Err(SMBError::response_error( + smb_core::nt_status::NTStatus::InvalidDeviceRequest, + )), + } + } + + fn write_data(&mut self, offset: u64, data: &[u8]) -> SMBResult { + match &mut self.resource { + SMBFileSystemResourceHandle::File(file) => { + let capped = &data[..data.len().min(MAX_WRITE_SIZE as usize)]; + file.seek(SeekFrom::Start(offset)) + .map_err(SMBError::io_error)?; + file.write_all(capped).map_err(SMBError::io_error)?; + Ok(capped.len() as u32) + } + SMBFileSystemResourceHandle::Directory(_) => { + Err(SMBError::response_error(NTStatus::InvalidDeviceRequest)) + } + } } } @@ -180,7 +238,17 @@ impl< disposition: SMBCreateDisposition, directory: bool, ) -> SMBResult { - let path = format!("{}/{}", self.local_path, path); + // Sanitize: strip NUL terminators from UTF-16LE wire encoding, + // convert Windows backslashes to forward slashes + let sanitized = path.trim_end_matches('\0').replace('\\', "/"); + + // Normalize and reject path traversal attempts (e.g. "../../etc/passwd") + let relative = normalize_path(&sanitized).ok_or_else(|| { + warn!(path = %sanitized, "rejected path traversal attempt"); + SMBError::response_error(NTStatus::AccessDenied) + })?; + let path = format!("{}/{}", self.local_path, relative.display()); + let resource = match directory { true => SMBFileSystemResourceHandle::directory(&path), false => SMBFileSystemResourceHandle::file(&path, disposition), @@ -281,3 +349,274 @@ impl> Debug .finish() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normalize_path_simple() { + assert_eq!( + normalize_path("foo/bar.txt"), + Some(PathBuf::from("foo/bar.txt")) + ); + } + + #[test] + fn normalize_path_strips_current_dir() { + assert_eq!( + normalize_path("./foo/./bar.txt"), + Some(PathBuf::from("foo/bar.txt")) + ); + } + + #[test] + fn normalize_path_resolves_parent_within_subtree() { + assert_eq!( + normalize_path("foo/bar/../baz.txt"), + Some(PathBuf::from("foo/baz.txt")) + ); + } + + #[test] + fn normalize_path_rejects_traversal_above_root() { + assert_eq!(normalize_path("../etc/passwd"), None); + } + + #[test] + fn normalize_path_rejects_deep_traversal() { + assert_eq!(normalize_path("foo/../../etc/passwd"), None); + } + + #[test] + fn normalize_path_empty() { + assert_eq!(normalize_path(""), Some(PathBuf::from(""))); + } + + #[test] + fn normalize_path_backslash_after_sanitize() { + assert_eq!( + normalize_path("subdir/file.txt"), + Some(PathBuf::from("subdir/file.txt")) + ); + } + + #[test] + fn read_data_returns_full_contents() { + let dir = std::env::temp_dir().join("smb_test_read_full"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("read_test.bin"); + let data: Vec = (0..4096).map(|i| (i % 256) as u8).collect(); + std::fs::write(&path, &data).unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + let result = handle.read_data(0, 4096).unwrap(); + assert_eq!( + result.len(), + 4096, + "read_data must return all requested bytes when available" + ); + assert_eq!(result, data); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn read_data_at_offset_returns_remaining() { + let dir = std::env::temp_dir().join("smb_test_read_offset"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("offset_test.bin"); + let data = vec![0xAA; 100]; + std::fs::write(&path, &data).unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + // Read past end of file — should return only remaining bytes + let result = handle.read_data(90, 50).unwrap(); + assert_eq!(result.len(), 10); + + // Read at exact EOF — should return empty + let result = handle.read_data(100, 50).unwrap(); + assert!(result.is_empty()); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn read_data_capped_at_max_read_size() { + let dir = std::env::temp_dir().join("smb_test_read_cap"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("cap_test.bin"); + // Write a small file but request more than MAX_READ_SIZE + let data = vec![0xBB; 64]; + std::fs::write(&path, &data).unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + // Request u32::MAX bytes — should be capped and not OOM + let result = handle.read_data(0, u32::MAX).unwrap(); + assert_eq!(result.len(), 64); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn read_data_directory_returns_error() { + let dir = std::env::temp_dir().join("smb_test_read_dir"); + std::fs::create_dir_all(&dir).unwrap(); + + let mut handle = SMBFileSystemHandle { + path: dir.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::directory(dir.to_str().unwrap()).unwrap(), + }; + + let result = handle.read_data(0, 100); + assert!(result.is_err()); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn write_data_writes_all_bytes() { + let dir = std::env::temp_dir().join("smb_test_write_all"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("write_test.bin"); + // Create the file first + std::fs::write(&path, b"").unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + let data = vec![0xCC; 4096]; + let written = handle.write_data(0, &data).unwrap(); + assert_eq!(written, 4096); + + // Verify the file contents + let contents = std::fs::read(&path).unwrap(); + assert_eq!(contents, data); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn write_data_at_offset() { + let dir = std::env::temp_dir().join("smb_test_write_offset"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("offset_write.bin"); + std::fs::write(&path, vec![0xAA; 100]).unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + let patch = vec![0xBB; 10]; + let written = handle.write_data(50, &patch).unwrap(); + assert_eq!(written, 10); + + let contents = std::fs::read(&path).unwrap(); + assert_eq!(&contents[..50], &[0xAA; 50]); + assert_eq!(&contents[50..60], &[0xBB; 10]); + assert_eq!(&contents[60..], &[0xAA; 40]); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn write_data_capped_at_max_write_size() { + let dir = std::env::temp_dir().join("smb_test_write_cap"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("cap_write.bin"); + std::fs::write(&path, b"").unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + // Write a small amount — just verify capping logic doesn't break small writes + let data = vec![0xDD; 64]; + let written = handle.write_data(0, &data).unwrap(); + assert_eq!(written, 64); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn write_data_directory_returns_error() { + let dir = std::env::temp_dir().join("smb_test_write_dir"); + std::fs::create_dir_all(&dir).unwrap(); + + let mut handle = SMBFileSystemHandle { + path: dir.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::directory(dir.to_str().unwrap()).unwrap(), + }; + + let result = handle.write_data(0, &[0xFF; 10]); + assert!(result.is_err()); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn write_then_read_round_trip() { + let dir = std::env::temp_dir().join("smb_test_write_read_rt"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("round_trip.bin"); + std::fs::write(&path, b"").unwrap(); + + let mut handle = SMBFileSystemHandle { + path: path.to_string_lossy().into(), + resource: SMBFileSystemResourceHandle::file( + path.to_str().unwrap(), + SMBCreateDisposition::Open, + ) + .unwrap(), + }; + + let data: Vec = (0..256).map(|i| i as u8).collect(); + let written = handle.write_data(0, &data).unwrap(); + assert_eq!(written, 256); + + let read_back = handle.read_data(0, 256).unwrap(); + assert_eq!(read_back, data); + + std::fs::remove_dir_all(&dir).unwrap(); + } +} diff --git a/smb/src/server/share/ipc.rs b/smb/src/server/share/ipc.rs index 9579e8e..246b80c 100644 --- a/smb/src/server/share/ipc.rs +++ b/smb/src/server/share/ipc.rs @@ -3,6 +3,7 @@ use std::fmt::{Debug, Formatter}; use std::marker::PhantomData; use smb_core::SMBResult; +use smb_core::error::SMBError; use crate::protocol::body::create::disposition::SMBCreateDisposition; use crate::protocol::body::filetime::FileTime; @@ -36,14 +37,26 @@ impl ResourceHandle for SMBIPCHandle { } fn metadata(&self) -> SMBResult { - Ok(SMBFileMetadata { - creation_time: FileTime::default(), - last_access_time: FileTime::default(), - last_write_time: FileTime::default(), - last_modification_time: FileTime::default(), - allocated_size: 0, - actual_size: 0, - }) + Ok(SMBFileMetadata::new( + FileTime::default(), + FileTime::default(), + FileTime::default(), + FileTime::default(), + 0, + 0, + )) + } + + fn read_data(&mut self, _offset: u64, _length: u32) -> SMBResult> { + Err(SMBError::response_error( + smb_core::nt_status::NTStatus::InvalidDeviceRequest, + )) + } + + fn write_data(&mut self, _offset: u64, _data: &[u8]) -> SMBResult { + Err(SMBError::response_error( + smb_core::nt_status::NTStatus::InvalidDeviceRequest, + )) } } diff --git a/smb/src/server/share/mod.rs b/smb/src/server/share/mod.rs index 17d7af4..cdf4a65 100644 --- a/smb/src/server/share/mod.rs +++ b/smb/src/server/share/mod.rs @@ -24,15 +24,61 @@ pub trait ResourceHandle: Send + Sync { fn is_directory(&self) -> bool; fn path(&self) -> &str; fn metadata(&self) -> SMBResult; + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult>; + fn write_data(&mut self, offset: u64, data: &[u8]) -> SMBResult; } pub struct SMBFileMetadata { - pub creation_time: FileTime, - pub last_access_time: FileTime, - pub last_write_time: FileTime, - pub last_modification_time: FileTime, - pub allocated_size: u64, - pub actual_size: u64, + creation_time: FileTime, + last_access_time: FileTime, + last_write_time: FileTime, + last_modification_time: FileTime, + allocated_size: u64, + actual_size: u64, +} + +impl SMBFileMetadata { + pub fn new( + creation_time: FileTime, + last_access_time: FileTime, + last_write_time: FileTime, + last_modification_time: FileTime, + allocated_size: u64, + actual_size: u64, + ) -> Self { + Self { + creation_time, + last_access_time, + last_write_time, + last_modification_time, + allocated_size, + actual_size, + } + } + + pub fn creation_time(&self) -> &FileTime { + &self.creation_time + } + + pub fn last_access_time(&self) -> &FileTime { + &self.last_access_time + } + + pub fn last_write_time(&self) -> &FileTime { + &self.last_write_time + } + + pub fn last_modification_time(&self) -> &FileTime { + &self.last_modification_time + } + + pub fn allocated_size(&self) -> u64 { + self.allocated_size + } + + pub fn actual_size(&self) -> u64 { + self.actual_size + } } impl ResourceHandle for Box { @@ -55,6 +101,14 @@ impl ResourceHandle for Box { fn metadata(&self) -> SMBResult { H::metadata(self) } + + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult> { + H::read_data(self, offset, length) + } + + fn write_data(&mut self, offset: u64, data: &[u8]) -> SMBResult { + H::write_data(self, offset, data) + } } pub trait SharedResource: Send + Sync { diff --git a/smb/src/server/tree_connect.rs b/smb/src/server/tree_connect.rs index 992e352..1c58c79 100644 --- a/smb/src/server/tree_connect.rs +++ b/smb/src/server/tree_connect.rs @@ -4,14 +4,30 @@ use std::sync::{Arc, Weak}; use tokio::sync::RwLock; -use smb_core::SMBResult; +#[cfg(feature = "logging")] +use smb_core::SMBByteSize; use smb_core::error::SMBError; -use smb_core::logging::{debug, trace}; +use smb_core::logging::{debug, trace, warn}; +use smb_core::nt_status::NTStatus; +use smb_core::{SMBResult, SMBToBytes}; use crate::protocol::body::SMBBody; +use crate::protocol::body::close::{SMBCloseRequest, SMBCloseResponse}; +use crate::protocol::body::create::file_attributes::SMBFileAttributes; +use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::create::{SMBCreateRequest, SMBCreateResponse}; +use crate::protocol::body::file_info::{ + FileAccessFlags, FileAccessInformation, FileAlignmentInformation, FileAlignmentRequirement, + FileAllInformation, FileBasicInformation, FileEaInformation, FileInternalInformation, + FileModeFlags, FileModeInformation, FileNameInformation, FileNetworkOpenInformation, + FilePositionInformation, FileStandardInformation, +}; use crate::protocol::body::filetime::FileTime; +use crate::protocol::body::query_info::info_type::SMBInfoType; +use crate::protocol::body::query_info::{SMBQueryInfoRequest, SMBQueryInfoResponse}; +use crate::protocol::body::read::{SMBReadRequest, SMBReadResponse}; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; +use crate::protocol::body::write::{SMBWriteRequest, SMBWriteResponse}; use crate::protocol::header::SMBSyncHeader; use crate::protocol::message::SMBMessage; use crate::server::Server; @@ -54,6 +70,83 @@ impl SMBTreeConnect { } } +impl SMBTreeConnect { + fn get_session(&self) -> SMBResult>> { + self.session + .upgrade() + .ok_or(SMBError::server_error("No Session Found")) + } + + async fn find_open(&self, file_id: &SMBFileId) -> SMBResult>> { + let session = self.get_session()?; + let session_rd = session.read().await; + let open = session_rd + .open_table() + .get(&file_id.volatile()) + .cloned() + .ok_or(SMBError::response_error(NTStatus::FileClosed))?; + // MS-SMB2 §3.3.5.10/12/20: verify Open.DurableFileId == FileId.Persistent + let open_rd = open.read().await; + if open_rd.file_id().persistent() != file_id.persistent() { + return Err(SMBError::response_error(NTStatus::FileClosed)); + } + drop(open_rd); + Ok(open) + } + + fn build_basic_info(open: &S::Open) -> SMBResult { + let metadata = open.file_metadata()?; + Ok(FileBasicInformation::new( + metadata.creation_time().clone(), + metadata.last_access_time().clone(), + metadata.last_write_time().clone(), + metadata.last_modification_time().clone(), + open.file_attributes(), + )) + } + + fn build_standard_info(open: &S::Open) -> SMBResult { + let metadata = open.file_metadata()?; + let is_dir = open + .file_attributes() + .contains(SMBFileAttributes::DIRECTORY); + Ok(FileStandardInformation::new( + metadata.allocated_size(), + metadata.actual_size(), + 1, + false, + is_dir, + )) + } + + fn build_network_open_info(open: &S::Open) -> SMBResult { + let metadata = open.file_metadata()?; + Ok(FileNetworkOpenInformation::new( + metadata.creation_time().clone(), + metadata.last_access_time().clone(), + metadata.last_write_time().clone(), + metadata.last_modification_time().clone(), + metadata.allocated_size(), + metadata.actual_size(), + open.file_attributes(), + )) + } + + fn build_all_info(open: &S::Open) -> SMBResult { + Ok(FileAllInformation::new( + Self::build_basic_info(open)?, + Self::build_standard_info(open)?, + FileInternalInformation::new(0), + FileEaInformation::new(0), + FileAccessInformation::new(FileAccessFlags::from_bits_truncate(0x001f01ff)), + FilePositionInformation::new(0), + FileModeInformation::new(FileModeFlags::empty()), + FileAlignmentInformation::new(FileAlignmentRequirement::Byte), + FileNameInformation::from_name(open.file_name().into()), + )) + } +} + impl SMBLockedMessageHandlerBase for Arc> { type Inner = Arc>; @@ -69,33 +162,194 @@ impl SMBLockedMessageHandlerBase for Arc> { let (path, disposition, directory) = message.validate(self.share.deref())?; let handle = self.share.handle_create(path, disposition, directory)?; let open_raw = Open::init(handle, message); - let response = SMBBody::CreateResponse(SMBCreateResponse::for_open::(&open_raw)?); let open = Arc::new(RwLock::new(open_raw)); - let session = self - .session - .upgrade() - .ok_or(SMBError::server_error("No Session Found"))?; - session.write().await.add_open(open.clone()).await; + let session = self.get_session()?; + // Register with server first (outermost), then session (inner) let server = session.upper().await?.upper().await?; { server.write().await.add_open(open.clone()).await; } - { - let file_id = open.read().await.file_id(); - session.write().await.set_previous_file_id(file_id); - } + session.write().await.add_open(open.clone()).await; + // Build response AFTER registration so file_id reflects assigned IDs + let (response, file_id) = { + let open_rd = open.read().await; + let resp = SMBBody::CreateResponse(SMBCreateResponse::for_open::(&*open_rd)?); + (resp, open_rd.file_id()) + }; + session.write().await.set_previous_file_id(file_id); debug!("tree connect create handled"); - let header = header.create_response_header( - header.channel_sequence, - header.session_id, - header.tree_id, - ); + let header = header.create_response_header(0, header.session_id, header.tree_id); trace!( response_size = response.smb_byte_size(), "create response built" ); Ok(SMBHandlerState::Finished(SMBMessage::new(header, response))) } + + async fn handle_close( + &mut self, + header: &SMBSyncHeader, + message: &SMBCloseRequest, + ) -> SMBResult> { + debug!(file_id = ?message.file_id(), "handling close request"); + + // Phase 1: Validate and read open data via shared find_open logic + let open = self.find_open(message.file_id()).await?; + let (response, file_id) = { + let open_rd = open.read().await; + let response = if message + .flags() + .contains(crate::protocol::body::close::flags::SMBCloseFlags::POSTQUERY_ATTRIB) + { + let metadata = open_rd.file_metadata()?; + SMBCloseResponse::from_metadata(&metadata, open_rd.file_attributes()) + } else { + SMBCloseResponse::empty() + }; + (response, open_rd.file_id()) + }; + + // Phase 2: Cleanup — acquire locks outer to inner (server_wr, then session_wr) + let session = self.get_session()?; + let global_id: u32 = file_id + .persistent() + .try_into() + .expect("global_id fits in u32"); + // Server write first (outermost) + if let Ok(conn) = session.upper().await { + if let Ok(server) = conn.upper().await { + server.write().await.remove_open(global_id); + } else { + warn!(file_id = ?file_id, "failed to acquire server lock during close; global open table entry leaked"); + } + } else { + warn!(file_id = ?file_id, "failed to acquire connection lock during close; global open table entry leaked"); + } + // Session write second (inner relative to server) + { + let mut session_wr = session.write().await; + session_wr.open_table_mut().remove(&file_id.volatile()); + } + + debug!(file_id = ?file_id, "close completed"); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new( + header, + SMBBody::CloseResponse(response), + ))) + } + + async fn handle_read( + &mut self, + header: &SMBSyncHeader, + message: &SMBReadRequest, + ) -> SMBResult> { + debug!(file_id = ?message.file_id(), offset = message.read_offset(), length = message.read_length(), "handling read request"); + let open = self.find_open(message.file_id()).await?; + let mut open_wr = open.write().await; + let data = open_wr.read_data(message.read_offset(), message.read_length())?; + drop(open_wr); + + // MS-SMB2 §3.3.5.12: if read returns 0 bytes at/past EOF, fail with STATUS_END_OF_FILE + if data.is_empty() && message.read_length() > 0 { + return Err(SMBError::response_error(NTStatus::EndOfFile)); + } + + if data.len() < message.minimum_count() as usize { + return Err(SMBError::response_error(NTStatus::EndOfFile)); + } + + debug!(bytes_read = data.len(), "read completed"); + trace!(data_len = data.len(), "read response data"); + let response = SMBReadResponse::new(data, 0); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new( + header, + SMBBody::ReadResponse(response), + ))) + } + + async fn handle_write( + &mut self, + header: &SMBSyncHeader, + message: &SMBWriteRequest, + ) -> SMBResult> { + debug!(file_id = ?message.file_id(), offset = message.write_offset(), length = message.write_length(), "handling write request"); + let open = self.find_open(message.file_id()).await?; + let mut open_wr = open.write().await; + let bytes_written = open_wr.write_data(message.write_offset(), message.data_to_write())?; + drop(open_wr); + + debug!(bytes_written, "write completed"); + let response = SMBWriteResponse::new(bytes_written); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new( + header, + SMBBody::WriteResponse(response), + ))) + } + + async fn handle_query_info( + &mut self, + header: &SMBSyncHeader, + message: &SMBQueryInfoRequest, + ) -> SMBResult> { + debug!(file_id = ?message.file_id(), info_type = ?message.info_type(), class = message.file_info_class(), "handling query_info request"); + let open = self.find_open(message.file_id()).await?; + let open_rd = open.read().await; + + let mut data = match message.info_type() { + SMBInfoType::File => { + // MS-FSCC file information classes + match message.file_info_class() { + 4 => SMBTreeConnect::::build_basic_info(&*open_rd)?.smb_to_bytes(), + 5 => SMBTreeConnect::::build_standard_info(&*open_rd)?.smb_to_bytes(), + 18 => SMBTreeConnect::::build_all_info(&*open_rd)?.smb_to_bytes(), + 34 => SMBTreeConnect::::build_network_open_info(&*open_rd)?.smb_to_bytes(), + _ => { + debug!( + class = message.file_info_class(), + "unsupported file info class" + ); + return Err(SMBError::response_error(NTStatus::InvalidInfoClass)); + } + } + } + _ => { + debug!(info_type = ?message.info_type(), "unsupported info type"); + return Err(SMBError::response_error(NTStatus::InvalidInfoClass)); + } + }; + + // MS-SMB2 §3.3.5.20.1: enforce OutputBufferLength — truncate and + // return STATUS_BUFFER_OVERFLOW for variable-length info classes + let max_output = message.output_buffer_length() as usize; + if max_output > 0 && data.len() > max_output { + debug!( + data_len = data.len(), + max_output, "truncating response to output_buffer_length" + ); + data.truncate(max_output); + let response = SMBQueryInfoResponse::new(data); + let header = header.create_response_header( + NTStatus::BufferOverflow as u32, + header.session_id, + header.tree_id, + ); + return Ok(SMBHandlerState::Finished(SMBMessage::new( + header, + SMBBody::QueryInfoResponse(response), + ))); + } + + debug!(data_len = data.len(), "query_info completed"); + let response = SMBQueryInfoResponse::new(data); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new( + header, + SMBBody::QueryInfoResponse(response), + ))) + } } impl SMBLockedMessageHandler for Arc> {} diff --git a/smb/tests/smbclient.rs b/smb/tests/smbclient.rs index 0f70bfb..6d07ef0 100644 --- a/smb/tests/smbclient.rs +++ b/smb/tests/smbclient.rs @@ -13,6 +13,9 @@ //! without the server binary. Use `cargo test --test smbclient --features server -- --ignored` //! to run them explicitly. +// Tests spawn the server and kill it at the end; we don't need to wait on exit status. +#![allow(clippy::zombie_processes)] + use std::net::TcpListener; use std::process::{Child, Command, Stdio}; use std::time::Duration; @@ -70,7 +73,6 @@ fn run_smbclient(args: &[&str]) -> (bool, String, String) { /// should succeed — indicated by smbclient progressing past the initial /// connection phase. #[test] -#[ignore] fn negotiate_completes() { let port = free_port(); let mut server = spawn_server(port); @@ -103,7 +105,6 @@ fn negotiate_completes() { /// Verify that the server rejects connections with an unsupported dialect /// gracefully (no crash). #[test] -#[ignore] fn server_does_not_crash_on_smb1_only() { let port = free_port(); let mut server = spawn_server(port); @@ -143,7 +144,6 @@ fn server_does_not_crash_on_smb1_only() { /// succeeds depends on the auth configuration, but the server should not /// crash. #[test] -#[ignore] fn session_setup_with_credentials() { let port = free_port(); let mut server = spawn_server(port); @@ -175,7 +175,6 @@ fn session_setup_with_credentials() { /// Verify that anonymous (no-auth) session setup is handled. #[test] -#[ignore] fn session_setup_anonymous() { let port = free_port(); let mut server = spawn_server(port); @@ -214,7 +213,6 @@ fn session_setup_anonymous() { /// reject it (e.g. due to signing issues) but should respond with a /// proper NT status, not crash. #[test] -#[ignore] fn tree_connect_to_share() { let port = free_port(); let mut server = spawn_server(port); @@ -246,7 +244,6 @@ fn tree_connect_to_share() { /// Verify that tree connect to a nonexistent share returns an error. #[test] -#[ignore] fn tree_connect_nonexistent_share() { let port = free_port(); let mut server = spawn_server(port); @@ -274,6 +271,378 @@ fn tree_connect_nonexistent_share() { server.kill().ok(); } +// --------------------------------------------------------------------------- +// File Read Tests +// --------------------------------------------------------------------------- + +/// Verify that smbclient can read a file from the share. +/// +/// Expected: The server handles Create, Read, QueryInfo, and Close +/// without crashing. smbclient should be able to retrieve file contents. +#[test] +fn file_read_does_not_crash_server() { + use std::io::Write; + + let port = free_port(); + + // Create a temp file in the server's working directory for the share to serve + let tmp_dir = std::env::temp_dir().join(format!("smb_test_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + let test_file = tmp_dir.join("testfile.txt"); + { + let mut f = std::fs::File::create(&test_file).expect("Failed to create test file"); + f.write_all(b"hello from smb server") + .expect("Failed to write test file"); + } + + // Start server with the share path pointing to our temp dir + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + // Wait for server to start + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let download_path = tmp_dir.join("downloaded.txt"); + let port_str = port.to_string(); + let download_str = download_path.to_str().unwrap().to_string(); + let get_cmd = format!("get testfile.txt {}", download_str); + let (success, stdout, stderr) = run_smbclient(&[ + "//127.0.0.1/test", + "-p", + &port_str, + "-U", + "tejasmehta%password", + "-m", + "SMB2", + "-c", + &get_cmd, + ]); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after file read. stdout: {} stderr: {}", + stdout, + stderr + ); + + // Verify the file was downloaded and contents match + assert!( + success, + "smbclient get should succeed. stdout: {} stderr: {}", + stdout, stderr + ); + let downloaded = std::fs::read(&download_path).expect("Downloaded file should exist"); + assert_eq!( + downloaded, b"hello from smb server", + "Downloaded file contents should match the original" + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + +/// Verify that smbclient can list files (which triggers QueryInfo). +#[test] +fn directory_listing_does_not_crash_server() { + use std::io::Write; + + let port = free_port(); + + let tmp_dir = std::env::temp_dir().join(format!("smb_test_ls_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + let test_file = tmp_dir.join("listing_test.txt"); + { + let mut f = std::fs::File::create(&test_file).expect("Failed to create test file"); + f.write_all(b"test content") + .expect("Failed to write test file"); + } + + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let port_str = port.to_string(); + let (_success, _stdout, stderr) = run_smbclient(&[ + "//127.0.0.1/test", + "-p", + &port_str, + "-U", + "tejasmehta%password", + "-m", + "SMB2", + "-c", + "ls", + ]); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after directory listing. stderr: {}", + stderr + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + +/// Verify that reading a nonexistent file returns an error without crashing. +#[test] +fn read_nonexistent_file_returns_error() { + let port = free_port(); + + let tmp_dir = std::env::temp_dir().join(format!("smb_test_nofile_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let port_str = port.to_string(); + let (success, stdout, stderr) = run_smbclient(&[ + "//127.0.0.1/test", + "-p", + &port_str, + "-U", + "tejasmehta%password", + "-m", + "SMB2", + "-c", + "get nonexistent_file.txt /dev/null", + ]); + + // Should fail (file doesn't exist) + assert!( + !success || stdout.contains("NT_STATUS_") || stderr.contains("NT_STATUS_"), + "Reading nonexistent file should fail. stdout: {} stderr: {}", + stdout, + stderr + ); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after failed file read. stderr: {}", + stderr + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + +// --------------------------------------------------------------------------- +// File Write Tests +// --------------------------------------------------------------------------- + +/// Verify that smbclient can write (upload) a file to the share and that +/// the contents match what was written. +#[test] +fn file_write_uploads_file() { + use std::io::Write; + + let port = free_port(); + + let tmp_dir = std::env::temp_dir().join(format!("smb_test_write_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + + // Create a source file for smbclient to upload + let source_file = tmp_dir.join("upload_source.txt"); + let source_contents = b"hello written to smb server"; + { + let mut f = std::fs::File::create(&source_file).expect("Failed to create source file"); + f.write_all(source_contents) + .expect("Failed to write source file"); + } + + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let port_str = port.to_string(); + let source_str = source_file.to_str().unwrap().to_string(); + let put_cmd = format!("put {} uploaded.txt", source_str); + let (success, stdout, stderr) = run_smbclient(&[ + "//127.0.0.1/test", + "-p", + &port_str, + "-U", + "tejasmehta%password", + "-m", + "SMB2", + "-c", + &put_cmd, + ]); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after file write. stdout: {} stderr: {}", + stdout, + stderr + ); + + assert!( + success, + "smbclient put should succeed. stdout: {} stderr: {}", + stdout, stderr + ); + + // Verify the uploaded file exists on the server side and contents match + let uploaded_path = tmp_dir.join("uploaded.txt"); + let uploaded = std::fs::read(&uploaded_path).expect("Uploaded file should exist on server"); + assert_eq!( + uploaded, source_contents, + "Uploaded file contents should match the source" + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + +/// Verify that smbclient can write a file and then read it back, and the +/// contents round-trip correctly. +#[test] +fn file_write_then_read_round_trip() { + use std::io::Write; + + let port = free_port(); + + let tmp_dir = std::env::temp_dir().join(format!("smb_test_rw_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + + // Create a source file for smbclient to upload + let source_file = tmp_dir.join("rw_source.txt"); + let source_contents = b"round-trip test data"; + { + let mut f = std::fs::File::create(&source_file).expect("Failed to create source file"); + f.write_all(source_contents) + .expect("Failed to write source file"); + } + + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let port_str = port.to_string(); + let source_str = source_file.to_str().unwrap().to_string(); + let download_path = tmp_dir.join("rw_downloaded.txt"); + let download_str = download_path.to_str().unwrap().to_string(); + + // Upload, then download in a single smbclient session + let cmd = format!( + "put {} rw_remote.txt; get rw_remote.txt {}", + source_str, download_str + ); + let (success, stdout, stderr) = run_smbclient(&[ + "//127.0.0.1/test", + "-p", + &port_str, + "-U", + "tejasmehta%password", + "-m", + "SMB2", + "-c", + &cmd, + ]); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after put+get. stdout: {} stderr: {}", + stdout, + stderr + ); + + assert!( + success, + "smbclient put+get should succeed. stdout: {} stderr: {}", + stdout, stderr + ); + + let downloaded = std::fs::read(&download_path).expect("Downloaded file should exist"); + assert_eq!( + downloaded, source_contents, + "Downloaded file contents should match the originally written data" + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + // --------------------------------------------------------------------------- // Echo Tests // --------------------------------------------------------------------------- @@ -283,7 +652,6 @@ fn tree_connect_nonexistent_share() { /// Note: smbclient doesn't have a direct "echo" command, but we can /// verify the server stays alive through multiple operations. #[test] -#[ignore] fn server_survives_multiple_connections() { let port = free_port(); let mut server = spawn_server(port);