diff --git a/Cargo.lock b/Cargo.lock index dd0ec6dd18e..1f16ec1f405 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1749,6 +1749,26 @@ dependencies = [ "windows-link", ] +[[package]] +name = "liblzma" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6033b77c21d1f56deeae8014eb9fbe7bdf1765185a6c508b5ca82eeaed7f899" +dependencies = [ + "liblzma-sys", +] + +[[package]] +name = "liblzma-sys" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f2db66f3268487b5033077f266da6777d057949b8f93c8ad82e441df25e6186" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "libm" version = "0.2.16" @@ -1815,17 +1835,6 @@ dependencies = [ "twox-hash", ] -[[package]] -name = "lzma-sys" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" -dependencies = [ - "cc", - "libc", - "pkg-config", -] - [[package]] name = "mac_address" version = "1.1.8" @@ -3230,9 +3239,10 @@ dependencies = [ "indexmap", "itertools 0.14.0", "libc", + "liblzma", + "liblzma-sys", "libsqlite3-sys", "libz-rs-sys", - "lzma-sys", "mac_address", "malachite-bigint", "md-5", @@ -3291,7 +3301,6 @@ dependencies = [ "x509-cert", "x509-parser", "xml", - "xz2", ] [[package]] @@ -4823,15 +4832,6 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8aa498d22c9bbaf482329839bc5620c46be275a19a812e9a22a2b07529a642a" -[[package]] -name = "xz2" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" -dependencies = [ - "lzma-sys", -] - [[package]] name = "zerocopy" version = "0.8.34" diff --git a/Lib/tarfile.py b/Lib/tarfile.py old mode 100755 new mode 100644 index 04fda115971..c7e9f7d681a --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 #------------------------------------------------------------------- # tarfile.py #------------------------------------------------------------------- @@ -46,7 +45,6 @@ import struct import copy import re -import warnings try: import pwd @@ -69,7 +67,7 @@ "DEFAULT_FORMAT", "open","fully_trusted_filter", "data_filter", "tar_filter", "FilterError", "AbsoluteLinkError", "OutsideDestinationError", "SpecialFileError", "AbsolutePathError", - "LinkOutsideDestinationError"] + "LinkOutsideDestinationError", "LinkFallbackError"] #--------------------------------------------------------- @@ -341,7 +339,7 @@ class _Stream: """ def __init__(self, name, mode, comptype, fileobj, bufsize, - compresslevel): + compresslevel, preset): """Construct a _Stream object. """ self._extfileobj = True @@ -355,7 +353,7 @@ def __init__(self, name, mode, comptype, fileobj, bufsize, fileobj = _StreamProxy(fileobj) comptype = fileobj.getcomptype() - self.name = name or "" + self.name = os.fspath(name) if name is not None else "" self.mode = mode self.comptype = comptype self.fileobj = fileobj @@ -395,17 +393,23 @@ def __init__(self, name, mode, comptype, fileobj, bufsize, import lzma except ImportError: raise CompressionError("lzma module is not available") from None - - # XXX: RUSTPYTHON; xz is not supported yet - raise CompressionError("lzma module is not available") from None - if mode == "r": self.dbuf = b"" self.cmp = lzma.LZMADecompressor() self.exception = lzma.LZMAError else: - self.cmp = lzma.LZMACompressor() - + self.cmp = lzma.LZMACompressor(preset=preset) + elif comptype == "zst": + try: + from compression import zstd + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + if mode == "r": + self.dbuf = b"" + self.cmp = zstd.ZstdDecompressor() + self.exception = zstd.ZstdError + else: + self.cmp = zstd.ZstdCompressor() elif comptype != "tar": raise CompressionError("unknown compression type %r" % comptype) @@ -597,6 +601,8 @@ def getcomptype(self): return "bz2" elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")): return "xz" + elif self.buf.startswith(b"\x28\xb5\x2f\xfd"): + return "zst" else: return "tar" @@ -641,6 +647,10 @@ def __init__(self, fileobj, offset, size, name, blockinfo=None): def flush(self): pass + @property + def mode(self): + return 'rb' + def readable(self): return True @@ -756,10 +766,22 @@ def __init__(self, tarinfo, path): super().__init__(f'{tarinfo.name!r} would link to {path!r}, ' + 'which is outside the destination') +class LinkFallbackError(FilterError): + def __init__(self, tarinfo, path): + self.tarinfo = tarinfo + self._path = path + super().__init__(f'link {tarinfo.name!r} would be extracted as a ' + + f'copy of {path!r}, which was rejected') + +# Errors caused by filters -- both "fatal" and "non-fatal" -- that +# we consider to be issues with the argument, rather than a bug in the +# filter function +_FILTER_ERRORS = (FilterError, OSError, ExtractError) + def _get_filtered_attrs(member, dest_path, for_data=True): new_attrs = {} name = member.name - dest_path = os.path.realpath(dest_path) + dest_path = os.path.realpath(dest_path, strict=os.path.ALLOW_MISSING) # Strip leading / (tar's directory separator) from filenames. # Include os.sep (target OS directory separator) as well. if name.startswith(('/', os.sep)): @@ -769,7 +791,8 @@ def _get_filtered_attrs(member, dest_path, for_data=True): # For example, 'C:/foo' on Windows. raise AbsolutePathError(member) # Ensure we stay in the destination - target_path = os.path.realpath(os.path.join(dest_path, name)) + target_path = os.path.realpath(os.path.join(dest_path, name), + strict=os.path.ALLOW_MISSING) if os.path.commonpath([target_path, dest_path]) != dest_path: raise OutsideDestinationError(member, target_path) # Limit permissions (no high bits, and go-w) @@ -807,6 +830,9 @@ def _get_filtered_attrs(member, dest_path, for_data=True): if member.islnk() or member.issym(): if os.path.isabs(member.linkname): raise AbsoluteLinkError(member) + normalized = os.path.normpath(member.linkname) + if normalized != member.linkname: + new_attrs['linkname'] = normalized if member.issym(): target_path = os.path.join(dest_path, os.path.dirname(name), @@ -814,7 +840,8 @@ def _get_filtered_attrs(member, dest_path, for_data=True): else: target_path = os.path.join(dest_path, member.linkname) - target_path = os.path.realpath(target_path) + target_path = os.path.realpath(target_path, + strict=os.path.ALLOW_MISSING) if os.path.commonpath([target_path, dest_path]) != dest_path: raise LinkOutsideDestinationError(member, target_path) return new_attrs @@ -847,6 +874,9 @@ def data_filter(member, dest_path): # Sentinel for replace() defaults, meaning "don't change the attribute" _KEEP = object() +# Header length is digits followed by a space. +_header_length_prefix_re = re.compile(br"([0-9]{1,20}) ") + class TarInfo(object): """Informational class which holds the details about an archive member given by a tar header block. @@ -877,7 +907,7 @@ class TarInfo(object): pax_headers = ('A dictionary containing key-value pairs of an ' 'associated pax extended header.'), sparse = 'Sparse member information.', - tarfile = None, + _tarfile = None, _sparse_structs = None, _link_target = None, ) @@ -906,6 +936,24 @@ def __init__(self, name=""): self.sparse = None # sparse member information self.pax_headers = {} # pax header information + @property + def tarfile(self): + import warnings + warnings.warn( + 'The undocumented "tarfile" attribute of TarInfo objects ' + + 'is deprecated and will be removed in Python 3.16', + DeprecationWarning, stacklevel=2) + return self._tarfile + + @tarfile.setter + def tarfile(self, tarfile): + import warnings + warnings.warn( + 'The undocumented "tarfile" attribute of TarInfo objects ' + + 'is deprecated and will be removed in Python 3.16', + DeprecationWarning, stacklevel=2) + self._tarfile = tarfile + @property def path(self): 'In pax headers, "name" is called "path".' @@ -1200,7 +1248,7 @@ def _create_pax_generic_header(cls, pax_headers, type, encoding): for keyword, value in pax_headers.items(): keyword = keyword.encode("utf-8") if binary: - # Try to restore the original byte representation of `value'. + # Try to restore the original byte representation of 'value'. # Needless to say, that the encoding must match the string. value = value.encode(encoding, "surrogateescape") else: @@ -1416,37 +1464,59 @@ def _proc_pax(self, tarfile): else: pax_headers = tarfile.pax_headers.copy() - # Check if the pax header contains a hdrcharset field. This tells us - # the encoding of the path, linkpath, uname and gname fields. Normally, - # these fields are UTF-8 encoded but since POSIX.1-2008 tar - # implementations are allowed to store them as raw binary strings if - # the translation to UTF-8 fails. - match = re.search(br"\d+ hdrcharset=([^\n]+)\n", buf) - if match is not None: - pax_headers["hdrcharset"] = match.group(1).decode("utf-8") - - # For the time being, we don't care about anything other than "BINARY". - # The only other value that is currently allowed by the standard is - # "ISO-IR 10646 2000 UTF-8" in other words UTF-8. - hdrcharset = pax_headers.get("hdrcharset") - if hdrcharset == "BINARY": - encoding = tarfile.encoding - else: - encoding = "utf-8" - # Parse pax header information. A record looks like that: # "%d %s=%s\n" % (length, keyword, value). length is the size # of the complete record including the length field itself and - # the newline. keyword and value are both UTF-8 encoded strings. - regex = re.compile(br"(\d+) ([^=]+)=") + # the newline. pos = 0 - while match := regex.match(buf, pos): - length, keyword = match.groups() - length = int(length) - if length == 0: + encoding = None + raw_headers = [] + while len(buf) > pos and buf[pos] != 0x00: + if not (match := _header_length_prefix_re.match(buf, pos)): raise InvalidHeaderError("invalid header") - value = buf[match.end(2) + 1:match.start(1) + length - 1] + try: + length = int(match.group(1)) + except ValueError: + raise InvalidHeaderError("invalid header") + # Headers must be at least 5 bytes, shortest being '5 x=\n'. + # Value is allowed to be empty. + if length < 5: + raise InvalidHeaderError("invalid header") + if pos + length > len(buf): + raise InvalidHeaderError("invalid header") + + header_value_end_offset = match.start(1) + length - 1 # Last byte of the header + keyword_and_value = buf[match.end(1) + 1:header_value_end_offset] + raw_keyword, equals, raw_value = keyword_and_value.partition(b"=") + # Check the framing of the header. The last character must be '\n' (0x0A) + if not raw_keyword or equals != b"=" or buf[header_value_end_offset] != 0x0A: + raise InvalidHeaderError("invalid header") + raw_headers.append((length, raw_keyword, raw_value)) + + # Check if the pax header contains a hdrcharset field. This tells us + # the encoding of the path, linkpath, uname and gname fields. Normally, + # these fields are UTF-8 encoded but since POSIX.1-2008 tar + # implementations are allowed to store them as raw binary strings if + # the translation to UTF-8 fails. For the time being, we don't care about + # anything other than "BINARY". The only other value that is currently + # allowed by the standard is "ISO-IR 10646 2000 UTF-8" in other words UTF-8. + # Note that we only follow the initial 'hdrcharset' setting to preserve + # the initial behavior of the 'tarfile' module. + if raw_keyword == b"hdrcharset" and encoding is None: + if raw_value == b"BINARY": + encoding = tarfile.encoding + else: # This branch ensures only the first 'hdrcharset' header is used. + encoding = "utf-8" + + pos += length + + # If no explicit hdrcharset is set, we use UTF-8 as a default. + if encoding is None: + encoding = "utf-8" + + # After parsing the raw headers we can decode them to text. + for length, raw_keyword, raw_value in raw_headers: # Normally, we could just use "utf-8" as the encoding and "strict" # as the error handler, but we better not take the risk. For # example, GNU tar <= 1.23 is known to store filenames it cannot @@ -1454,17 +1524,16 @@ def _proc_pax(self, tarfile): # hdrcharset=BINARY header). # We first try the strict standard encoding, and if that fails we # fall back on the user's encoding and error handler. - keyword = self._decode_pax_field(keyword, "utf-8", "utf-8", + keyword = self._decode_pax_field(raw_keyword, "utf-8", "utf-8", tarfile.errors) if keyword in PAX_NAME_FIELDS: - value = self._decode_pax_field(value, encoding, tarfile.encoding, + value = self._decode_pax_field(raw_value, encoding, tarfile.encoding, tarfile.errors) else: - value = self._decode_pax_field(value, "utf-8", "utf-8", + value = self._decode_pax_field(raw_value, "utf-8", "utf-8", tarfile.errors) pax_headers[keyword] = value - pos += length # Fetch the next header. try: @@ -1479,7 +1548,7 @@ def _proc_pax(self, tarfile): elif "GNU.sparse.size" in pax_headers: # GNU extended sparse format version 0.0. - self._proc_gnusparse_00(next, pax_headers, buf) + self._proc_gnusparse_00(next, raw_headers) elif pax_headers.get("GNU.sparse.major") == "1" and pax_headers.get("GNU.sparse.minor") == "0": # GNU extended sparse format version 1.0. @@ -1501,15 +1570,24 @@ def _proc_pax(self, tarfile): return next - def _proc_gnusparse_00(self, next, pax_headers, buf): + def _proc_gnusparse_00(self, next, raw_headers): """Process a GNU tar extended sparse header, version 0.0. """ offsets = [] - for match in re.finditer(br"\d+ GNU.sparse.offset=(\d+)\n", buf): - offsets.append(int(match.group(1))) numbytes = [] - for match in re.finditer(br"\d+ GNU.sparse.numbytes=(\d+)\n", buf): - numbytes.append(int(match.group(1))) + for _, keyword, value in raw_headers: + if keyword == b"GNU.sparse.offset": + try: + offsets.append(int(value.decode())) + except ValueError: + raise InvalidHeaderError("invalid header") + + elif keyword == b"GNU.sparse.numbytes": + try: + numbytes.append(int(value.decode())) + except ValueError: + raise InvalidHeaderError("invalid header") + next.sparse = list(zip(offsets, numbytes)) def _proc_gnusparse_01(self, next, pax_headers): @@ -1569,6 +1647,9 @@ def _block(self, count): """Round up a byte count by BLOCKSIZE and return it, e.g. _block(834) => 1024. """ + # Only non-negative offsets are allowed + if count < 0: + raise InvalidHeaderError("invalid offset") blocks, remainder = divmod(count, BLOCKSIZE) if remainder: blocks += 1 @@ -1645,14 +1726,14 @@ class TarFile(object): def __init__(self, name=None, mode="r", fileobj=None, format=None, tarinfo=None, dereference=None, ignore_zeros=None, encoding=None, errors="surrogateescape", pax_headers=None, debug=None, - errorlevel=None, copybufsize=None): - """Open an (uncompressed) tar archive `name'. `mode' is either 'r' to + errorlevel=None, copybufsize=None, stream=False): + """Open an (uncompressed) tar archive 'name'. 'mode' is either 'r' to read from an existing archive, 'a' to append data to an existing - file or 'w' to create a new file overwriting an existing one. `mode' + file or 'w' to create a new file overwriting an existing one. 'mode' defaults to 'r'. - If `fileobj' is given, it is used for reading or writing data. If it - can be determined, `mode' is overridden by `fileobj's mode. - `fileobj' is not closed, when TarFile is closed. + If 'fileobj' is given, it is used for reading or writing data. If it + can be determined, 'mode' is overridden by 'fileobj's mode. + 'fileobj' is not closed, when TarFile is closed. """ modes = {"r": "rb", "a": "r+b", "w": "wb", "x": "xb"} if mode not in modes: @@ -1677,6 +1758,8 @@ def __init__(self, name=None, mode="r", fileobj=None, format=None, self.name = os.path.abspath(name) if name else None self.fileobj = fileobj + self.stream = stream + # Init attributes. if format is not None: self.format = format @@ -1709,6 +1792,8 @@ def __init__(self, name=None, mode="r", fileobj=None, format=None, # current position in the archive file self.inodes = {} # dictionary caching the inodes of # archive members already added + self._unames = {} # Cached mappings of uid -> uname + self._gnames = {} # Cached mappings of gid -> gname try: if self.mode == "r": @@ -1764,11 +1849,13 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): 'r:gz' open for reading with gzip compression 'r:bz2' open for reading with bzip2 compression 'r:xz' open for reading with lzma compression + 'r:zst' open for reading with zstd compression 'a' or 'a:' open for appending, creating the file if necessary 'w' or 'w:' open for writing without compression 'w:gz' open for writing with gzip compression 'w:bz2' open for writing with bzip2 compression 'w:xz' open for writing with lzma compression + 'w:zst' open for writing with zstd compression 'x' or 'x:' create a tarfile exclusively without compression, raise an exception if the file is already created @@ -1778,16 +1865,20 @@ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): if the file is already created 'x:xz' create an lzma compressed tarfile, raise an exception if the file is already created + 'x:zst' create a zstd compressed tarfile, raise an exception + if the file is already created 'r|*' open a stream of tar blocks with transparent compression 'r|' open an uncompressed stream of tar blocks for reading 'r|gz' open a gzip compressed stream of tar blocks 'r|bz2' open a bzip2 compressed stream of tar blocks 'r|xz' open an lzma compressed stream of tar blocks + 'r|zst' open a zstd compressed stream of tar blocks 'w|' open an uncompressed stream for writing 'w|gz' open a gzip compressed stream for writing 'w|bz2' open a bzip2 compressed stream for writing 'w|xz' open an lzma compressed stream for writing + 'w|zst' open a zstd compressed stream for writing """ if not name and not fileobj: @@ -1832,10 +1923,17 @@ def not_compressed(comptype): if filemode not in ("r", "w"): raise ValueError("mode must be 'r' or 'w'") + if "compresslevel" in kwargs and comptype not in ("gz", "bz2"): + raise ValueError( + "compresslevel is only valid for w|gz and w|bz2 modes" + ) + if "preset" in kwargs and comptype not in ("xz",): + raise ValueError("preset is only valid for w|xz mode") compresslevel = kwargs.pop("compresslevel", 9) + preset = kwargs.pop("preset", None) stream = _Stream(name, filemode, comptype, fileobj, bufsize, - compresslevel) + compresslevel, preset) try: t = cls(name, filemode, stream, **kwargs) except: @@ -1931,9 +2029,6 @@ def xzopen(cls, name, mode="r", fileobj=None, preset=None, **kwargs): except ImportError: raise CompressionError("lzma module is not available") from None - # XXX: RUSTPYTHON; xz is not supported yet - raise CompressionError("lzma module is not available") from None - fileobj = LZMAFile(fileobj or name, mode, preset=preset) try: @@ -1949,12 +2044,48 @@ def xzopen(cls, name, mode="r", fileobj=None, preset=None, **kwargs): t._extfileobj = False return t + @classmethod + def zstopen(cls, name, mode="r", fileobj=None, level=None, options=None, + zstd_dict=None, **kwargs): + """Open zstd compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from compression.zstd import ZstdFile, ZstdError + except ImportError: + raise CompressionError("compression.zstd module is not available") from None + + fileobj = ZstdFile( + fileobj or name, + mode, + level=level, + options=options, + zstd_dict=zstd_dict + ) + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except (ZstdError, EOFError) as e: + fileobj.close() + if mode == 'r': + raise ReadError("not a zstd file") from e + raise + except Exception: + fileobj.close() + raise + t._extfileobj = False + return t + # All *open() methods are registered here. OPEN_METH = { "tar": "taropen", # uncompressed tar "gz": "gzopen", # gzip compressed tar "bz2": "bz2open", # bzip2 compressed tar - "xz": "xzopen" # lzma compressed tar + "xz": "xzopen", # lzma compressed tar + "zst": "zstopen", # zstd compressed tar } #-------------------------------------------------------------------------- @@ -1982,7 +2113,7 @@ def close(self): self.fileobj.close() def getmember(self, name): - """Return a TarInfo object for member `name'. If `name' can not be + """Return a TarInfo object for member 'name'. If 'name' can not be found in the archive, KeyError is raised. If a member occurs more than once in the archive, its last occurrence is assumed to be the most up-to-date version. @@ -2010,9 +2141,9 @@ def getnames(self): def gettarinfo(self, name=None, arcname=None, fileobj=None): """Create a TarInfo object from the result of os.stat or equivalent - on an existing file. The file is either named by `name', or - specified as a file object `fileobj' with a file descriptor. If - given, `arcname' specifies an alternative name for the file in the + on an existing file. The file is either named by 'name', or + specified as a file object 'fileobj' with a file descriptor. If + given, 'arcname' specifies an alternative name for the file in the archive, otherwise, the name is taken from the 'name' attribute of 'fileobj', or the 'name' argument. The name should be a text string. @@ -2036,7 +2167,7 @@ def gettarinfo(self, name=None, arcname=None, fileobj=None): # Now, fill the TarInfo object with # information specific for the file. tarinfo = self.tarinfo() - tarinfo.tarfile = self # Not needed + tarinfo._tarfile = self # To be removed in 3.16. # Use os.stat or os.lstat, depending on if symlinks shall be resolved. if fileobj is None: @@ -2090,16 +2221,23 @@ def gettarinfo(self, name=None, arcname=None, fileobj=None): tarinfo.mtime = statres.st_mtime tarinfo.type = type tarinfo.linkname = linkname + + # Calls to pwd.getpwuid() and grp.getgrgid() tend to be expensive. To + # speed things up, cache the resolved usernames and group names. if pwd: - try: - tarinfo.uname = pwd.getpwuid(tarinfo.uid)[0] - except KeyError: - pass + if tarinfo.uid not in self._unames: + try: + self._unames[tarinfo.uid] = pwd.getpwuid(tarinfo.uid)[0] + except KeyError: + self._unames[tarinfo.uid] = '' + tarinfo.uname = self._unames[tarinfo.uid] if grp: - try: - tarinfo.gname = grp.getgrgid(tarinfo.gid)[0] - except KeyError: - pass + if tarinfo.gid not in self._gnames: + try: + self._gnames[tarinfo.gid] = grp.getgrgid(tarinfo.gid)[0] + except KeyError: + self._gnames[tarinfo.gid] = '' + tarinfo.gname = self._gnames[tarinfo.gid] if type in (CHRTYPE, BLKTYPE): if hasattr(os, "major") and hasattr(os, "minor"): @@ -2108,11 +2246,15 @@ def gettarinfo(self, name=None, arcname=None, fileobj=None): return tarinfo def list(self, verbose=True, *, members=None): - """Print a table of contents to sys.stdout. If `verbose' is False, only - the names of the members are printed. If it is True, an `ls -l'-like - output is produced. `members' is optional and must be a subset of the + """Print a table of contents to sys.stdout. If 'verbose' is False, only + the names of the members are printed. If it is True, an 'ls -l'-like + output is produced. 'members' is optional and must be a subset of the list returned by getmembers(). """ + # Convert tarinfo type to stat type. + type2mode = {REGTYPE: stat.S_IFREG, SYMTYPE: stat.S_IFLNK, + FIFOTYPE: stat.S_IFIFO, CHRTYPE: stat.S_IFCHR, + DIRTYPE: stat.S_IFDIR, BLKTYPE: stat.S_IFBLK} self._check() if members is None: @@ -2122,7 +2264,8 @@ def list(self, verbose=True, *, members=None): if tarinfo.mode is None: _safe_print("??????????") else: - _safe_print(stat.filemode(tarinfo.mode)) + modetype = type2mode.get(tarinfo.type, 0) + _safe_print(stat.filemode(modetype | tarinfo.mode)) _safe_print("%s/%s" % (tarinfo.uname or tarinfo.uid, tarinfo.gname or tarinfo.gid)) if tarinfo.ischr() or tarinfo.isblk(): @@ -2146,11 +2289,11 @@ def list(self, verbose=True, *, members=None): print() def add(self, name, arcname=None, recursive=True, *, filter=None): - """Add the file `name' to the archive. `name' may be any type of file - (directory, fifo, symbolic link, etc.). If given, `arcname' + """Add the file 'name' to the archive. 'name' may be any type of file + (directory, fifo, symbolic link, etc.). If given, 'arcname' specifies an alternative name for the file in the archive. Directories are added recursively by default. This can be avoided by - setting `recursive' to False. `filter' is a function + setting 'recursive' to False. 'filter' is a function that expects a TarInfo object argument and returns the changed TarInfo object, if it returns None the TarInfo object will be excluded from the archive. @@ -2197,13 +2340,16 @@ def add(self, name, arcname=None, recursive=True, *, filter=None): self.addfile(tarinfo) def addfile(self, tarinfo, fileobj=None): - """Add the TarInfo object `tarinfo' to the archive. If `fileobj' is - given, it should be a binary file, and tarinfo.size bytes are read - from it and added to the archive. You can create TarInfo objects - directly, or by using gettarinfo(). + """Add the TarInfo object 'tarinfo' to the archive. If 'tarinfo' represents + a non zero-size regular file, the 'fileobj' argument should be a binary file, + and tarinfo.size bytes are read from it and added to the archive. + You can create TarInfo objects directly, or by using gettarinfo(). """ self._check("awx") + if fileobj is None and tarinfo.isreg() and tarinfo.size != 0: + raise ValueError("fileobj not provided for non zero-size regular file") + tarinfo = copy.copy(tarinfo) buf = tarinfo.tobuf(self.format, self.encoding, self.errors) @@ -2225,12 +2371,7 @@ def _get_filter_function(self, filter): if filter is None: filter = self.extraction_filter if filter is None: - warnings.warn( - 'Python 3.14 will, by default, filter extracted tar ' - + 'archives and reject files or modify their metadata. ' - + 'Use the filter argument to control this behavior.', - DeprecationWarning) - return fully_trusted_filter + return data_filter if isinstance(filter, str): raise TypeError( 'String names are not supported for ' @@ -2248,12 +2389,12 @@ def extractall(self, path=".", members=None, *, numeric_owner=False, filter=None): """Extract all members from the archive to the current working directory and set owner, modification time and permissions on - directories afterwards. `path' specifies a different directory - to extract to. `members' is optional and must be a subset of the - list returned by getmembers(). If `numeric_owner` is True, only + directories afterwards. 'path' specifies a different directory + to extract to. 'members' is optional and must be a subset of the + list returned by getmembers(). If 'numeric_owner' is True, only the numbers for user/group names are used and not the names. - The `filter` function will be called on each member just + The 'filter' function will be called on each member just before extraction. It can return a changed TarInfo or None to skip the member. String names of common filters are accepted. @@ -2265,81 +2406,124 @@ def extractall(self, path=".", members=None, *, numeric_owner=False, members = self for member in members: - tarinfo = self._get_extract_tarinfo(member, filter_function, path) + tarinfo, unfiltered = self._get_extract_tarinfo( + member, filter_function, path) if tarinfo is None: continue if tarinfo.isdir(): # For directories, delay setting attributes until later, # since permissions can interfere with extraction and # extracting contents can reset mtime. - directories.append(tarinfo) + directories.append(unfiltered) self._extract_one(tarinfo, path, set_attrs=not tarinfo.isdir(), - numeric_owner=numeric_owner) + numeric_owner=numeric_owner, + filter_function=filter_function) # Reverse sort directories. directories.sort(key=lambda a: a.name, reverse=True) + # Set correct owner, mtime and filemode on directories. - for tarinfo in directories: - dirpath = os.path.join(path, tarinfo.name) + for unfiltered in directories: try: + # Need to re-apply any filter, to take the *current* filesystem + # state into account. + try: + tarinfo = filter_function(unfiltered, path) + except _FILTER_ERRORS as exc: + self._log_no_directory_fixup(unfiltered, repr(exc)) + continue + if tarinfo is None: + self._log_no_directory_fixup(unfiltered, + 'excluded by filter') + continue + dirpath = os.path.join(path, tarinfo.name) + try: + lstat = os.lstat(dirpath) + except FileNotFoundError: + self._log_no_directory_fixup(tarinfo, 'missing') + continue + if not stat.S_ISDIR(lstat.st_mode): + # This is no longer a directory; presumably a later + # member overwrote the entry. + self._log_no_directory_fixup(tarinfo, 'not a directory') + continue self.chown(tarinfo, dirpath, numeric_owner=numeric_owner) self.utime(tarinfo, dirpath) self.chmod(tarinfo, dirpath) except ExtractError as e: self._handle_nonfatal_error(e) + def _log_no_directory_fixup(self, member, reason): + self._dbg(2, "tarfile: Not fixing up directory %r (%s)" % + (member.name, reason)) + def extract(self, member, path="", set_attrs=True, *, numeric_owner=False, filter=None): """Extract a member from the archive to the current working directory, using its full name. Its file information is extracted as accurately - as possible. `member' may be a filename or a TarInfo object. You can - specify a different directory using `path'. File attributes (owner, - mtime, mode) are set unless `set_attrs' is False. If `numeric_owner` + as possible. 'member' may be a filename or a TarInfo object. You can + specify a different directory using 'path'. File attributes (owner, + mtime, mode) are set unless 'set_attrs' is False. If 'numeric_owner' is True, only the numbers for user/group names are used and not the names. - The `filter` function will be called before extraction. + The 'filter' function will be called before extraction. It can return a changed TarInfo or None to skip the member. String names of common filters are accepted. """ filter_function = self._get_filter_function(filter) - tarinfo = self._get_extract_tarinfo(member, filter_function, path) + tarinfo, unfiltered = self._get_extract_tarinfo( + member, filter_function, path) if tarinfo is not None: self._extract_one(tarinfo, path, set_attrs, numeric_owner) def _get_extract_tarinfo(self, member, filter_function, path): - """Get filtered TarInfo (or None) from member, which might be a str""" + """Get (filtered, unfiltered) TarInfos from *member* + + *member* might be a string. + + Return (None, None) if not found. + """ + if isinstance(member, str): - tarinfo = self.getmember(member) + unfiltered = self.getmember(member) else: - tarinfo = member + unfiltered = member - unfiltered = tarinfo + filtered = None try: - tarinfo = filter_function(tarinfo, path) - except (OSError, FilterError) as e: + filtered = filter_function(unfiltered, path) + except (OSError, UnicodeEncodeError, FilterError) as e: self._handle_fatal_error(e) except ExtractError as e: self._handle_nonfatal_error(e) - if tarinfo is None: + if filtered is None: self._dbg(2, "tarfile: Excluded %r" % unfiltered.name) - return None + return None, None + # Prepare the link target for makelink(). - if tarinfo.islnk(): - tarinfo = copy.copy(tarinfo) - tarinfo._link_target = os.path.join(path, tarinfo.linkname) - return tarinfo + if filtered.islnk(): + filtered = copy.copy(filtered) + filtered._link_target = os.path.join(path, filtered.linkname) + return filtered, unfiltered - def _extract_one(self, tarinfo, path, set_attrs, numeric_owner): - """Extract from filtered tarinfo to disk""" + def _extract_one(self, tarinfo, path, set_attrs, numeric_owner, + filter_function=None): + """Extract from filtered tarinfo to disk. + + filter_function is only used when extracting a *different* + member (e.g. as fallback to creating a symlink) + """ self._check("r") try: self._extract_member(tarinfo, os.path.join(path, tarinfo.name), set_attrs=set_attrs, - numeric_owner=numeric_owner) - except OSError as e: + numeric_owner=numeric_owner, + filter_function=filter_function, + extraction_root=path) + except (OSError, UnicodeEncodeError) as e: self._handle_fatal_error(e) except ExtractError as e: self._handle_nonfatal_error(e) @@ -2364,10 +2548,10 @@ def _handle_fatal_error(self, e): self._dbg(1, "tarfile: %s %s" % (type(e).__name__, e)) def extractfile(self, member): - """Extract a member from the archive as a file object. `member' may be - a filename or a TarInfo object. If `member' is a regular file or + """Extract a member from the archive as a file object. 'member' may be + a filename or a TarInfo object. If 'member' is a regular file or a link, an io.BufferedReader object is returned. For all other - existing members, None is returned. If `member' does not appear + existing members, None is returned. If 'member' does not appear in the archive, KeyError is raised. """ self._check("r") @@ -2396,9 +2580,13 @@ def extractfile(self, member): return None def _extract_member(self, tarinfo, targetpath, set_attrs=True, - numeric_owner=False): - """Extract the TarInfo object tarinfo to a physical + numeric_owner=False, *, filter_function=None, + extraction_root=None): + """Extract the filtered TarInfo object tarinfo to a physical file called targetpath. + + filter_function is only used when extracting a *different* + member (e.g. as fallback to creating a symlink) """ # Fetch the TarInfo object for the given name # and build the destination pathname, replacing @@ -2411,7 +2599,7 @@ def _extract_member(self, tarinfo, targetpath, set_attrs=True, if upperdirs and not os.path.exists(upperdirs): # Create directories that are not part of the archive with # default permissions. - os.makedirs(upperdirs) + os.makedirs(upperdirs, exist_ok=True) if tarinfo.islnk() or tarinfo.issym(): self._dbg(1, "%s -> %s" % (tarinfo.name, tarinfo.linkname)) @@ -2427,7 +2615,10 @@ def _extract_member(self, tarinfo, targetpath, set_attrs=True, elif tarinfo.ischr() or tarinfo.isblk(): self.makedev(tarinfo, targetpath) elif tarinfo.islnk() or tarinfo.issym(): - self.makelink(tarinfo, targetpath) + self.makelink_with_filter( + tarinfo, targetpath, + filter_function=filter_function, + extraction_root=extraction_root) elif tarinfo.type not in SUPPORTED_TYPES: self.makeunknown(tarinfo, targetpath) else: @@ -2510,10 +2701,18 @@ def makedev(self, tarinfo, targetpath): os.makedev(tarinfo.devmajor, tarinfo.devminor)) def makelink(self, tarinfo, targetpath): + return self.makelink_with_filter(tarinfo, targetpath, None, None) + + def makelink_with_filter(self, tarinfo, targetpath, + filter_function, extraction_root): """Make a (symbolic) link called targetpath. If it cannot be created (platform limitation), we try to make a copy of the referenced file instead of a link. + + filter_function is only used when extracting a *different* + member (e.g. as fallback to creating a link). """ + keyerror_to_extracterror = False try: # For systems that support symbolic and hard links. if tarinfo.issym(): @@ -2521,18 +2720,41 @@ def makelink(self, tarinfo, targetpath): # Avoid FileExistsError on following os.symlink. os.unlink(targetpath) os.symlink(tarinfo.linkname, targetpath) + return else: if os.path.exists(tarinfo._link_target): + if os.path.lexists(targetpath): + # Avoid FileExistsError on following os.link. + os.unlink(targetpath) os.link(tarinfo._link_target, targetpath) - else: - self._extract_member(self._find_link_target(tarinfo), - targetpath) + return except symlink_exception: + keyerror_to_extracterror = True + + try: + unfiltered = self._find_link_target(tarinfo) + except KeyError: + if keyerror_to_extracterror: + raise ExtractError( + "unable to resolve link inside archive") from None + else: + raise + + if filter_function is None: + filtered = unfiltered + else: + if extraction_root is None: + raise ExtractError( + "makelink_with_filter: if filter_function is not None, " + + "extraction_root must also not be None") try: - self._extract_member(self._find_link_target(tarinfo), - targetpath) - except KeyError: - raise ExtractError("unable to resolve link inside archive") from None + filtered = filter_function(unfiltered, extraction_root) + except _FILTER_ERRORS as cause: + raise LinkFallbackError(tarinfo, unfiltered.name) from cause + if filtered is not None: + self._extract_member(filtered, targetpath, + filter_function=filter_function, + extraction_root=extraction_root) def chown(self, tarinfo, targetpath, numeric_owner): """Set owner of targetpath according to tarinfo. If numeric_owner @@ -2564,7 +2786,8 @@ def chown(self, tarinfo, targetpath, numeric_owner): os.lchown(targetpath, u, g) else: os.chown(targetpath, u, g) - except OSError as e: + except (OSError, OverflowError) as e: + # OverflowError can be raised if an ID doesn't fit in 'id_t' raise ExtractError("could not change owner") from e def chmod(self, tarinfo, targetpath): @@ -2647,7 +2870,9 @@ def next(self): break if tarinfo is not None: - self.members.append(tarinfo) + # if streaming the file we do not want to cache the tarinfo + if not self.stream: + self.members.append(tarinfo) else: self._loaded = True @@ -2698,11 +2923,12 @@ def _getmember(self, name, tarinfo=None, normalize=False): def _load(self): """Read through the entire archive file and look for readable - members. + members. This should not run if the file is set to stream. """ - while self.next() is not None: - pass - self._loaded = True + if not self.stream: + while self.next() is not None: + pass + self._loaded = True def _check(self, mode=None): """Check if TarFile is still open, and if the operation's mode @@ -2812,7 +3038,7 @@ def main(): import argparse description = 'A simple command-line interface for tarfile module.' - parser = argparse.ArgumentParser(description=description) + parser = argparse.ArgumentParser(description=description, color=True) parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Verbose output') parser.add_argument('--filter', metavar='', @@ -2892,6 +3118,9 @@ def main(): '.tbz': 'bz2', '.tbz2': 'bz2', '.tb2': 'bz2', + # zstd + '.zst': 'zst', + '.tzst': 'zst', } tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w' tar_files = args.create diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py index 75ab8c42e92..a433f0c6193 100644 --- a/Lib/test/test_lzma.py +++ b/Lib/test/test_lzma.py @@ -80,7 +80,6 @@ def test_decompressor_after_eof(self): lzd.decompress(COMPRESSED_XZ) self.assertRaises(EOFError, lzd.decompress, b"nyan") - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument memlimit def test_decompressor_memlimit(self): lzd = LZMADecompressor(memlimit=1024) self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ) @@ -101,7 +100,6 @@ def _test_decompressor(self, lzd, data, check, unused_data=b""): self.assertTrue(lzd.eof) self.assertEqual(lzd.unused_data, unused_data) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_decompressor_auto(self): lzd = LZMADecompressor() self._test_decompressor(lzd, COMPRESSED_XZ, lzma.CHECK_CRC64) @@ -109,37 +107,30 @@ def test_decompressor_auto(self): lzd = LZMADecompressor() self._test_decompressor(lzd, COMPRESSED_ALONE, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_decompressor_xz(self): lzd = LZMADecompressor(lzma.FORMAT_XZ) self._test_decompressor(lzd, COMPRESSED_XZ, lzma.CHECK_CRC64) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_decompressor_alone(self): lzd = LZMADecompressor(lzma.FORMAT_ALONE) self._test_decompressor(lzd, COMPRESSED_ALONE, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_decompressor_raw_1(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_1) self._test_decompressor(lzd, COMPRESSED_RAW_1, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_decompressor_raw_2(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_2) self._test_decompressor(lzd, COMPRESSED_RAW_2, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_decompressor_raw_3(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_3) self._test_decompressor(lzd, COMPRESSED_RAW_3, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_decompressor_raw_4(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self._test_decompressor(lzd, COMPRESSED_RAW_4, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_decompressor_chunks(self): lzd = LZMADecompressor() out = [] @@ -260,14 +251,12 @@ def test_decompressor_inputbuf_3(self): out.append(lzd.decompress(COMPRESSED_XZ[300:])) self.assertEqual(b''.join(out), INPUT) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_decompressor_unused_data(self): lzd = LZMADecompressor() extra = b"fooblibar" self._test_decompressor(lzd, COMPRESSED_XZ + extra, lzma.CHECK_CRC64, unused_data=extra) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_decompressor_bad_input(self): lzd = LZMADecompressor() self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_RAW_1) @@ -281,7 +270,6 @@ def test_decompressor_bad_input(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_1) self.assertRaises(LZMAError, lzd.decompress, COMPRESSED_XZ) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_decompressor_bug_28275(self): # Test coverage for Issue 28275 lzd = LZMADecompressor() @@ -291,28 +279,24 @@ def test_decompressor_bug_28275(self): # Test that LZMACompressor->LZMADecompressor preserves the input data. - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_roundtrip_xz(self): lzc = LZMACompressor() cdata = lzc.compress(INPUT) + lzc.flush() lzd = LZMADecompressor() self._test_decompressor(lzd, cdata, lzma.CHECK_CRC64) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_roundtrip_alone(self): lzc = LZMACompressor(lzma.FORMAT_ALONE) cdata = lzc.compress(INPUT) + lzc.flush() lzd = LZMADecompressor() self._test_decompressor(lzd, cdata, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; lzma.LZMAError: Invalid format def test_roundtrip_raw(self): lzc = LZMACompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) cdata = lzc.compress(INPUT) + lzc.flush() lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self._test_decompressor(lzd, cdata, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; lzma.LZMAError: Invalid format def test_roundtrip_raw_empty(self): lzc = LZMACompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) cdata = lzc.compress(INPUT) @@ -323,7 +307,6 @@ def test_roundtrip_raw_empty(self): lzd = LZMADecompressor(lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self._test_decompressor(lzd, cdata, lzma.CHECK_NONE) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_roundtrip_chunks(self): lzc = LZMACompressor() cdata = [] @@ -334,7 +317,6 @@ def test_roundtrip_chunks(self): lzd = LZMADecompressor() self._test_decompressor(lzd, cdata, lzma.CHECK_CRC64) - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_roundtrip_empty_chunks(self): lzc = LZMACompressor() cdata = [] @@ -350,7 +332,6 @@ def test_roundtrip_empty_chunks(self): # LZMADecompressor intentionally does not handle concatenated streams. - @unittest.expectedFailure # TODO: RUSTPYTHON; AttributeError: 'LZMADecompressor' object has no attribute 'check' def test_decompressor_multistream(self): lzd = LZMADecompressor() self._test_decompressor(lzd, COMPRESSED_XZ + COMPRESSED_ALONE, @@ -441,7 +422,6 @@ def test_bad_args(self): lzma.decompress( b"", format=lzma.FORMAT_ALONE, filters=FILTERS_RAW_1) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: memory limit reached def test_decompress_memlimit(self): with self.assertRaises(LZMAError): lzma.decompress(COMPRESSED_XZ, memlimit=1024) @@ -454,7 +434,6 @@ def test_decompress_memlimit(self): # Test LZMADecompressor on known-good input data. - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_decompress_good_input(self): ddata = lzma.decompress(COMPRESSED_XZ) self.assertEqual(ddata, INPUT) @@ -484,7 +463,6 @@ def test_decompress_good_input(self): COMPRESSED_RAW_4, lzma.FORMAT_RAW, filters=FILTERS_RAW_4) self.assertEqual(ddata, INPUT) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_decompress_incomplete_input(self): self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_XZ[:128]) self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_ALONE[:128]) @@ -497,7 +475,6 @@ def test_decompress_incomplete_input(self): self.assertRaises(LZMAError, lzma.decompress, COMPRESSED_RAW_4[:128], format=lzma.FORMAT_RAW, filters=FILTERS_RAW_4) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_decompress_bad_input(self): with self.assertRaises(LZMAError): lzma.decompress(COMPRESSED_BOGUS) @@ -513,7 +490,6 @@ def test_decompress_bad_input(self): # Test that compress()->decompress() preserves the input data. - @unittest.expectedFailure # TODO: RUSTPYTHON; lzma.LZMAError: Invalid format def test_roundtrip(self): cdata = lzma.compress(INPUT) ddata = lzma.decompress(cdata) @@ -539,12 +515,10 @@ def test_decompress_multistream(self): # Test robust handling of non-LZMA data following the compressed stream(s). - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_decompress_trailing_junk(self): ddata = lzma.decompress(COMPRESSED_XZ + COMPRESSED_BOGUS) self.assertEqual(ddata, INPUT) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_decompress_multistream_trailing_junk(self): ddata = lzma.decompress(COMPRESSED_XZ * 3 + COMPRESSED_BOGUS) self.assertEqual(ddata, INPUT * 3) @@ -840,7 +814,6 @@ def test_writable(self): f.close() self.assertRaises(ValueError, f.writable) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_read(self): with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: self.assertEqual(f.read(), INPUT) @@ -888,7 +861,6 @@ def test_read_10(self): chunks.append(result) self.assertEqual(b"".join(chunks), INPUT) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_read_multistream(self): with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f: self.assertEqual(f.read(), INPUT * 5) @@ -909,12 +881,10 @@ def test_read_multistream_buffer_size_aligned(self): finally: _streams.BUFFER_SIZE = saved_buffer_size - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_read_trailing_junk(self): with LZMAFile(BytesIO(COMPRESSED_XZ + COMPRESSED_BOGUS)) as f: self.assertEqual(f.read(), INPUT) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_read_multistream_trailing_junk(self): with LZMAFile(BytesIO(COMPRESSED_XZ * 5 + COMPRESSED_BOGUS)) as f: self.assertEqual(f.read(), INPUT * 5) @@ -1020,7 +990,6 @@ def test_read_bad_args(self): with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: self.assertRaises(TypeError, f.read, float()) - @unittest.expectedFailure # TODO: RUSTPYTHON; OSError: stream/file format not recognized def test_read_bad_data(self): with LZMAFile(BytesIO(COMPRESSED_BOGUS)) as f: self.assertRaises(LZMAError, f.read) @@ -1078,7 +1047,6 @@ def test_peek_bad_args(self): with LZMAFile(BytesIO(), "w") as f: self.assertRaises(ValueError, f.peek) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_iterator(self): with BytesIO(INPUT) as f: lines = f.readlines() @@ -1118,7 +1086,6 @@ def test_decompress_limited(self): self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, "Excessive amount of data was decompressed") - @unittest.expectedFailure # TODO: RUSTPYTHON; lzma.LZMAError: Invalid format def test_write(self): with BytesIO() as dst: with LZMAFile(dst, "w") as f: @@ -1498,7 +1465,6 @@ def test_bad_params(self): with self.assertRaises(ValueError): lzma.open(TESTFN, "rb", newline="\n") - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Expected type 'int' but 'list' found. def test_format_and_filters(self): # Test non-default format and filter chain. options = {"format": lzma.FORMAT_RAW, "filters": FILTERS_RAW_1} @@ -1587,7 +1553,6 @@ def test__encode_filter_properties(self): }) self.assertEqual(props, b"]\x00\x00\x80\x00") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: LZMAError not raised def test__decode_filter_properties(self): with self.assertRaises(TypeError): lzma._decode_filter_properties(lzma.FILTER_X86, {"should be": bytes}) @@ -1611,7 +1576,6 @@ def test__decode_filter_properties(self): filterspec = lzma._decode_filter_properties(f, b"") self.assertEqual(filterspec, {"id": f}) - @unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: expected at most 0 arguments, got 1 def test_filter_properties_roundtrip(self): spec1 = lzma._decode_filter_properties( lzma.FILTER_LZMA1, b"]\x00\x00\x80\x00") diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index 7a921d569a7..d6cbb350428 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -1,18 +1,25 @@ +import errno import sys import os import io from hashlib import sha256 -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from random import Random import pathlib +import shutil +import re +import warnings +import stat import unittest import unittest.mock import tarfile +from test import archiver_tests from test import support from test.support import os_helper from test.support import script_helper +from test.support import warnings_helper # Check for our compression modules. try: @@ -33,18 +40,24 @@ lzma = None # XXX: RUSTPYTHON; xz is not supported yet lzma = None +try: + from compression import zstd +except ImportError: + zstd = None def sha256sum(data): return sha256(data).hexdigest() TEMPDIR = os.path.abspath(os_helper.TESTFN) + "-tardir" tarextdir = TEMPDIR + '-extract-test' -tarname = support.findfile("testtar.tar") +tarname = support.findfile("testtar.tar", subdir="archivetestdata") gzipname = os.path.join(TEMPDIR, "testtar.tar.gz") bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2") xzname = os.path.join(TEMPDIR, "testtar.tar.xz") +zstname = os.path.join(TEMPDIR, "testtar.tar.zst") tmpname = os.path.join(TEMPDIR, "tmp.tar") dotlessname = os.path.join(TEMPDIR, "testtar") +SPACE = b" " sha256_regtype = ( "e09e4bc8b3c9d9177e77256353b36c159f5f040531bbd4b024a8f9b9196c71ce" @@ -85,6 +98,12 @@ class LzmaTest: open = lzma.LZMAFile if lzma else None taropen = tarfile.TarFile.xzopen +@support.requires_zstd() +class ZstdTest: + tarname = zstname + suffix = 'zst' + open = zstd.ZstdFile if zstd else None + taropen = tarfile.TarFile.zstopen class ReadTest(TarTest): @@ -97,6 +116,14 @@ def setUp(self): def tearDown(self): self.tar.close() +class StreamModeTest(ReadTest): + + # Only needs to change how the tarfile is opened to set + # stream mode + def setUp(self): + self.tar = tarfile.open(self.tarname, mode=self.mode, + encoding="iso8859-1", + stream=True) class UstarReadTest(ReadTest, unittest.TestCase): @@ -110,7 +137,7 @@ def test_fileobj_regular_file(self): "regular file extraction failed") def test_fileobj_readlines(self): - self.tar.extract("ustar/regtype", TEMPDIR) + self.tar.extract("ustar/regtype", TEMPDIR, filter='data') tarinfo = self.tar.getmember("ustar/regtype") with open(os.path.join(TEMPDIR, "ustar/regtype"), "r") as fobj1: lines1 = fobj1.readlines() @@ -128,7 +155,7 @@ def test_fileobj_readlines(self): "fileobj.readlines() failed") def test_fileobj_iter(self): - self.tar.extract("ustar/regtype", TEMPDIR) + self.tar.extract("ustar/regtype", TEMPDIR, filter='data') tarinfo = self.tar.getmember("ustar/regtype") with open(os.path.join(TEMPDIR, "ustar/regtype"), "r") as fobj1: lines1 = fobj1.readlines() @@ -138,7 +165,8 @@ def test_fileobj_iter(self): "fileobj.__iter__() failed") def test_fileobj_seek(self): - self.tar.extract("ustar/regtype", TEMPDIR) + self.tar.extract("ustar/regtype", TEMPDIR, + filter='data') with open(os.path.join(TEMPDIR, "ustar/regtype"), "rb") as fobj: data = fobj.read() @@ -227,13 +255,19 @@ def test_add_dir_getmember(self): self.add_dir_and_getmember('bar') self.add_dir_and_getmember('a'*101) + @unittest.skipUnless(hasattr(os, "getuid") and hasattr(os, "getgid"), + "Missing getuid or getgid implementation") def add_dir_and_getmember(self, name): + def filter(tarinfo): + tarinfo.uid = tarinfo.gid = 100 + return tarinfo + with os_helper.temp_cwd(): with tarfile.open(tmpname, 'w') as tar: tar.format = tarfile.USTAR_FORMAT try: os.mkdir(name) - tar.add(name) + tar.add(name, filter=filter) finally: os.rmdir(name) with tarfile.open(tmpname) as tar: @@ -251,6 +285,8 @@ class Bz2UstarReadTest(Bz2Test, UstarReadTest): class LzmaUstarReadTest(LzmaTest, UstarReadTest): pass +class ZstdUstarReadTest(ZstdTest, UstarReadTest): + pass class ListTest(ReadTest, unittest.TestCase): @@ -304,11 +340,23 @@ def test_list_verbose(self): # accessories if verbose flag is being used # ... # ?rw-r--r-- tarfile/tarfile 7011 2003-01-06 07:19:43 ustar/conttype - # ?rw-r--r-- tarfile/tarfile 7011 2003-01-06 07:19:43 ustar/regtype + # -rw-r--r-- tarfile/tarfile 7011 2003-01-06 07:19:43 ustar/regtype + # drwxr-xr-x tarfile/tarfile 0 2003-01-05 15:19:43 ustar/dirtype/ # ... - self.assertRegex(out, (br'\?rw-r--r-- tarfile/tarfile\s+7011 ' - br'\d{4}-\d\d-\d\d\s+\d\d:\d\d:\d\d ' - br'ustar/\w+type ?\r?\n') * 2) + # + # Array of values to modify the regex below: + # ((file_type, file_permissions, file_length), ...) + type_perm_lengths = ( + (br'\?', b'rw-r--r--', b'7011'), (b'-', b'rw-r--r--', b'7011'), + (b'd', b'rwxr-xr-x', b'0'), (b'd', b'rwxr-xr-x', b'255'), + (br'\?', b'rw-r--r--', b'0'), (b'l', b'rwxrwxrwx', b'0'), + (b'b', b'rw-rw----', b'3,0'), (b'c', b'rw-rw-rw-', b'1,3'), + (b'p', b'rw-r--r--', b'0')) + self.assertRegex(out, b''.join( + [(tp + (br'%s tarfile/tarfile\s+%s ' % (perm, ln) + + br'\d{4}-\d\d-\d\d\s+\d\d:\d\d:\d\d ' + br'ustar/\w+type[/>\sa-z-]*\n')) for tp, perm, ln + in type_perm_lengths])) # Make sure it prints the source of link with verbose flag self.assertIn(b'ustar/symtype -> regtype', out) self.assertIn(b'./ustar/linktest2/symtype -> ../linktest1/regtype', out) @@ -343,6 +391,8 @@ class Bz2ListTest(Bz2Test, ListTest): class LzmaListTest(LzmaTest, ListTest): pass +class ZstdListTest(ZstdTest, ListTest): + pass class CommonReadTest(ReadTest): @@ -354,7 +404,7 @@ def test_is_tarfile_erroneous(self): self.assertFalse(tarfile.is_tarfile(tmpname)) # is_tarfile works on path-like objects - self.assertFalse(tarfile.is_tarfile(pathlib.Path(tmpname))) + self.assertFalse(tarfile.is_tarfile(os_helper.FakePath(tmpname))) # is_tarfile works on file objects with open(tmpname, "rb") as fobj: @@ -368,7 +418,7 @@ def test_is_tarfile_valid(self): self.assertTrue(tarfile.is_tarfile(self.tarname)) # is_tarfile works on path-like objects - self.assertTrue(tarfile.is_tarfile(pathlib.Path(self.tarname))) + self.assertTrue(tarfile.is_tarfile(os_helper.FakePath(self.tarname))) # is_tarfile works on file objects with open(self.tarname, "rb") as fobj: @@ -378,6 +428,18 @@ def test_is_tarfile_valid(self): with open(self.tarname, "rb") as fobj: self.assertTrue(tarfile.is_tarfile(io.BytesIO(fobj.read()))) + def test_is_tarfile_keeps_position(self): + # Test for issue44289: tarfile.is_tarfile() modifies + # file object's current position + with open(self.tarname, "rb") as fobj: + tarfile.is_tarfile(fobj) + self.assertEqual(fobj.tell(), 0) + + with open(self.tarname, "rb") as fobj: + file_like = io.BytesIO(fobj.read()) + tarfile.is_tarfile(file_like) + self.assertEqual(file_like.tell(), 0) + def test_empty_tarfile(self): # Test for issue6123: Allow opening empty archives. # This test checks if tarfile.open() is able to open an empty tar @@ -451,7 +513,7 @@ def test_premature_end_of_archive(self): t = tar.next() with self.assertRaisesRegex(tarfile.ReadError, "unexpected end of data"): - tar.extract(t, TEMPDIR) + tar.extract(t, TEMPDIR, filter='data') with self.assertRaisesRegex(tarfile.ReadError, "unexpected end of data"): tar.extractfile(t).read() @@ -460,15 +522,39 @@ def test_length_zero_header(self): # bpo-39017 (CVE-2019-20907): reading a zero-length header should fail # with an exception with self.assertRaisesRegex(tarfile.ReadError, "file could not be opened successfully"): - with tarfile.open(support.findfile('recursion.tar')) as tar: + with tarfile.open(support.findfile('recursion.tar', subdir='archivetestdata')): pass + def test_extractfile_attrs(self): + # gh-74468: TarFile.name must name a file, not a parent archive. + file = self.tar.getmember('ustar/regtype') + with self.tar.extractfile(file) as fobj: + self.assertEqual(fobj.name, 'ustar/regtype') + self.assertRaises(AttributeError, fobj.fileno) + self.assertEqual(fobj.mode, 'rb') + self.assertIs(fobj.readable(), True) + self.assertIs(fobj.writable(), False) + if self.is_stream: + self.assertRaises(AttributeError, fobj.seekable) + else: + self.assertIs(fobj.seekable(), True) + self.assertIs(fobj.closed, False) + self.assertIs(fobj.closed, True) + self.assertEqual(fobj.name, 'ustar/regtype') + self.assertRaises(AttributeError, fobj.fileno) + self.assertEqual(fobj.mode, 'rb') + self.assertIs(fobj.readable(), True) + self.assertIs(fobj.writable(), False) + if self.is_stream: + self.assertRaises(AttributeError, fobj.seekable) + else: + self.assertIs(fobj.seekable(), True) + + class MiscReadTestBase(CommonReadTest): - def requires_name_attribute(self): - pass + is_stream = False def test_no_name_argument(self): - self.requires_name_attribute() with open(self.tarname, "rb") as fobj: self.assertIsInstance(fobj.name, str) with tarfile.open(fileobj=fobj, mode=self.mode) as tar: @@ -501,7 +587,6 @@ def test_int_name_attribute(self): self.assertIsNone(tar.name) def test_bytes_name_attribute(self): - self.requires_name_attribute() tarname = os.fsencode(self.tarname) with open(tarname, 'rb') as fobj: self.assertIsInstance(fobj.name, bytes) @@ -509,21 +594,23 @@ def test_bytes_name_attribute(self): self.assertIsInstance(tar.name, bytes) self.assertEqual(tar.name, os.path.abspath(fobj.name)) - def test_pathlike_name(self): - tarname = pathlib.Path(self.tarname) + def test_pathlike_name(self, tarname=None): + if tarname is None: + tarname = self.tarname + expected = os.path.abspath(tarname) + tarname = os_helper.FakePath(tarname) with tarfile.open(tarname, mode=self.mode) as tar: - self.assertIsInstance(tar.name, str) - self.assertEqual(tar.name, os.path.abspath(os.fspath(tarname))) + self.assertEqual(tar.name, expected) with self.taropen(tarname) as tar: - self.assertIsInstance(tar.name, str) - self.assertEqual(tar.name, os.path.abspath(os.fspath(tarname))) + self.assertEqual(tar.name, expected) with tarfile.TarFile.open(tarname, mode=self.mode) as tar: - self.assertIsInstance(tar.name, str) - self.assertEqual(tar.name, os.path.abspath(os.fspath(tarname))) + self.assertEqual(tar.name, expected) if self.suffix == '': with tarfile.TarFile(tarname, mode='r') as tar: - self.assertIsInstance(tar.name, str) - self.assertEqual(tar.name, os.path.abspath(os.fspath(tarname))) + self.assertEqual(tar.name, expected) + + def test_pathlike_bytes_name(self): + self.test_pathlike_name(os.fsencode(self.tarname)) def test_illegal_mode_arg(self): with open(tmpname, 'wb'): @@ -606,21 +693,22 @@ def test_find_members(self): def test_extract_hardlink(self): # Test hardlink extraction (e.g. bug #857297). with tarfile.open(tarname, errorlevel=1, encoding="iso8859-1") as tar: - tar.extract("ustar/regtype", TEMPDIR) + tar.extract("ustar/regtype", TEMPDIR, filter='data') self.addCleanup(os_helper.unlink, os.path.join(TEMPDIR, "ustar/regtype")) - tar.extract("ustar/lnktype", TEMPDIR) + tar.extract("ustar/lnktype", TEMPDIR, filter='data') self.addCleanup(os_helper.unlink, os.path.join(TEMPDIR, "ustar/lnktype")) with open(os.path.join(TEMPDIR, "ustar/lnktype"), "rb") as f: data = f.read() self.assertEqual(sha256sum(data), sha256_regtype) - tar.extract("ustar/symtype", TEMPDIR) + tar.extract("ustar/symtype", TEMPDIR, filter='data') self.addCleanup(os_helper.unlink, os.path.join(TEMPDIR, "ustar/symtype")) with open(os.path.join(TEMPDIR, "ustar/symtype"), "rb") as f: data = f.read() self.assertEqual(sha256sum(data), sha256_regtype) + @os_helper.skip_unless_working_chmod def test_extractall(self): # Test if extractall() correctly restores directory permissions # and times (see issue1735). @@ -629,13 +717,14 @@ def test_extractall(self): os.mkdir(DIR) try: directories = [t for t in tar if t.isdir()] - tar.extractall(DIR, directories) + tar.extractall(DIR, directories, filter='fully_trusted') for tarinfo in directories: path = os.path.join(DIR, tarinfo.name) if sys.platform != "win32": # Win32 has no support for fine grained permissions. self.assertEqual(tarinfo.mode & 0o777, - os.stat(path).st_mode & 0o777) + os.stat(path).st_mode & 0o777, + tarinfo.name) def format_mtime(mtime): if isinstance(mtime, float): return "{} ({})".format(mtime, mtime.hex()) @@ -651,6 +740,25 @@ def format_mtime(mtime): tar.close() os_helper.rmtree(DIR) + @staticmethod + def test_extractall_default_filter(): + # Test that the default filter is now "data", and the other filter types are not used. + DIR = pathlib.Path(TEMPDIR) / "extractall_default_filter" + with ( + os_helper.temp_dir(DIR), + tarfile.open(tarname, encoding="iso8859-1") as tar, + unittest.mock.patch("tarfile.data_filter", wraps=tarfile.data_filter) as mock_data_filter, + unittest.mock.patch("tarfile.tar_filter", wraps=tarfile.tar_filter) as mock_tar_filter, + unittest.mock.patch("tarfile.fully_trusted_filter", wraps=tarfile.fully_trusted_filter) as mock_ft_filter + ): + directories = [t for t in tar if t.isdir()] + tar.extractall(DIR, directories) + + mock_data_filter.assert_called() + mock_ft_filter.assert_not_called() + mock_tar_filter.assert_not_called() + + @os_helper.skip_unless_working_chmod def test_extract_directory(self): dirtype = "ustar/dirtype" DIR = os.path.join(TEMPDIR, "extractdir") @@ -658,7 +766,7 @@ def test_extract_directory(self): try: with tarfile.open(tarname, encoding="iso8859-1") as tar: tarinfo = tar.getmember(dirtype) - tar.extract(tarinfo, path=DIR) + tar.extract(tarinfo, path=DIR, filter='fully_trusted') extracted = os.path.join(DIR, dirtype) self.assertEqual(os.path.getmtime(extracted), tarinfo.mtime) if sys.platform != "win32": @@ -666,24 +774,24 @@ def test_extract_directory(self): finally: os_helper.rmtree(DIR) - def test_extractall_pathlike_name(self): - DIR = pathlib.Path(TEMPDIR) / "extractall" + def test_extractall_pathlike_dir(self): + DIR = os.path.join(TEMPDIR, "extractall") with os_helper.temp_dir(DIR), \ tarfile.open(tarname, encoding="iso8859-1") as tar: directories = [t for t in tar if t.isdir()] - tar.extractall(DIR, directories) + tar.extractall(os_helper.FakePath(DIR), directories, filter='fully_trusted') for tarinfo in directories: - path = DIR / tarinfo.name + path = os.path.join(DIR, tarinfo.name) self.assertEqual(os.path.getmtime(path), tarinfo.mtime) - def test_extract_pathlike_name(self): + def test_extract_pathlike_dir(self): dirtype = "ustar/dirtype" - DIR = pathlib.Path(TEMPDIR) / "extractall" + DIR = os.path.join(TEMPDIR, "extractall") with os_helper.temp_dir(DIR), \ tarfile.open(tarname, encoding="iso8859-1") as tar: tarinfo = tar.getmember(dirtype) - tar.extract(tarinfo, path=DIR) - extracted = DIR / dirtype + tar.extract(tarinfo, path=os_helper.FakePath(DIR), filter='fully_trusted') + extracted = os.path.join(DIR, dirtype) self.assertEqual(os.path.getmtime(extracted), tarinfo.mtime) def test_init_close_fobj(self): @@ -722,6 +830,69 @@ def test_zlib_error_does_not_leak(self): with self.assertRaises(tarfile.ReadError): tarfile.open(self.tarname) + def test_next_on_empty_tarfile(self): + fd = io.BytesIO() + tf = tarfile.open(fileobj=fd, mode="w") + tf.close() + + fd.seek(0) + with tarfile.open(fileobj=fd, mode="r|") as tf: + self.assertEqual(tf.next(), None) + + fd.seek(0) + with tarfile.open(fileobj=fd, mode="r") as tf: + self.assertEqual(tf.next(), None) + + def _setup_symlink_to_target(self, temp_dirpath): + target_filepath = os.path.join(temp_dirpath, "target") + ustar_dirpath = os.path.join(temp_dirpath, "ustar") + hardlink_filepath = os.path.join(ustar_dirpath, "lnktype") + with open(target_filepath, "wb") as f: + f.write(b"target") + os.makedirs(ustar_dirpath) + os.symlink(target_filepath, hardlink_filepath) + return target_filepath, hardlink_filepath + + def _assert_on_file_content(self, filepath, digest): + with open(filepath, "rb") as f: + data = f.read() + self.assertEqual(sha256sum(data), digest) + + @unittest.skipUnless( + hasattr(os, "link"), "Missing hardlink implementation" + ) + @os_helper.skip_unless_symlink + def test_extract_hardlink_on_symlink(self): + """ + This test verifies that extracting a hardlink will not follow an + existing symlink after a FileExistsError on os.link. + """ + with os_helper.temp_dir() as DIR: + target_filepath, hardlink_filepath = self._setup_symlink_to_target(DIR) + with tarfile.open(tarname, encoding="iso8859-1") as tar: + tar.extract("ustar/regtype", DIR, filter="data") + tar.extract("ustar/lnktype", DIR, filter="data") + self._assert_on_file_content(target_filepath, sha256sum(b"target")) + self._assert_on_file_content(hardlink_filepath, sha256_regtype) + + @unittest.skipUnless( + hasattr(os, "link"), "Missing hardlink implementation" + ) + @os_helper.skip_unless_symlink + def test_extractall_hardlink_on_symlink(self): + """ + This test verifies that extracting a hardlink will not follow an + existing symlink after a FileExistsError on os.link. + """ + with os_helper.temp_dir() as DIR: + target_filepath, hardlink_filepath = self._setup_symlink_to_target(DIR) + with tarfile.open(tarname, encoding="iso8859-1") as tar: + tar.extractall( + DIR, members=["ustar/regtype", "ustar/lnktype"], filter="data", + ) + self._assert_on_file_content(target_filepath, sha256sum(b"target")) + self._assert_on_file_content(hardlink_filepath, sha256_regtype) + class MiscReadTest(MiscReadTestBase, unittest.TestCase): test_fail_comp = None @@ -730,17 +901,18 @@ class GzipMiscReadTest(GzipTest, MiscReadTestBase, unittest.TestCase): pass class Bz2MiscReadTest(Bz2Test, MiscReadTestBase, unittest.TestCase): - def requires_name_attribute(self): - self.skipTest("BZ2File have no name attribute") + pass class LzmaMiscReadTest(LzmaTest, MiscReadTestBase, unittest.TestCase): - def requires_name_attribute(self): - self.skipTest("LZMAFile have no name attribute") + pass +class ZstdMiscReadTest(ZstdTest, MiscReadTestBase, unittest.TestCase): + pass class StreamReadTest(CommonReadTest, unittest.TestCase): prefix="r|" + is_stream = True def test_read_through(self): # Issue #11224: A poorly designed _FileInFile.read() method @@ -808,6 +980,27 @@ class Bz2StreamReadTest(Bz2Test, StreamReadTest): class LzmaStreamReadTest(LzmaTest, StreamReadTest): pass +class ZstdStreamReadTest(ZstdTest, StreamReadTest): + pass + +class TarStreamModeReadTest(StreamModeTest, unittest.TestCase): + + def test_stream_mode_no_cache(self): + for _ in self.tar: + pass + self.assertEqual(self.tar.members, []) + +class GzipStreamModeReadTest(GzipTest, TarStreamModeReadTest): + pass + +class Bz2StreamModeReadTest(Bz2Test, TarStreamModeReadTest): + pass + +class LzmaStreamModeReadTest(LzmaTest, TarStreamModeReadTest): + pass + +class ZstdStreamModeReadTest(ZstdTest, TarStreamModeReadTest): + pass class DetectReadTest(TarTest, unittest.TestCase): def _testfunc_file(self, name, mode): @@ -870,6 +1063,25 @@ def test_detect_stream_bz2(self): class LzmaDetectReadTest(LzmaTest, DetectReadTest): pass +class ZstdDetectReadTest(ZstdTest, DetectReadTest): + pass + +class GzipBrokenHeaderCorrectException(GzipTest, unittest.TestCase): + """ + See: https://github.com/python/cpython/issues/107396 + """ + def runTest(self): + f = io.BytesIO( + b'\x1f\x8b' # header + b'\x08' # compression method + b'\x04' # flags + b'\0\0\0\0\0\0' # timestamp, compression data, OS ID + b'\0\x01' # size + b'\0\0\0\0\0' # corrupt data (zeros) + ) + with self.assertRaises(tarfile.ReadError): + tarfile.open(fileobj=f, mode='r|gz') + class MemberReadTest(ReadTest, unittest.TestCase): @@ -1019,7 +1231,7 @@ def test_longname_directory(self): os.mkdir(longdir) tar.add(longdir) finally: - os.rmdir(longdir) + os.rmdir(longdir.rstrip("/")) with tarfile.open(tmpname) as tar: self.assertIsNotNone(tar.getmember(longdir)) self.assertIsNotNone(tar.getmember(longdir.removesuffix('/'))) @@ -1038,7 +1250,7 @@ class GNUReadTest(LongnameTest, ReadTest, unittest.TestCase): # an all platforms, and after that a test that will work only on # platforms/filesystems that prove to support sparse files. def _test_sparse_file(self, name): - self.tar.extract(name, TEMPDIR) + self.tar.extract(name, TEMPDIR, filter='data') filename = os.path.join(TEMPDIR, name) with open(filename, "rb") as fobj: data = fobj.read() @@ -1069,7 +1281,7 @@ def _fs_supports_holes(): # # The function returns False if page size is larger than 4 KiB. # For example, ppc64 uses pages of 64 KiB. - if sys.platform.startswith("linux"): + if sys.platform.startswith(("linux", "android")): # Linux evidentially has 512 byte st_blocks units. name = os.path.join(TEMPDIR, "sparse-test") with open(name, "wb") as fobj: @@ -1128,6 +1340,48 @@ def test_pax_number_fields(self): finally: tar.close() + def test_pax_header_bad_formats(self): + # The fields from the pax header have priority over the + # TarInfo. + pax_header_replacements = ( + b" foo=bar\n", + b"0 \n", + b"1 \n", + b"2 \n", + b"3 =\n", + b"4 =a\n", + b"1000000 foo=bar\n", + b"0 foo=bar\n", + b"-12 foo=bar\n", + b"000000000000000000000000036 foo=bar\n", + ) + pax_headers = {"foo": "bar"} + + for replacement in pax_header_replacements: + with self.subTest(header=replacement): + tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, + encoding="iso8859-1") + try: + t = tarfile.TarInfo() + t.name = "pax" # non-ASCII + t.uid = 1 + t.pax_headers = pax_headers + tar.addfile(t) + finally: + tar.close() + + with open(tmpname, "rb") as f: + data = f.read() + self.assertIn(b"11 foo=bar\n", data) + data = data.replace(b"11 foo=bar\n", replacement) + + with open(tmpname, "wb") as f: + f.truncate() + f.write(data) + + with self.assertRaisesRegex(tarfile.ReadError, r"method tar: ReadError\('invalid header'\)"): + tarfile.open(tmpname, encoding="iso8859-1") + class WriteTestBase(TarTest): # Put all write tests in here that are supposed to be tested @@ -1252,11 +1506,11 @@ def test_ordered_recursion(self): def test_gettarinfo_pathlike_name(self): with tarfile.open(tmpname, self.mode) as tar: - path = pathlib.Path(TEMPDIR) / "file" + path = os.path.join(TEMPDIR, "file") with open(path, "wb") as fobj: fobj.write(b"aaa") - tarinfo = tar.gettarinfo(path) - tarinfo2 = tar.gettarinfo(os.fspath(path)) + tarinfo = tar.gettarinfo(os_helper.FakePath(path)) + tarinfo2 = tar.gettarinfo(path) self.assertIsInstance(tarinfo.name, str) self.assertEqual(tarinfo.name, tarinfo2.name) self.assertEqual(tarinfo.size, 3) @@ -1405,7 +1659,8 @@ def test_extractall_symlinks(self): with tarfile.open(temparchive, errorlevel=2) as tar: # this should not raise OSError: [Errno 17] File exists try: - tar.extractall(path=tempdir) + tar.extractall(path=tempdir, + filter='fully_trusted') except OSError: self.fail("extractall failed with symlinked files") finally: @@ -1449,7 +1704,7 @@ def test_cwd(self): try: for t in tar: if t.name != ".": - self.assertTrue(t.name.startswith("./"), t.name) + self.assertStartsWith(t.name, "./") finally: tar.close() @@ -1463,24 +1718,24 @@ def write(self, data): raise exctype f = BadFile() - with self.assertRaises(exctype): - tar = tarfile.open(tmpname, self.mode, fileobj=f, - format=tarfile.PAX_FORMAT, - pax_headers={'non': 'empty'}) + with ( + warnings_helper.check_no_resource_warning(self), + self.assertRaises(exctype), + ): + tarfile.open(tmpname, self.mode, fileobj=f, + format=tarfile.PAX_FORMAT, + pax_headers={'non': 'empty'}) self.assertFalse(f.closed) + def test_missing_fileobj(self): + with tarfile.open(tmpname, self.mode) as tar: + tarinfo = tar.gettarinfo(tarname) + with self.assertRaises(ValueError): + tar.addfile(tarinfo) -class GzipWriteTest(GzipTest, WriteTest): - # TODO: RUSTPYTHON - def expectedSuccess(test_item): - test_item.__unittest_expecting_failure__ = False - return test_item - # TODO: RUSTPYTHON - if sys.platform == "win32": - @expectedSuccess - def test_cwd(self): - super().test_cwd() +class GzipWriteTest(GzipTest, WriteTest): + pass class Bz2WriteTest(Bz2Test, WriteTest): @@ -1490,6 +1745,8 @@ class Bz2WriteTest(Bz2Test, WriteTest): class LzmaWriteTest(LzmaTest, WriteTest): pass +class ZstdWriteTest(ZstdTest, WriteTest): + pass class StreamWriteTest(WriteTestBase, unittest.TestCase): @@ -1514,6 +1771,10 @@ def test_stream_padding(self): @unittest.skipUnless(sys.platform != "win32" and hasattr(os, "umask"), "Missing umask implementation") + @unittest.skipIf( + support.is_emscripten or support.is_wasi, + "Emscripten's/WASI's umask is a stub." + ) def test_file_mode(self): # Test for issue #8464: Create files with correct # permissions. @@ -1529,6 +1790,16 @@ def test_file_mode(self): finally: os.umask(original_umask) + def test_pathlike_name(self): + expected_name = os.path.abspath(tmpname) + tarpath = os_helper.FakePath(tmpname) + + for func in (tarfile.open, tarfile.TarFile.open): + with self.subTest(): + with func(tarpath, self.mode) as tar: + self.assertEqual(tar.name, expected_name) + os_helper.unlink(tmpname) + class GzipStreamWriteTest(GzipTest, StreamWriteTest): def test_source_directory_not_leaked(self): @@ -1547,6 +1818,78 @@ class Bz2StreamWriteTest(Bz2Test, StreamWriteTest): class LzmaStreamWriteTest(LzmaTest, StreamWriteTest): decompressor = lzma.LZMADecompressor if lzma else None +class ZstdStreamWriteTest(ZstdTest, StreamWriteTest): + decompressor = zstd.ZstdDecompressor if zstd else None + +class _CompressedWriteTest(TarTest): + # This is not actually a standalone test. + # It does not inherit WriteTest because it only makes sense with gz,bz2 + source = (b"And we move to Bristol where they have a special, " + + b"Very Silly candidate") + + def _compressed_tar(self, compresslevel): + fobj = io.BytesIO() + with tarfile.open(tmpname, self.mode, fobj, + compresslevel=compresslevel) as tarfl: + tarfl.addfile(tarfile.TarInfo("foo"), io.BytesIO(self.source)) + return fobj + + def _test_bz2_header(self, compresslevel): + fobj = self._compressed_tar(compresslevel) + self.assertEqual(fobj.getvalue()[0:10], + b"BZh%d1AY&SY" % compresslevel) + + def _test_gz_header(self, compresslevel): + fobj = self._compressed_tar(compresslevel) + self.assertEqual(fobj.getvalue()[:3], b"\x1f\x8b\x08") + +class Bz2CompressWriteTest(Bz2Test, _CompressedWriteTest, unittest.TestCase): + prefix = "w:" + def test_compression_levels(self): + self._test_bz2_header(1) + self._test_bz2_header(5) + self._test_bz2_header(9) + +class Bz2CompressStreamWriteTest(Bz2Test, _CompressedWriteTest, + unittest.TestCase): + prefix = "w|" + def test_compression_levels(self): + self._test_bz2_header(1) + self._test_bz2_header(5) + self._test_bz2_header(9) + +class GzCompressWriteTest(GzipTest, _CompressedWriteTest, unittest.TestCase): + prefix = "w:" + def test_compression_levels(self): + self._test_gz_header(1) + self._test_gz_header(5) + self._test_gz_header(9) + +class GzCompressStreamWriteTest(GzipTest, _CompressedWriteTest, + unittest.TestCase): + prefix = "w|" + def test_compression_levels(self): + self._test_gz_header(1) + self._test_gz_header(5) + self._test_gz_header(9) + +class CompressLevelRaises(unittest.TestCase): + def test_compresslevel_wrong_modes(self): + compresslevel = 5 + fobj = io.BytesIO() + with self.assertRaises(TypeError): + tarfile.open(tmpname, "w:", fobj, compresslevel=compresslevel) + + @support.requires_bz2() + def test_wrong_compresslevels(self): + # BZ2 checks that the compresslevel is in [1,9]. gz does not + fobj = io.BytesIO() + with self.assertRaises(ValueError): + tarfile.open(tmpname, "w:bz2", fobj, compresslevel=0) + with self.assertRaises(ValueError): + tarfile.open(tmpname, "w:bz2", fobj, compresslevel=10) + with self.assertRaises(ValueError): + tarfile.open(tmpname, "w|bz2", fobj, compresslevel=10) class GNUWriteTest(unittest.TestCase): # This testcase checks for correct creation of GNU Longname @@ -1738,10 +2081,10 @@ def test_create_existing_taropen(self): self.assertIn("spameggs42", names[0]) def test_create_pathlike_name(self): - with tarfile.open(pathlib.Path(tmpname), self.mode) as tobj: + with tarfile.open(os_helper.FakePath(tmpname), self.mode) as tobj: self.assertIsInstance(tobj.name, str) self.assertEqual(tobj.name, os.path.abspath(tmpname)) - tobj.add(pathlib.Path(self.file_path)) + tobj.add(os_helper.FakePath(self.file_path)) names = tobj.getnames() self.assertEqual(len(names), 1) self.assertIn('spameggs42', names[0]) @@ -1752,10 +2095,10 @@ def test_create_pathlike_name(self): self.assertIn('spameggs42', names[0]) def test_create_taropen_pathlike_name(self): - with self.taropen(pathlib.Path(tmpname), "x") as tobj: + with self.taropen(os_helper.FakePath(tmpname), "x") as tobj: self.assertIsInstance(tobj.name, str) self.assertEqual(tobj.name, os.path.abspath(tmpname)) - tobj.add(pathlib.Path(self.file_path)) + tobj.add(os_helper.FakePath(self.file_path)) names = tobj.getnames() self.assertEqual(len(names), 1) self.assertIn('spameggs42', names[0]) @@ -1793,6 +2136,14 @@ def test_create_with_preset(self): tobj.add(self.file_path) +class ZstdCreateTest(ZstdTest, CreateTest): + + # Unlike gz and bz2, zstd uses the level keyword instead of compresslevel. + # It does not allow for level to be specified when reading. + def test_create_with_level(self): + with tarfile.open(tmpname, self.mode, level=1) as tobj: + tobj.add(self.file_path) + class CreateWithXModeTest(CreateTest): prefix = "x" @@ -2274,6 +2625,8 @@ class Bz2AppendTest(Bz2Test, AppendTestBase, unittest.TestCase): class LzmaAppendTest(LzmaTest, AppendTestBase, unittest.TestCase): pass +class ZstdAppendTest(ZstdTest, AppendTestBase, unittest.TestCase): + pass class LimitsTest(unittest.TestCase): @@ -2414,9 +2767,8 @@ def test__all__(self): 'SubsequentHeaderError', 'ExFileObject', 'main'} support.check__all__(self, tarfile, not_exported=not_exported) - @unittest.expectedFailure # TODO: RUSTPYTHON; FileNotFoundError: [Errno 2] No such file or directory: '/Users/al03219714/Projects/RustPython3/crates/pylib/Lib/test/testtar.tar.xz' def test_useful_error_message_when_modules_missing(self): - fname = os.path.join(os.path.dirname(__file__), 'testtar.tar.xz') + fname = os.path.join(os.path.dirname(__file__), 'archivetestdata', 'testtar.tar.xz') with self.assertRaises(tarfile.ReadError) as excinfo: error = tarfile.CompressionError('lzma module is not available'), with unittest.mock.patch.object(tarfile.TarFile, 'xzopen', side_effect=error): @@ -2427,6 +2779,31 @@ def test_useful_error_message_when_modules_missing(self): str(excinfo.exception), ) + @unittest.skipUnless(os_helper.can_symlink(), 'requires symlink support') + @unittest.skipUnless(hasattr(os, 'chmod'), "missing os.chmod") + @unittest.mock.patch('os.chmod') + def test_deferred_directory_attributes_update(self, mock_chmod): + # Regression test for gh-127987: setting attributes on arbitrary files + tempdir = os.path.join(TEMPDIR, 'test127987') + def mock_chmod_side_effect(path, mode, **kwargs): + target_path = os.path.realpath(path) + if os.path.commonpath([target_path, tempdir]) != tempdir: + raise Exception("should not try to chmod anything outside the destination", target_path) + mock_chmod.side_effect = mock_chmod_side_effect + + outside_tree_dir = os.path.join(TEMPDIR, 'outside_tree_dir') + with ArchiveMaker() as arc: + arc.add('x', symlink_to='.') + arc.add('x', type=tarfile.DIRTYPE, mode='?rwsrwsrwt') + arc.add('x', symlink_to=outside_tree_dir) + + os.makedirs(outside_tree_dir) + try: + arc.open().extractall(path=tempdir, filter='tar') + finally: + os_helper.rmtree(outside_tree_dir) + os_helper.rmtree(tempdir) + class CommandLineTest(unittest.TestCase): @@ -2439,14 +2816,24 @@ def tarfilecmd_failure(self, *args): return script_helper.assert_python_failure('-m', 'tarfile', *args) def make_simple_tarfile(self, tar_name): - files = [support.findfile('tokenize_tests.txt'), + files = [support.findfile('tokenize_tests.txt', + subdir='tokenizedata'), support.findfile('tokenize_tests-no-coding-cookie-' - 'and-utf8-bom-sig-only.txt')] + 'and-utf8-bom-sig-only.txt', + subdir='tokenizedata')] self.addCleanup(os_helper.unlink, tar_name) with tarfile.open(tar_name, 'w') as tf: for tardata in files: tf.add(tardata, arcname=os.path.basename(tardata)) + def make_evil_tarfile(self, tar_name): + self.addCleanup(os_helper.unlink, tar_name) + with tarfile.open(tar_name, 'w') as tf: + benign = tarfile.TarInfo('benign') + tf.addfile(benign, fileobj=io.BytesIO(b'')) + evil = tarfile.TarInfo('../evil') + tf.addfile(evil, fileobj=io.BytesIO(b'')) + def test_bad_use(self): rc, out, err = self.tarfilecmd_failure() self.assertEqual(out, b'') @@ -2471,7 +2858,7 @@ def test_test_command_verbose(self): self.assertIn(b'is a tar archive.\n', out) def test_test_command_invalid_file(self): - zipname = support.findfile('zipdir.zip') + zipname = support.findfile('zipdir.zip', subdir='archivetestdata') rc, out, err = self.tarfilecmd_failure('-t', zipname) self.assertIn(b' is not a tar archive.', err) self.assertEqual(out, b'') @@ -2513,16 +2900,18 @@ def test_list_command_verbose(self): self.assertEqual(out, expected) def test_list_command_invalid_file(self): - zipname = support.findfile('zipdir.zip') + zipname = support.findfile('zipdir.zip', subdir='archivetestdata') rc, out, err = self.tarfilecmd_failure('-l', zipname) self.assertIn(b' is not a tar archive.', err) self.assertEqual(out, b'') self.assertEqual(rc, 1) def test_create_command(self): - files = [support.findfile('tokenize_tests.txt'), + files = [support.findfile('tokenize_tests.txt', + subdir='tokenizedata'), support.findfile('tokenize_tests-no-coding-cookie-' - 'and-utf8-bom-sig-only.txt')] + 'and-utf8-bom-sig-only.txt', + subdir='tokenizedata')] for opt in '-c', '--create': try: out = self.tarfilecmd(opt, tmpname, *files) @@ -2533,9 +2922,11 @@ def test_create_command(self): os_helper.unlink(tmpname) def test_create_command_verbose(self): - files = [support.findfile('tokenize_tests.txt'), + files = [support.findfile('tokenize_tests.txt', + subdir='tokenizedata'), support.findfile('tokenize_tests-no-coding-cookie-' - 'and-utf8-bom-sig-only.txt')] + 'and-utf8-bom-sig-only.txt', + subdir='tokenizedata')] for opt in '-v', '--verbose': try: out = self.tarfilecmd(opt, '-c', tmpname, *files, @@ -2547,7 +2938,7 @@ def test_create_command_verbose(self): os_helper.unlink(tmpname) def test_create_command_dotless_filename(self): - files = [support.findfile('tokenize_tests.txt')] + files = [support.findfile('tokenize_tests.txt', subdir='tokenizedata')] try: out = self.tarfilecmd('-c', dotlessname, *files) self.assertEqual(out, b'') @@ -2558,7 +2949,7 @@ def test_create_command_dotless_filename(self): def test_create_command_dot_started_filename(self): tar_name = os.path.join(TEMPDIR, ".testtar") - files = [support.findfile('tokenize_tests.txt')] + files = [support.findfile('tokenize_tests.txt', subdir='tokenizedata')] try: out = self.tarfilecmd('-c', tar_name, *files) self.assertEqual(out, b'') @@ -2568,10 +2959,12 @@ def test_create_command_dot_started_filename(self): os_helper.unlink(tar_name) def test_create_command_compressed(self): - files = [support.findfile('tokenize_tests.txt'), + files = [support.findfile('tokenize_tests.txt', + subdir='tokenizedata'), support.findfile('tokenize_tests-no-coding-cookie-' - 'and-utf8-bom-sig-only.txt')] - for filetype in (GzipTest, Bz2Test, LzmaTest): + 'and-utf8-bom-sig-only.txt', + subdir='tokenizedata')] + for filetype in (GzipTest, Bz2Test, LzmaTest, ZstdTest): if not filetype.open: continue try: @@ -2603,6 +2996,25 @@ def test_extract_command_verbose(self): finally: os_helper.rmtree(tarextdir) + def test_extract_command_filter(self): + self.make_evil_tarfile(tmpname) + # Make an inner directory, so the member named '../evil' + # is still extracted into `tarextdir` + destdir = os.path.join(tarextdir, 'dest') + os.mkdir(tarextdir) + try: + with os_helper.temp_cwd(destdir): + self.tarfilecmd_failure('-e', tmpname, + '-v', + '--filter', 'data') + out = self.tarfilecmd('-e', tmpname, + '-v', + '--filter', 'fully_trusted', + PYTHONIOENCODING='utf-8') + self.assertIn(b' file is extracted.', out) + finally: + os_helper.rmtree(tarextdir) + def test_extract_command_different_directory(self): self.make_simple_tarfile(tmpname) try: @@ -2613,7 +3025,7 @@ def test_extract_command_different_directory(self): os_helper.rmtree(tarextdir) def test_extract_command_invalid_file(self): - zipname = support.findfile('zipdir.zip') + zipname = support.findfile('zipdir.zip', subdir='archivetestdata') with os_helper.temp_cwd(tarextdir): rc, out, err = self.tarfilecmd_failure('-e', zipname) self.assertIn(b' is not a tar archive.', err) @@ -2686,7 +3098,7 @@ class LinkEmulationTest(ReadTest, unittest.TestCase): # symbolic or hard links tarfile tries to extract these types of members # as the regular files they point to. def _test_link_extraction(self, name): - self.tar.extract(name, TEMPDIR) + self.tar.extract(name, TEMPDIR, filter='fully_trusted') with open(os.path.join(TEMPDIR, name), "rb") as f: data = f.read() self.assertEqual(sha256sum(data), sha256_regtype) @@ -2818,8 +3230,10 @@ def test_extract_with_numeric_owner(self, mock_geteuid, mock_chmod, mock_chown): with self._setup_test(mock_geteuid) as (tarfl, filename_1, _, filename_2): - tarfl.extract(filename_1, TEMPDIR, numeric_owner=True) - tarfl.extract(filename_2 , TEMPDIR, numeric_owner=True) + tarfl.extract(filename_1, TEMPDIR, numeric_owner=True, + filter='fully_trusted') + tarfl.extract(filename_2 , TEMPDIR, numeric_owner=True, + filter='fully_trusted') # convert to filesystem paths f_filename_1 = os.path.join(TEMPDIR, filename_1) @@ -2837,7 +3251,8 @@ def test_extractall_with_numeric_owner(self, mock_geteuid, mock_chmod, mock_chown): with self._setup_test(mock_geteuid) as (tarfl, filename_1, dirname_1, filename_2): - tarfl.extractall(TEMPDIR, numeric_owner=True) + tarfl.extractall(TEMPDIR, numeric_owner=True, + filter='fully_trusted') # convert to filesystem paths f_filename_1 = os.path.join(TEMPDIR, filename_1) @@ -2862,7 +3277,8 @@ def test_extractall_with_numeric_owner(self, mock_geteuid, mock_chmod, def test_extract_without_numeric_owner(self, mock_geteuid, mock_chmod, mock_chown): with self._setup_test(mock_geteuid) as (tarfl, filename_1, _, _): - tarfl.extract(filename_1, TEMPDIR, numeric_owner=False) + tarfl.extract(filename_1, TEMPDIR, numeric_owner=False, + filter='fully_trusted') # convert to filesystem paths f_filename_1 = os.path.join(TEMPDIR, filename_1) @@ -2876,6 +3292,1535 @@ def test_keyword_only(self, mock_geteuid): tarfl.extract, filename_1, TEMPDIR, False, True) +class ReplaceTests(ReadTest, unittest.TestCase): + def test_replace_name(self): + member = self.tar.getmember('ustar/regtype') + replaced = member.replace(name='misc/other') + self.assertEqual(replaced.name, 'misc/other') + self.assertEqual(member.name, 'ustar/regtype') + self.assertEqual(self.tar.getmember('ustar/regtype').name, + 'ustar/regtype') + + def test_replace_deep(self): + member = self.tar.getmember('pax/regtype1') + replaced = member.replace() + replaced.pax_headers['gname'] = 'not-bar' + self.assertEqual(member.pax_headers['gname'], 'bar') + self.assertEqual( + self.tar.getmember('pax/regtype1').pax_headers['gname'], 'bar') + + def test_replace_shallow(self): + member = self.tar.getmember('pax/regtype1') + replaced = member.replace(deep=False) + replaced.pax_headers['gname'] = 'not-bar' + self.assertEqual(member.pax_headers['gname'], 'not-bar') + self.assertEqual( + self.tar.getmember('pax/regtype1').pax_headers['gname'], 'not-bar') + + def test_replace_all(self): + member = self.tar.getmember('ustar/regtype') + for attr_name in ('name', 'mtime', 'mode', 'linkname', + 'uid', 'gid', 'uname', 'gname'): + with self.subTest(attr_name=attr_name): + replaced = member.replace(**{attr_name: None}) + self.assertEqual(getattr(replaced, attr_name), None) + self.assertNotEqual(getattr(member, attr_name), None) + + def test_replace_internal(self): + member = self.tar.getmember('ustar/regtype') + with self.assertRaises(TypeError): + member.replace(offset=123456789) + + +class NoneInfoExtractTests(ReadTest): + # These mainly check that all kinds of members are extracted successfully + # if some metadata is None. + # Some of the methods do additional spot checks. + + # We also test that the default filters can deal with None. + + extraction_filter = None + + @classmethod + def setUpClass(cls): + tar = tarfile.open(tarname, mode='r', encoding="iso8859-1") + cls.control_dir = pathlib.Path(TEMPDIR) / "extractall_ctrl" + tar.errorlevel = 0 + with ExitStack() as cm: + if cls.extraction_filter is None: + cm.enter_context(warnings.catch_warnings( + action="ignore", category=DeprecationWarning)) + tar.extractall(cls.control_dir, filter=cls.extraction_filter) + tar.close() + cls.control_paths = set( + p.relative_to(cls.control_dir) + for p in pathlib.Path(cls.control_dir).glob('**/*')) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.control_dir) + + def check_files_present(self, directory): + got_paths = set( + p.relative_to(directory) + for p in pathlib.Path(directory).glob('**/*')) + if self.extraction_filter in (None, 'data'): + # The 'data' filter is expected to reject special files + for path in 'ustar/fifotype', 'ustar/blktype', 'ustar/chrtype': + got_paths.discard(pathlib.Path(path)) + self.assertEqual(self.control_paths, got_paths) + + @contextmanager + def extract_with_none(self, *attr_names): + DIR = pathlib.Path(TEMPDIR) / "extractall_none" + self.tar.errorlevel = 0 + for member in self.tar.getmembers(): + for attr_name in attr_names: + setattr(member, attr_name, None) + with os_helper.temp_dir(DIR): + self.tar.extractall(DIR, filter='fully_trusted') + self.check_files_present(DIR) + yield DIR + + def test_extractall_none_mtime(self): + # mtimes of extracted files should be later than 'now' -- the mtime + # of a previously created directory. + now = pathlib.Path(TEMPDIR).stat().st_mtime + with self.extract_with_none('mtime') as DIR: + for path in pathlib.Path(DIR).glob('**/*'): + with self.subTest(path=path): + try: + mtime = path.stat().st_mtime + except OSError: + # Some systems can't stat symlinks, ignore those + if not path.is_symlink(): + raise + else: + self.assertGreaterEqual(path.stat().st_mtime, now) + + def test_extractall_none_mode(self): + # modes of directories and regular files should match the mode + # of a "normally" created directory or regular file + dir_mode = pathlib.Path(TEMPDIR).stat().st_mode + regular_file = pathlib.Path(TEMPDIR) / 'regular_file' + regular_file.write_text('') + regular_file_mode = regular_file.stat().st_mode + with self.extract_with_none('mode') as DIR: + for path in pathlib.Path(DIR).glob('**/*'): + with self.subTest(path=path): + if path.is_dir(): + self.assertEqual(path.stat().st_mode, dir_mode) + elif path.is_file(): + self.assertEqual(path.stat().st_mode, + regular_file_mode) + + def test_extractall_none_uid(self): + with self.extract_with_none('uid'): + pass + + def test_extractall_none_gid(self): + with self.extract_with_none('gid'): + pass + + def test_extractall_none_uname(self): + with self.extract_with_none('uname'): + pass + + def test_extractall_none_gname(self): + with self.extract_with_none('gname'): + pass + + def test_extractall_none_ownership(self): + with self.extract_with_none('uid', 'gid', 'uname', 'gname'): + pass + +class NoneInfoExtractTests_Data(NoneInfoExtractTests, unittest.TestCase): + extraction_filter = 'data' + +class NoneInfoExtractTests_FullyTrusted(NoneInfoExtractTests, + unittest.TestCase): + extraction_filter = 'fully_trusted' + +class NoneInfoExtractTests_Tar(NoneInfoExtractTests, unittest.TestCase): + extraction_filter = 'tar' + +class NoneInfoExtractTests_Default(NoneInfoExtractTests, + unittest.TestCase): + extraction_filter = None + +class NoneInfoTests_Misc(unittest.TestCase): + def test_add(self): + # When addfile() encounters None metadata, it raises a ValueError + bio = io.BytesIO() + for tarformat in (tarfile.USTAR_FORMAT, tarfile.GNU_FORMAT, + tarfile.PAX_FORMAT): + with self.subTest(tarformat=tarformat): + tar = tarfile.open(fileobj=bio, mode='w', format=tarformat) + tarinfo = tar.gettarinfo(tarname) + try: + with open(tarname, 'rb') as f: + tar.addfile(tarinfo, f) + except Exception: + if tarformat == tarfile.USTAR_FORMAT: + # In the old, limited format, adding might fail for + # reasons like the UID being too large + pass + else: + raise + else: + for attr_name in ('mtime', 'mode', 'uid', 'gid', + 'uname', 'gname'): + with self.subTest(attr_name=attr_name): + replaced = tarinfo.replace(**{attr_name: None}) + with self.assertRaisesRegex(ValueError, + f"{attr_name}"): + with open(tarname, 'rb') as f: + tar.addfile(replaced, f) + + def test_list(self): + # Change some metadata to None, then compare list() output + # word-for-word. We want list() to not raise, and to only change + # printout for the affected piece of metadata. + # (n.b.: some contents of the test archive are hardcoded.) + for attr_names in ({'mtime'}, {'mode'}, {'uid'}, {'gid'}, + {'uname'}, {'gname'}, + {'uid', 'uname'}, {'gid', 'gname'}): + with (self.subTest(attr_names=attr_names), + tarfile.open(tarname, encoding="iso8859-1") as tar): + tio_prev = io.TextIOWrapper(io.BytesIO(), 'ascii', newline='\n') + with support.swap_attr(sys, 'stdout', tio_prev): + tar.list() + for member in tar.getmembers(): + for attr_name in attr_names: + setattr(member, attr_name, None) + tio_new = io.TextIOWrapper(io.BytesIO(), 'ascii', newline='\n') + with support.swap_attr(sys, 'stdout', tio_new): + tar.list() + for expected, got in zip(tio_prev.detach().getvalue().split(), + tio_new.detach().getvalue().split()): + if attr_names == {'mtime'} and re.match(rb'2003-01-\d\d', expected): + self.assertEqual(got, b'????-??-??') + elif attr_names == {'mtime'} and re.match(rb'\d\d:\d\d:\d\d', expected): + self.assertEqual(got, b'??:??:??') + elif attr_names == {'mode'} and re.match( + rb'.([r-][w-][x-]){3}', expected): + self.assertEqual(got, b'??????????') + elif attr_names == {'uname'} and expected.startswith( + (b'tarfile/', b'lars/', b'foo/')): + exp_user, exp_group = expected.split(b'/') + got_user, got_group = got.split(b'/') + self.assertEqual(got_group, exp_group) + self.assertRegex(got_user, b'[0-9]+') + elif attr_names == {'gname'} and expected.endswith( + (b'/tarfile', b'/users', b'/bar')): + exp_user, exp_group = expected.split(b'/') + got_user, got_group = got.split(b'/') + self.assertEqual(got_user, exp_user) + self.assertRegex(got_group, b'[0-9]+') + elif attr_names == {'uid'} and expected.startswith( + (b'1000/')): + exp_user, exp_group = expected.split(b'/') + got_user, got_group = got.split(b'/') + self.assertEqual(got_group, exp_group) + self.assertEqual(got_user, b'None') + elif attr_names == {'gid'} and expected.endswith((b'/100')): + exp_user, exp_group = expected.split(b'/') + got_user, got_group = got.split(b'/') + self.assertEqual(got_user, exp_user) + self.assertEqual(got_group, b'None') + elif attr_names == {'uid', 'uname'} and expected.startswith( + (b'tarfile/', b'lars/', b'foo/', b'1000/')): + exp_user, exp_group = expected.split(b'/') + got_user, got_group = got.split(b'/') + self.assertEqual(got_group, exp_group) + self.assertEqual(got_user, b'None') + elif attr_names == {'gname', 'gid'} and expected.endswith( + (b'/tarfile', b'/users', b'/bar', b'/100')): + exp_user, exp_group = expected.split(b'/') + got_user, got_group = got.split(b'/') + self.assertEqual(got_user, exp_user) + self.assertEqual(got_group, b'None') + else: + # In other cases the output should be the same + self.assertEqual(expected, got) + +def _filemode_to_int(mode): + """Inverse of `stat.filemode` (for permission bits) + + Using mode strings rather than numbers makes the later tests more readable. + """ + str_mode = mode[1:] + result = ( + {'r': stat.S_IRUSR, '-': 0}[str_mode[0]] + | {'w': stat.S_IWUSR, '-': 0}[str_mode[1]] + | {'x': stat.S_IXUSR, '-': 0, + 's': stat.S_IXUSR | stat.S_ISUID, + 'S': stat.S_ISUID}[str_mode[2]] + | {'r': stat.S_IRGRP, '-': 0}[str_mode[3]] + | {'w': stat.S_IWGRP, '-': 0}[str_mode[4]] + | {'x': stat.S_IXGRP, '-': 0, + 's': stat.S_IXGRP | stat.S_ISGID, + 'S': stat.S_ISGID}[str_mode[5]] + | {'r': stat.S_IROTH, '-': 0}[str_mode[6]] + | {'w': stat.S_IWOTH, '-': 0}[str_mode[7]] + | {'x': stat.S_IXOTH, '-': 0, + 't': stat.S_IXOTH | stat.S_ISVTX, + 'T': stat.S_ISVTX}[str_mode[8]] + ) + # check we did this right + assert stat.filemode(result)[1:] == mode[1:] + + return result + +class ArchiveMaker: + """Helper to create a tar file with specific contents + + Usage: + + with ArchiveMaker() as t: + t.add('filename', ...) + + with t.open() as tar: + ... # `tar` is now a TarFile with 'filename' in it! + """ + def __init__(self, **kwargs): + self.bio = io.BytesIO() + self.tar_kwargs = dict(kwargs) + + def __enter__(self): + self.tar_w = tarfile.TarFile(mode='w', fileobj=self.bio, **self.tar_kwargs) + return self + + def __exit__(self, *exc): + self.tar_w.close() + self.contents = self.bio.getvalue() + self.bio = None + + def add(self, name, *, type=None, symlink_to=None, hardlink_to=None, + mode=None, size=None, content=None, **kwargs): + """Add a member to the test archive. Call within `with`. + + Provides many shortcuts: + - default `type` is based on symlink_to, hardlink_to, and trailing `/` + in name (which is stripped) + - size & content defaults are based on each other + - content can be str or bytes + - mode should be textual ('-rwxrwxrwx') + + (add more! this is unstable internal test-only API) + """ + name = str(name) + tarinfo = tarfile.TarInfo(name).replace(**kwargs) + if content is not None: + if isinstance(content, str): + content = content.encode() + size = len(content) + if size is not None: + tarinfo.size = size + if content is None: + content = bytes(tarinfo.size) + if mode: + tarinfo.mode = _filemode_to_int(mode) + if symlink_to is not None: + type = tarfile.SYMTYPE + tarinfo.linkname = str(symlink_to) + if hardlink_to is not None: + type = tarfile.LNKTYPE + tarinfo.linkname = str(hardlink_to) + if name.endswith('/') and type is None: + type = tarfile.DIRTYPE + if type is not None: + tarinfo.type = type + if tarinfo.isreg(): + fileobj = io.BytesIO(content) + else: + fileobj = None + self.tar_w.addfile(tarinfo, fileobj) + + def open(self, **kwargs): + """Open the resulting archive as TarFile. Call after `with`.""" + bio = io.BytesIO(self.contents) + return tarfile.open(fileobj=bio, **kwargs) + +# Under WASI, `os_helper.can_symlink` is False to make +# `skip_unless_symlink` skip symlink tests. " +# But in the following tests we use can_symlink to *determine* which +# behavior is expected. +# Like other symlink tests, skip these on WASI for now. +if support.is_wasi: + def symlink_test(f): + return unittest.skip("WASI: Skip symlink test for now")(f) +else: + def symlink_test(f): + return f + + +class TestExtractionFilters(unittest.TestCase): + + # A temporary directory for the extraction results. + # All files that "escape" the destination path should still end + # up in this directory. + outerdir = pathlib.Path(TEMPDIR) / 'outerdir' + + # The destination for the extraction, within `outerdir` + destdir = outerdir / 'dest' + + @contextmanager + def check_context(self, tar, filter, *, check_flag=True): + """Extracts `tar` to `self.destdir` and allows checking the result + + If an error occurs, it must be checked using `expect_exception` + + Otherwise, all resulting files must be checked using `expect_file`, + except the destination directory itself and parent directories of + other files. + When checking directories, do so before their contents. + + A file called 'flag' is made in outerdir (i.e. outside destdir) + before extraction; it should not be altered nor should its contents + be read/copied. + """ + with os_helper.temp_dir(self.outerdir): + flag_path = self.outerdir / 'flag' + flag_path.write_text('capture me') + try: + tar.extractall(self.destdir, filter=filter) + except Exception as exc: + self.raised_exception = exc + self.reraise_exception = True + self.expected_paths = set() + else: + self.raised_exception = None + self.reraise_exception = False + self.expected_paths = set(self.outerdir.glob('**/*')) + self.expected_paths.discard(self.destdir) + self.expected_paths.discard(flag_path) + try: + yield self + finally: + tar.close() + if self.reraise_exception: + raise self.raised_exception + self.assertEqual(self.expected_paths, set()) + if check_flag: + self.assertEqual(flag_path.read_text(), 'capture me') + else: + assert filter == 'fully_trusted' + + def expect_file(self, name, type=None, symlink_to=None, mode=None, + size=None, content=None): + """Check a single file. See check_context.""" + if self.raised_exception: + raise self.raised_exception + # use normpath() rather than resolve() so we don't follow symlinks + path = pathlib.Path(os.path.normpath(self.destdir / name)) + self.assertIn(path, self.expected_paths) + self.expected_paths.remove(path) + if mode is not None and os_helper.can_chmod() and os.name != 'nt': + got = stat.filemode(stat.S_IMODE(path.stat().st_mode)) + self.assertEqual(got, mode) + if type is None and isinstance(name, str) and name.endswith('/'): + type = tarfile.DIRTYPE + if symlink_to is not None: + got = (self.destdir / name).readlink() + expected = pathlib.Path(symlink_to) + # The symlink might be the same (textually) as what we expect, + # but some systems change the link to an equivalent path, so + # we fall back to samefile(). + try: + if expected != got: + self.assertTrue(got.samefile(expected)) + except Exception as e: + # attach a note, so it's shown even if `samefile` fails + e.add_note(f'{expected=}, {got=}') + raise + elif type == tarfile.REGTYPE or type is None: + self.assertTrue(path.is_file()) + elif type == tarfile.DIRTYPE: + self.assertTrue(path.is_dir()) + elif type == tarfile.FIFOTYPE: + self.assertTrue(path.is_fifo()) + elif type == tarfile.SYMTYPE: + self.assertTrue(path.is_symlink()) + else: + raise NotImplementedError(type) + if size is not None: + self.assertEqual(path.stat().st_size, size) + if content is not None: + self.assertEqual(path.read_text(), content) + for parent in path.parents: + self.expected_paths.discard(parent) + + def expect_any_tree(self, name): + """Check a directory; forget about its contents.""" + tree_path = (self.destdir / name).resolve() + self.expect_file(tree_path, type=tarfile.DIRTYPE) + self.expected_paths = { + p for p in self.expected_paths + if tree_path not in p.parents + } + + def expect_exception(self, exc_type, message_re='.'): + with self.assertRaisesRegex(exc_type, message_re): + if self.raised_exception is not None: + raise self.raised_exception + self.reraise_exception = False + return self.raised_exception + + def test_benign_file(self): + with ArchiveMaker() as arc: + arc.add('benign.txt') + for filter in 'fully_trusted', 'tar', 'data': + with self.check_context(arc.open(), filter): + self.expect_file('benign.txt') + + def test_absolute(self): + # Test handling a member with an absolute path + # Inspired by 'absolute1' in https://github.com/jwilk/traversal-archives + with ArchiveMaker() as arc: + arc.add(self.outerdir / 'escaped.evil') + + with self.check_context(arc.open(), 'fully_trusted'): + self.expect_file('../escaped.evil') + + for filter in 'tar', 'data': + with self.check_context(arc.open(), filter): + if str(self.outerdir).startswith('/'): + # We strip leading slashes, as e.g. GNU tar does + # (without --absolute-filenames). + outerdir_stripped = str(self.outerdir).lstrip('/') + self.expect_file(f'{outerdir_stripped}/escaped.evil') + else: + # On this system, absolute paths don't have leading + # slashes. + # So, there's nothing to strip. We refuse to unpack + # to an absolute path, nonetheless. + self.expect_exception( + tarfile.AbsolutePathError, + """['"].*escaped.evil['"] has an absolute path""") + + @symlink_test + def test_parent_symlink(self): + # Test interplaying symlinks + # Inspired by 'dirsymlink2a' in jwilk/traversal-archives + with ArchiveMaker() as arc: + + # `current` links to `.` which is both: + # - the destination directory + # - `current` itself + arc.add('current', symlink_to='.') + + # effectively points to ./../ + arc.add('parent', symlink_to='current/..') + + arc.add('parent/evil') + + if os_helper.can_symlink(): + with self.check_context(arc.open(), 'fully_trusted'): + if self.raised_exception is not None: + # Windows will refuse to create a file that's a symlink to itself + # (and tarfile doesn't swallow that exception) + self.expect_exception(FileExistsError) + # The other cases will fail with this error too. + # Skip the rest of this test. + return + else: + self.expect_file('current', symlink_to='.') + self.expect_file('parent', symlink_to='current/..') + self.expect_file('../evil') + + with self.check_context(arc.open(), 'tar'): + self.expect_exception( + tarfile.OutsideDestinationError, + """'parent/evil' would be extracted to ['"].*evil['"], """ + + "which is outside the destination") + + with self.check_context(arc.open(), 'data'): + self.expect_exception( + tarfile.LinkOutsideDestinationError, + """'parent' would link to ['"].*outerdir['"], """ + + "which is outside the destination") + + else: + # No symlink support. The symlinks are ignored. + with self.check_context(arc.open(), 'fully_trusted'): + self.expect_file('parent/evil') + with self.check_context(arc.open(), 'tar'): + self.expect_file('parent/evil') + with self.check_context(arc.open(), 'data'): + self.expect_file('parent/evil') + + @symlink_test + @os_helper.skip_unless_symlink + def test_realpath_limit_attack(self): + # (CVE-2025-4517) + + with ArchiveMaker() as arc: + # populate the symlinks and dirs that expand in os.path.realpath() + # The component length is chosen so that in common cases, the unexpanded + # path fits in PATH_MAX, but it overflows when the final symlink + # is expanded + steps = "abcdefghijklmnop" + if sys.platform == 'win32': + component = 'd' * 25 + elif 'PC_PATH_MAX' in os.pathconf_names: + max_path_len = os.pathconf(self.outerdir.parent, "PC_PATH_MAX") + path_sep_len = 1 + dest_len = len(str(self.destdir)) + path_sep_len + component_len = (max_path_len - dest_len) // (len(steps) + path_sep_len) + component = 'd' * component_len + else: + raise NotImplementedError("Need to guess component length for {sys.platform}") + path = "" + step_path = "" + for i in steps: + arc.add(os.path.join(path, component), type=tarfile.DIRTYPE, + mode='drwxrwxrwx') + arc.add(os.path.join(path, i), symlink_to=component) + path = os.path.join(path, component) + step_path = os.path.join(step_path, i) + # create the final symlink that exceeds PATH_MAX and simply points + # to the top dir. + # this link will never be expanded by + # os.path.realpath(strict=False), nor anything after it. + linkpath = os.path.join(*steps, "l"*254) + parent_segments = [".."] * len(steps) + arc.add(linkpath, symlink_to=os.path.join(*parent_segments)) + # make a symlink outside to keep the tar command happy + arc.add("escape", symlink_to=os.path.join(linkpath, "..")) + # use the symlinks above, that are not checked, to create a hardlink + # to a file outside of the destination path + arc.add("flaglink", hardlink_to=os.path.join("escape", "flag")) + # now that we have the hardlink we can overwrite the file + arc.add("flaglink", content='overwrite') + # we can also create new files as well! + arc.add("escape/newfile", content='new') + + with (self.subTest('fully_trusted'), + self.check_context(arc.open(), filter='fully_trusted', + check_flag=False)): + if sys.platform == 'win32': + self.expect_exception((FileNotFoundError, FileExistsError)) + elif self.raised_exception: + # Cannot symlink/hardlink: tarfile falls back to getmember() + self.expect_exception(KeyError) + # Otherwise, this block should never enter. + else: + self.expect_any_tree(component) + self.expect_file('flaglink', content='overwrite') + self.expect_file('../newfile', content='new') + self.expect_file('escape', type=tarfile.SYMTYPE) + self.expect_file('a', symlink_to=component) + + for filter in 'tar', 'data': + with self.subTest(filter), self.check_context(arc.open(), filter=filter): + exc = self.expect_exception((OSError, KeyError)) + if isinstance(exc, OSError): + if sys.platform == 'win32': + # 3: ERROR_PATH_NOT_FOUND + # 5: ERROR_ACCESS_DENIED + # 206: ERROR_FILENAME_EXCED_RANGE + self.assertIn(exc.winerror, (3, 5, 206)) + else: + self.assertEqual(exc.errno, errno.ENAMETOOLONG) + + @symlink_test + def test_parent_symlink2(self): + # Test interplaying symlinks + # Inspired by 'dirsymlink2b' in jwilk/traversal-archives + + # Posix and Windows have different pathname resolution: + # either symlink or a '..' component resolve first. + # Let's see which we are on. + if os_helper.can_symlink(): + testpath = os.path.join(TEMPDIR, 'resolution_test') + os.mkdir(testpath) + + # testpath/current links to `.` which is all of: + # - `testpath` + # - `testpath/current` + # - `testpath/current/current` + # - etc. + os.symlink('.', os.path.join(testpath, 'current')) + + # we'll test where `testpath/current/../file` ends up + with open(os.path.join(testpath, 'current', '..', 'file'), 'w'): + pass + + if os.path.exists(os.path.join(testpath, 'file')): + # Windows collapses 'current\..' to '.' first, leaving + # 'testpath\file' + dotdot_resolves_early = True + elif os.path.exists(os.path.join(testpath, '..', 'file')): + # Posix resolves 'current' to '.' first, leaving + # 'testpath/../file' + dotdot_resolves_early = False + else: + raise AssertionError('Could not determine link resolution') + + with ArchiveMaker() as arc: + + # `current` links to `.` which is both the destination directory + # and `current` itself + arc.add('current', symlink_to='.') + + # `current/parent` is also available as `./parent`, + # and effectively points to `./../` + arc.add('current/parent', symlink_to='..') + + arc.add('parent/evil') + + with self.check_context(arc.open(), 'fully_trusted'): + if os_helper.can_symlink(): + self.expect_file('current', symlink_to='.') + self.expect_file('parent', symlink_to='..') + self.expect_file('../evil') + else: + self.expect_file('current/') + self.expect_file('parent/evil') + + with self.check_context(arc.open(), 'tar'): + if os_helper.can_symlink(): + # Fail when extracting a file outside destination + self.expect_exception( + tarfile.OutsideDestinationError, + "'parent/evil' would be extracted to " + + """['"].*evil['"], which is outside """ + + "the destination") + else: + self.expect_file('current/') + self.expect_file('parent/evil') + + with self.check_context(arc.open(), 'data'): + if os_helper.can_symlink(): + if dotdot_resolves_early: + # Fail when extracting a file outside destination + self.expect_exception( + tarfile.OutsideDestinationError, + "'parent/evil' would be extracted to " + + """['"].*evil['"], which is outside """ + + "the destination") + else: + # Fail as soon as we have a symlink outside the destination + self.expect_exception( + tarfile.LinkOutsideDestinationError, + "'current/parent' would link to " + + """['"].*outerdir['"], which is outside """ + + "the destination") + else: + self.expect_file('current/') + self.expect_file('parent/evil') + + @symlink_test + def test_absolute_symlink(self): + # Test symlink to an absolute path + # Inspired by 'dirsymlink' in jwilk/traversal-archives + with ArchiveMaker() as arc: + arc.add('parent', symlink_to=self.outerdir) + arc.add('parent/evil') + + with self.check_context(arc.open(), 'fully_trusted'): + if os_helper.can_symlink(): + self.expect_file('parent', symlink_to=self.outerdir) + self.expect_file('../evil') + else: + self.expect_file('parent/evil') + + with self.check_context(arc.open(), 'tar'): + if os_helper.can_symlink(): + self.expect_exception( + tarfile.OutsideDestinationError, + "'parent/evil' would be extracted to " + + """['"].*evil['"], which is outside """ + + "the destination") + else: + self.expect_file('parent/evil') + + with self.check_context(arc.open(), 'data'): + self.expect_exception( + tarfile.AbsoluteLinkError, + "'parent' is a link to an absolute path") + + def test_absolute_hardlink(self): + # Test hardlink to an absolute path + # Inspired by 'dirsymlink' in https://github.com/jwilk/traversal-archives + with ArchiveMaker() as arc: + arc.add('parent', hardlink_to=self.outerdir / 'foo') + + with self.check_context(arc.open(), 'fully_trusted'): + self.expect_exception(KeyError, ".*foo. not found") + + with self.check_context(arc.open(), 'tar'): + self.expect_exception(KeyError, ".*foo. not found") + + with self.check_context(arc.open(), 'data'): + self.expect_exception( + tarfile.AbsoluteLinkError, + "'parent' is a link to an absolute path") + + @symlink_test + def test_sly_relative0(self): + # Inspired by 'relative0' in jwilk/traversal-archives + with ArchiveMaker() as arc: + # points to `../../tmp/moo` + arc.add('../moo', symlink_to='..//tmp/moo') + + try: + with self.check_context(arc.open(), filter='fully_trusted'): + if os_helper.can_symlink(): + if isinstance(self.raised_exception, FileExistsError): + # XXX TarFile happens to fail creating a parent + # directory. + # This might be a bug, but fixing it would hurt + # security. + # Note that e.g. GNU `tar` rejects '..' components, + # so you could argue this is an invalid archive and we + # just raise an bad type of exception. + self.expect_exception(FileExistsError) + else: + self.expect_file('../moo', symlink_to='..//tmp/moo') + else: + # The symlink can't be extracted and is ignored + pass + except FileExistsError: + pass + + for filter in 'tar', 'data': + with self.check_context(arc.open(), filter): + self.expect_exception( + tarfile.OutsideDestinationError, + "'../moo' would be extracted to " + + "'.*moo', which is outside " + + "the destination") + + @symlink_test + def test_sly_relative2(self): + # Inspired by 'relative2' in jwilk/traversal-archives + with ArchiveMaker() as arc: + arc.add('tmp/') + arc.add('tmp/../../moo', symlink_to='tmp/../..//tmp/moo') + + with self.check_context(arc.open(), 'fully_trusted'): + self.expect_file('tmp', type=tarfile.DIRTYPE) + if os_helper.can_symlink(): + self.expect_file('../moo', symlink_to='tmp/../../tmp/moo') + + for filter in 'tar', 'data': + with self.check_context(arc.open(), filter): + self.expect_exception( + tarfile.OutsideDestinationError, + "'tmp/../../moo' would be extracted to " + + """['"].*moo['"], which is outside the """ + + "destination") + + @symlink_test + def test_deep_symlink(self): + # Test that symlinks and hardlinks inside a directory + # point to the correct file (`target` of size 3). + # If links aren't supported we get a copy of the file. + with ArchiveMaker() as arc: + arc.add('targetdir/target', size=3) + # a hardlink's linkname is relative to the archive + arc.add('linkdir/hardlink', hardlink_to=os.path.join( + 'targetdir', 'target')) + # a symlink's linkname is relative to the link's directory + arc.add('linkdir/symlink', symlink_to=os.path.join( + '..', 'targetdir', 'target')) + + for filter in 'tar', 'data', 'fully_trusted': + with self.check_context(arc.open(), filter): + self.expect_file('targetdir/target', size=3) + self.expect_file('linkdir/hardlink', size=3) + if os_helper.can_symlink(): + self.expect_file('linkdir/symlink', size=3, + symlink_to='../targetdir/target') + else: + self.expect_file('linkdir/symlink', size=3) + + @symlink_test + def test_chains(self): + # Test chaining of symlinks/hardlinks. + # Symlinks are created before the files they point to. + with ArchiveMaker() as arc: + arc.add('linkdir/symlink', symlink_to='hardlink') + arc.add('symlink2', symlink_to=os.path.join( + 'linkdir', 'hardlink2')) + arc.add('targetdir/target', size=3) + arc.add('linkdir/hardlink', hardlink_to=os.path.join('targetdir', 'target')) + arc.add('linkdir/hardlink2', hardlink_to=os.path.join('linkdir', 'symlink')) + + for filter in 'tar', 'data', 'fully_trusted': + with self.check_context(arc.open(), filter): + self.expect_file('targetdir/target', size=3) + self.expect_file('linkdir/hardlink', size=3) + self.expect_file('linkdir/hardlink2', size=3) + if os_helper.can_symlink(): + self.expect_file('linkdir/symlink', size=3, + symlink_to='hardlink') + self.expect_file('symlink2', size=3, + symlink_to='linkdir/hardlink2') + else: + self.expect_file('linkdir/symlink', size=3) + self.expect_file('symlink2', size=3) + + @symlink_test + def test_sneaky_hardlink_fallback(self): + # (CVE-2025-4330) + # Test that when hardlink extraction falls back to extracting members + # from the archive, the extracted member is (re-)filtered. + with ArchiveMaker() as arc: + # Create a directory structure so the c/escape symlink stays + # inside the path + arc.add("a/t/dummy") + # Create b/ directory + arc.add("b/") + # Point "c" to the bottom of the tree in "a" + arc.add("c", symlink_to=os.path.join("a", "t")) + # link to non-existant location under "a" + arc.add("c/escape", symlink_to=os.path.join("..", "..", + "link_here")) + # Move "c" to point to "b" ("c/escape" no longer exists) + arc.add("c", symlink_to="b") + # Attempt to create a hard link to "c/escape". Since it doesn't + # exist it will attempt to extract "cescape" but at "boom". + arc.add("boom", hardlink_to=os.path.join("c", "escape")) + + with self.check_context(arc.open(), 'data'): + if not os_helper.can_symlink(): + # When 'c/escape' is extracted, 'c' is a regular + # directory, and 'c/escape' *would* point outside + # the destination if symlinks were allowed. + self.expect_exception( + tarfile.LinkOutsideDestinationError) + elif sys.platform == "win32": + # On Windows, 'c/escape' points outside the destination + self.expect_exception(tarfile.LinkOutsideDestinationError) + else: + e = self.expect_exception( + tarfile.LinkFallbackError, + "link 'boom' would be extracted as a copy of " + + "'c/escape', which was rejected") + self.assertIsInstance(e.__cause__, + tarfile.LinkOutsideDestinationError) + for filter in 'tar', 'fully_trusted': + with self.subTest(filter), self.check_context(arc.open(), filter): + if not os_helper.can_symlink(): + self.expect_file("a/t/dummy") + self.expect_file("b/") + self.expect_file("c/") + else: + self.expect_file("a/t/dummy") + self.expect_file("b/") + self.expect_file("a/t/escape", symlink_to='../../link_here') + self.expect_file("boom", symlink_to='../../link_here') + self.expect_file("c", symlink_to='b') + + @symlink_test + def test_exfiltration_via_symlink(self): + # (CVE-2025-4138) + # Test changing symlinks that result in a symlink pointing outside + # the extraction directory, unless prevented by 'data' filter's + # normalization. + with ArchiveMaker() as arc: + arc.add("escape", symlink_to=os.path.join('link', 'link', '..', '..', 'link-here')) + arc.add("link", symlink_to='./') + + for filter in 'tar', 'data', 'fully_trusted': + with self.check_context(arc.open(), filter): + if os_helper.can_symlink(): + self.expect_file("link", symlink_to='./') + if filter == 'data': + self.expect_file("escape", symlink_to='link-here') + else: + self.expect_file("escape", + symlink_to='link/link/../../link-here') + else: + # Nothing is extracted. + pass + + @symlink_test + def test_chmod_outside_dir(self): + # (CVE-2024-12718) + # Test that members used for delayed updates of directory metadata + # are (re-)filtered. + with ArchiveMaker() as arc: + # "pwn" is a veeeery innocent symlink: + arc.add("a/pwn", symlink_to='.') + # But now "pwn" is also a directory, so it's scheduled to have its + # metadata updated later: + arc.add("a/pwn/", mode='drwxrwxrwx') + # Oops, "pwn" is not so innocent any more: + arc.add("a/pwn", symlink_to='x/../') + # Newly created symlink points to the dest dir, + # so it's OK for the "data" filter. + arc.add('a/x', symlink_to=('../')) + # But now "pwn" points outside the dest dir + + for filter in 'tar', 'data', 'fully_trusted': + with self.check_context(arc.open(), filter) as cc: + if not os_helper.can_symlink(): + self.expect_file("a/pwn/") + elif filter == 'data': + self.expect_file("a/x", symlink_to='../') + self.expect_file("a/pwn", symlink_to='.') + else: + self.expect_file("a/x", symlink_to='../') + self.expect_file("a/pwn", symlink_to='x/../') + if sys.platform != "win32": + st_mode = cc.outerdir.stat().st_mode + self.assertNotEqual(st_mode & 0o777, 0o777) + + def test_link_fallback_normalizes(self): + # Make sure hardlink fallbacks work for non-normalized paths for all + # filters + with ArchiveMaker() as arc: + arc.add("dir/") + arc.add("dir/../afile") + arc.add("link1", hardlink_to='dir/../afile') + arc.add("link2", hardlink_to='dir/../dir/../afile') + + for filter in 'tar', 'data', 'fully_trusted': + with self.check_context(arc.open(), filter) as cc: + self.expect_file("dir/") + self.expect_file("afile") + self.expect_file("link1") + self.expect_file("link2") + + def test_modes(self): + # Test how file modes are extracted + # (Note that the modes are ignored on platforms without working chmod) + with ArchiveMaker() as arc: + arc.add('all_bits', mode='?rwsrwsrwt') + arc.add('perm_bits', mode='?rwxrwxrwx') + arc.add('exec_group_other', mode='?rw-rwxrwx') + arc.add('read_group_only', mode='?---r-----') + arc.add('no_bits', mode='?---------') + arc.add('dir/', mode='?---rwsrwt') + arc.add('dir_all_bits/', mode='?rwsrwsrwt') + + # On some systems, setting the uid, gid, and/or sticky bit is a no-ops. + # Check which bits we can set, so we can compare tarfile machinery to + # a simple chmod. + tmp_filename = os.path.join(TEMPDIR, "tmp.file") + with open(tmp_filename, 'w'): + pass + try: + new_mode = (os.stat(tmp_filename).st_mode + | stat.S_ISVTX | stat.S_ISGID | stat.S_ISUID) + try: + os.chmod(tmp_filename, new_mode) + except OSError as exc: + if exc.errno == getattr(errno, "EFTYPE", 0): + # gh-108948: On FreeBSD, regular users cannot set + # the sticky bit. + self.skipTest("chmod() failed with EFTYPE: " + "regular users cannot set sticky bit") + else: + raise + + got_mode = os.stat(tmp_filename).st_mode + _t_file = 't' if (got_mode & stat.S_ISVTX) else 'x' + _suid_file = 's' if (got_mode & stat.S_ISUID) else 'x' + _sgid_file = 's' if (got_mode & stat.S_ISGID) else 'x' + finally: + os.unlink(tmp_filename) + + os.mkdir(tmp_filename) + new_mode = (os.stat(tmp_filename).st_mode + | stat.S_ISVTX | stat.S_ISGID | stat.S_ISUID) + os.chmod(tmp_filename, new_mode) + got_mode = os.stat(tmp_filename).st_mode + _t_dir = 't' if (got_mode & stat.S_ISVTX) else 'x' + _suid_dir = 's' if (got_mode & stat.S_ISUID) else 'x' + _sgid_dir = 's' if (got_mode & stat.S_ISGID) else 'x' + os.rmdir(tmp_filename) + + with self.check_context(arc.open(), 'fully_trusted'): + self.expect_file('all_bits', + mode=f'?rw{_suid_file}rw{_sgid_file}rw{_t_file}') + self.expect_file('perm_bits', mode='?rwxrwxrwx') + self.expect_file('exec_group_other', mode='?rw-rwxrwx') + self.expect_file('read_group_only', mode='?---r-----') + self.expect_file('no_bits', mode='?---------') + self.expect_file('dir/', mode=f'?---rw{_sgid_dir}rw{_t_dir}') + self.expect_file('dir_all_bits/', + mode=f'?rw{_suid_dir}rw{_sgid_dir}rw{_t_dir}') + + with self.check_context(arc.open(), 'tar'): + self.expect_file('all_bits', mode='?rwxr-xr-x') + self.expect_file('perm_bits', mode='?rwxr-xr-x') + self.expect_file('exec_group_other', mode='?rw-r-xr-x') + self.expect_file('read_group_only', mode='?---r-----') + self.expect_file('no_bits', mode='?---------') + self.expect_file('dir/', mode='?---r-xr-x') + self.expect_file('dir_all_bits/', mode='?rwxr-xr-x') + + with self.check_context(arc.open(), 'data'): + normal_dir_mode = stat.filemode(stat.S_IMODE( + self.outerdir.stat().st_mode)) + self.expect_file('all_bits', mode='?rwxr-xr-x') + self.expect_file('perm_bits', mode='?rwxr-xr-x') + self.expect_file('exec_group_other', mode='?rw-r--r--') + self.expect_file('read_group_only', mode='?rw-r-----') + self.expect_file('no_bits', mode='?rw-------') + self.expect_file('dir/', mode=normal_dir_mode) + self.expect_file('dir_all_bits/', mode=normal_dir_mode) + + def test_pipe(self): + # Test handling of a special file + with ArchiveMaker() as arc: + arc.add('foo', type=tarfile.FIFOTYPE) + + for filter in 'fully_trusted', 'tar': + with self.check_context(arc.open(), filter): + if hasattr(os, 'mkfifo'): + self.expect_file('foo', type=tarfile.FIFOTYPE) + else: + # The pipe can't be extracted and is skipped. + pass + + with self.check_context(arc.open(), 'data'): + self.expect_exception( + tarfile.SpecialFileError, + "'foo' is a special file") + + def test_special_files(self): + # Creating device files is tricky. Instead of attempting that let's + # only check the filter result. + for special_type in tarfile.FIFOTYPE, tarfile.CHRTYPE, tarfile.BLKTYPE: + tarinfo = tarfile.TarInfo('foo') + tarinfo.type = special_type + trusted = tarfile.fully_trusted_filter(tarinfo, '') + self.assertIs(trusted, tarinfo) + tar = tarfile.tar_filter(tarinfo, '') + self.assertEqual(tar.type, special_type) + with self.assertRaises(tarfile.SpecialFileError) as cm: + tarfile.data_filter(tarinfo, '') + self.assertIsInstance(cm.exception.tarinfo, tarfile.TarInfo) + self.assertEqual(cm.exception.tarinfo.name, 'foo') + + def test_fully_trusted_filter(self): + # The 'fully_trusted' filter returns the original TarInfo objects. + with tarfile.TarFile.open(tarname) as tar: + for tarinfo in tar.getmembers(): + filtered = tarfile.fully_trusted_filter(tarinfo, '') + self.assertIs(filtered, tarinfo) + + def test_tar_filter(self): + # The 'tar' filter returns TarInfo objects with the same name/type. + # (It can also fail for particularly "evil" input, but we don't have + # that in the test archive.) + with tarfile.TarFile.open(tarname, encoding="iso8859-1") as tar: + for tarinfo in tar.getmembers(): + try: + filtered = tarfile.tar_filter(tarinfo, '') + except UnicodeEncodeError: + continue + self.assertIs(filtered.name, tarinfo.name) + self.assertIs(filtered.type, tarinfo.type) + + def test_data_filter(self): + # The 'data' filter either raises, or returns TarInfo with the same + # name/type. + with tarfile.TarFile.open(tarname, encoding="iso8859-1") as tar: + for tarinfo in tar.getmembers(): + try: + filtered = tarfile.data_filter(tarinfo, '') + except (tarfile.FilterError, UnicodeEncodeError): + continue + self.assertIs(filtered.name, tarinfo.name) + self.assertIs(filtered.type, tarinfo.type) + + @unittest.skipIf(sys.platform == 'win32', 'requires native bytes paths') + def test_filter_unencodable(self): + # Sanity check using a valid path. + tarinfo = tarfile.TarInfo(os_helper.TESTFN) + filtered = tarfile.tar_filter(tarinfo, '') + self.assertIs(filtered.name, tarinfo.name) + filtered = tarfile.data_filter(tarinfo, '') + self.assertIs(filtered.name, tarinfo.name) + + tarinfo = tarfile.TarInfo('test\x00') + self.assertRaises(ValueError, tarfile.tar_filter, tarinfo, '') + self.assertRaises(ValueError, tarfile.data_filter, tarinfo, '') + tarinfo = tarfile.TarInfo('\ud800') + self.assertRaises(UnicodeEncodeError, tarfile.tar_filter, tarinfo, '') + self.assertRaises(UnicodeEncodeError, tarfile.data_filter, tarinfo, '') + + @unittest.skipIf(sys.platform == 'win32', 'requires native bytes paths') + def test_extract_unencodable(self): + # Create a member with name \xed\xa0\x80 which is UTF-8 encoded + # lone surrogate \ud800. + with ArchiveMaker(encoding='ascii', errors='surrogateescape') as arc: + arc.add('\udced\udca0\udc80') + with os_helper.temp_cwd() as tmp: + tar = arc.open(encoding='utf-8', errors='surrogatepass', + errorlevel=1) + self.assertEqual(tar.getnames(), ['\ud800']) + with self.assertRaises(UnicodeEncodeError): + tar.extractall() + self.assertEqual(os.listdir(), []) + + tar = arc.open(encoding='utf-8', errors='surrogatepass', + errorlevel=0, debug=1) + with support.captured_stderr() as stderr: + tar.extractall() + self.assertEqual(os.listdir(), []) + self.assertIn('tarfile: UnicodeEncodeError ', stderr.getvalue()) + + def test_change_default_filter_on_instance(self): + tar = tarfile.TarFile(tarname, 'r') + def strict_filter(tarinfo, path): + if tarinfo.name == 'ustar/regtype': + return tarinfo + else: + return None + tar.extraction_filter = strict_filter + with self.check_context(tar, None): + self.expect_file('ustar/regtype') + + def test_change_default_filter_on_class(self): + def strict_filter(tarinfo, path): + if tarinfo.name == 'ustar/regtype': + return tarinfo + else: + return None + tar = tarfile.TarFile(tarname, 'r') + with support.swap_attr(tarfile.TarFile, 'extraction_filter', + staticmethod(strict_filter)): + with self.check_context(tar, None): + self.expect_file('ustar/regtype') + + def test_change_default_filter_on_subclass(self): + class TarSubclass(tarfile.TarFile): + def extraction_filter(self, tarinfo, path): + if tarinfo.name == 'ustar/regtype': + return tarinfo + else: + return None + + tar = TarSubclass(tarname, 'r') + with self.check_context(tar, None): + self.expect_file('ustar/regtype') + + def test_change_default_filter_to_string(self): + tar = tarfile.TarFile(tarname, 'r') + tar.extraction_filter = 'data' + with self.check_context(tar, None): + self.expect_exception(TypeError) + + def test_custom_filter(self): + def custom_filter(tarinfo, path): + self.assertIs(path, self.destdir) + if tarinfo.name == 'move_this': + return tarinfo.replace(name='moved') + if tarinfo.name == 'ignore_this': + return None + return tarinfo + + with ArchiveMaker() as arc: + arc.add('move_this') + arc.add('ignore_this') + arc.add('keep') + with self.check_context(arc.open(), custom_filter): + self.expect_file('moved') + self.expect_file('keep') + + def test_bad_filter_name(self): + with ArchiveMaker() as arc: + arc.add('foo') + with self.check_context(arc.open(), 'bad filter name'): + self.expect_exception(ValueError) + + def test_stateful_filter(self): + # Stateful filters should be possible. + # (This doesn't really test tarfile. Rather, it demonstrates + # that third parties can implement a stateful filter.) + class StatefulFilter: + def __enter__(self): + self.num_files_processed = 0 + return self + + def __call__(self, tarinfo, path): + try: + tarinfo = tarfile.data_filter(tarinfo, path) + except tarfile.FilterError: + return None + self.num_files_processed += 1 + return tarinfo + + def __exit__(self, *exc_info): + self.done = True + + with ArchiveMaker() as arc: + arc.add('good') + arc.add('bad', symlink_to='/') + arc.add('good') + with StatefulFilter() as custom_filter: + with self.check_context(arc.open(), custom_filter): + self.expect_file('good') + self.assertEqual(custom_filter.num_files_processed, 2) + self.assertEqual(custom_filter.done, True) + + def test_errorlevel(self): + def extracterror_filter(tarinfo, path): + raise tarfile.ExtractError('failed with ExtractError') + def filtererror_filter(tarinfo, path): + raise tarfile.FilterError('failed with FilterError') + def oserror_filter(tarinfo, path): + raise OSError('failed with OSError') + def tarerror_filter(tarinfo, path): + raise tarfile.TarError('failed with base TarError') + def valueerror_filter(tarinfo, path): + raise ValueError('failed with ValueError') + + with ArchiveMaker() as arc: + arc.add('file') + + # If errorlevel is 0, errors affected by errorlevel are ignored + + with self.check_context(arc.open(errorlevel=0), extracterror_filter): + pass + + with self.check_context(arc.open(errorlevel=0), filtererror_filter): + pass + + with self.check_context(arc.open(errorlevel=0), oserror_filter): + pass + + with self.check_context(arc.open(errorlevel=0), tarerror_filter): + self.expect_exception(tarfile.TarError) + + with self.check_context(arc.open(errorlevel=0), valueerror_filter): + self.expect_exception(ValueError) + + # If 1, all fatal errors are raised + + with self.check_context(arc.open(errorlevel=1), extracterror_filter): + pass + + with self.check_context(arc.open(errorlevel=1), filtererror_filter): + self.expect_exception(tarfile.FilterError) + + with self.check_context(arc.open(errorlevel=1), oserror_filter): + self.expect_exception(OSError) + + with self.check_context(arc.open(errorlevel=1), tarerror_filter): + self.expect_exception(tarfile.TarError) + + with self.check_context(arc.open(errorlevel=1), valueerror_filter): + self.expect_exception(ValueError) + + # If 2, all non-fatal errors are raised as well. + + with self.check_context(arc.open(errorlevel=2), extracterror_filter): + self.expect_exception(tarfile.ExtractError) + + with self.check_context(arc.open(errorlevel=2), filtererror_filter): + self.expect_exception(tarfile.FilterError) + + with self.check_context(arc.open(errorlevel=2), oserror_filter): + self.expect_exception(OSError) + + with self.check_context(arc.open(errorlevel=2), tarerror_filter): + self.expect_exception(tarfile.TarError) + + with self.check_context(arc.open(errorlevel=2), valueerror_filter): + self.expect_exception(ValueError) + + # We only handle ExtractionError, FilterError & OSError specially. + + with self.check_context(arc.open(errorlevel='boo!'), filtererror_filter): + self.expect_exception(TypeError) # errorlevel is not int + + +class OverwriteTests(archiver_tests.OverwriteTests, unittest.TestCase): + testdir = os.path.join(TEMPDIR, "testoverwrite") + + @classmethod + def setUpClass(cls): + p = cls.ar_with_file = os.path.join(TEMPDIR, 'tar-with-file.tar') + cls.addClassCleanup(os_helper.unlink, p) + with tarfile.open(p, 'w') as tar: + t = tarfile.TarInfo('test') + t.size = 10 + tar.addfile(t, io.BytesIO(b'newcontent')) + + p = cls.ar_with_dir = os.path.join(TEMPDIR, 'tar-with-dir.tar') + cls.addClassCleanup(os_helper.unlink, p) + with tarfile.open(p, 'w') as tar: + tar.addfile(tar.gettarinfo(os.curdir, 'test')) + + p = os.path.join(TEMPDIR, 'tar-with-implicit-dir.tar') + cls.ar_with_implicit_dir = p + cls.addClassCleanup(os_helper.unlink, p) + with tarfile.open(p, 'w') as tar: + t = tarfile.TarInfo('test/file') + t.size = 10 + tar.addfile(t, io.BytesIO(b'newcontent')) + + def open(self, path): + return tarfile.open(path, 'r') + + def extractall(self, ar): + ar.extractall(self.testdir, filter='fully_trusted') + + +class OffsetValidationTests(unittest.TestCase): + tarname = tmpname + invalid_posix_header = ( + # name: 100 bytes + tarfile.NUL * tarfile.LENGTH_NAME + # mode, space, null terminator: 8 bytes + + b"000755" + SPACE + tarfile.NUL + # uid, space, null terminator: 8 bytes + + b"000001" + SPACE + tarfile.NUL + # gid, space, null terminator: 8 bytes + + b"000001" + SPACE + tarfile.NUL + # size, space: 12 bytes + + b"\xff" * 11 + SPACE + # mtime, space: 12 bytes + + tarfile.NUL * 11 + SPACE + # chksum: 8 bytes + + b"0011407" + tarfile.NUL + # type: 1 byte + + tarfile.REGTYPE + # linkname: 100 bytes + + tarfile.NUL * tarfile.LENGTH_LINK + # magic: 6 bytes, version: 2 bytes + + tarfile.POSIX_MAGIC + # uname: 32 bytes + + tarfile.NUL * 32 + # gname: 32 bytes + + tarfile.NUL * 32 + # devmajor, space, null terminator: 8 bytes + + tarfile.NUL * 6 + SPACE + tarfile.NUL + # devminor, space, null terminator: 8 bytes + + tarfile.NUL * 6 + SPACE + tarfile.NUL + # prefix: 155 bytes + + tarfile.NUL * tarfile.LENGTH_PREFIX + # padding: 12 bytes + + tarfile.NUL * 12 + ) + invalid_gnu_header = ( + # name: 100 bytes + tarfile.NUL * tarfile.LENGTH_NAME + # mode, null terminator: 8 bytes + + b"0000755" + tarfile.NUL + # uid, null terminator: 8 bytes + + b"0000001" + tarfile.NUL + # gid, space, null terminator: 8 bytes + + b"0000001" + tarfile.NUL + # size, space: 12 bytes + + b"\xff" * 11 + SPACE + # mtime, space: 12 bytes + + tarfile.NUL * 11 + SPACE + # chksum: 8 bytes + + b"0011327" + tarfile.NUL + # type: 1 byte + + tarfile.REGTYPE + # linkname: 100 bytes + + tarfile.NUL * tarfile.LENGTH_LINK + # magic: 8 bytes + + tarfile.GNU_MAGIC + # uname: 32 bytes + + tarfile.NUL * 32 + # gname: 32 bytes + + tarfile.NUL * 32 + # devmajor, null terminator: 8 bytes + + tarfile.NUL * 8 + # devminor, null terminator: 8 bytes + + tarfile.NUL * 8 + # padding: 167 bytes + + tarfile.NUL * 167 + ) + invalid_v7_header = ( + # name: 100 bytes + tarfile.NUL * tarfile.LENGTH_NAME + # mode, space, null terminator: 8 bytes + + b"000755" + SPACE + tarfile.NUL + # uid, space, null terminator: 8 bytes + + b"000001" + SPACE + tarfile.NUL + # gid, space, null terminator: 8 bytes + + b"000001" + SPACE + tarfile.NUL + # size, space: 12 bytes + + b"\xff" * 11 + SPACE + # mtime, space: 12 bytes + + tarfile.NUL * 11 + SPACE + # chksum: 8 bytes + + b"0010070" + tarfile.NUL + # type: 1 byte + + tarfile.REGTYPE + # linkname: 100 bytes + + tarfile.NUL * tarfile.LENGTH_LINK + # padding: 255 bytes + + tarfile.NUL * 255 + ) + valid_gnu_header = tarfile.TarInfo("filename").tobuf(tarfile.GNU_FORMAT) + data_block = b"\xff" * tarfile.BLOCKSIZE + + def _write_buffer(self, buffer): + with open(self.tarname, "wb") as f: + f.write(buffer) + + def _get_members(self, ignore_zeros=None): + with open(self.tarname, "rb") as f: + with tarfile.open( + mode="r", fileobj=f, ignore_zeros=ignore_zeros + ) as tar: + return tar.getmembers() + + def _assert_raises_read_error_exception(self): + with self.assertRaisesRegex( + tarfile.ReadError, "file could not be opened successfully" + ): + self._get_members() + + def test_invalid_offset_header_validations(self): + for tar_format, invalid_header in ( + ("posix", self.invalid_posix_header), + ("gnu", self.invalid_gnu_header), + ("v7", self.invalid_v7_header), + ): + with self.subTest(format=tar_format): + self._write_buffer(invalid_header) + self._assert_raises_read_error_exception() + + def test_early_stop_at_invalid_offset_header(self): + buffer = self.valid_gnu_header + self.invalid_gnu_header + self.valid_gnu_header + self._write_buffer(buffer) + members = self._get_members() + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "filename") + self.assertEqual(members[0].offset, 0) + + def test_ignore_invalid_archive(self): + # 3 invalid headers with their respective data + buffer = (self.invalid_gnu_header + self.data_block) * 3 + self._write_buffer(buffer) + members = self._get_members(ignore_zeros=True) + self.assertEqual(len(members), 0) + + def test_ignore_invalid_offset_headers(self): + for first_block, second_block, expected_offset in ( + ( + (self.valid_gnu_header), + (self.invalid_gnu_header + self.data_block), + 0, + ), + ( + (self.invalid_gnu_header + self.data_block), + (self.valid_gnu_header), + 1024, + ), + ): + self._write_buffer(first_block + second_block) + members = self._get_members(ignore_zeros=True) + self.assertEqual(len(members), 1) + self.assertEqual(members[0].name, "filename") + self.assertEqual(members[0].offset, expected_offset) + + def setUpModule(): os_helper.unlink(TEMPDIR) os.makedirs(TEMPDIR) @@ -2886,7 +4831,7 @@ def setUpModule(): data = fobj.read() # Create compressed tarfiles. - for c in GzipTest, Bz2Test, LzmaTest: + for c in GzipTest, Bz2Test, LzmaTest, ZstdTest: if c.open: os_helper.unlink(c.tarname) testtarnames.append(c.tarname) diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index 3ea68d12bc9..0f16a06e61b 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -139,8 +139,8 @@ pkcs8 = { version = "0.10", features = ["encryption", "pkcs5", "pem"], optional [target.'cfg(not(any(target_os = "android", target_arch = "wasm32")))'.dependencies] libsqlite3-sys = { version = "0.36", features = ["bundled"], optional = true } -lzma-sys = "0.1" -xz2 = "0.1" +liblzma = "0.4" +liblzma-sys = "0.4" [target.'cfg(windows)'.dependencies] paste = { workspace = true } diff --git a/crates/stdlib/src/compression.rs b/crates/stdlib/src/compression.rs index a857b4e53de..8cc48b74782 100644 --- a/crates/stdlib/src/compression.rs +++ b/crates/stdlib/src/compression.rs @@ -296,6 +296,10 @@ impl DecompressState { self.eof } + pub fn decompressor(&self) -> &D { + &self.decompress + } + pub fn unused_data(&self) -> PyBytesRef { self.unused_data.clone() } diff --git a/crates/stdlib/src/lzma.rs b/crates/stdlib/src/lzma.rs index d9f19f0f24e..b98cfb1bf94 100644 --- a/crates/stdlib/src/lzma.rs +++ b/crates/stdlib/src/lzma.rs @@ -1,4 +1,4 @@ -// spell-checker:ignore ARMTHUMB +// spell-checker:ignore ARMTHUMB memlimit pub(crate) use _lzma::module_def; @@ -9,53 +9,63 @@ mod _lzma { DecompressError, DecompressState, DecompressStatus, Decompressor, }; use alloc::fmt; - #[pyattr] - use lzma_sys::{ - LZMA_CHECK_CRC32 as CHECK_CRC32, LZMA_CHECK_CRC64 as CHECK_CRC64, - LZMA_CHECK_NONE as CHECK_NONE, LZMA_CHECK_SHA256 as CHECK_SHA256, + use liblzma::stream::{ + Action, Check, Error, Filters, LzmaOptions, MatchFinder, Mode, Status, Stream, + TELL_ANY_CHECK, TELL_NO_CHECK, }; + // lzma_check, lzma_mode, lzma_match_finder have platform-dependent signedness + // (i32 on Windows, u32 elsewhere). Define as fixed-type const to avoid mismatch. #[pyattr] - use lzma_sys::{ + use liblzma_sys::{ LZMA_FILTER_ARM as FILTER_ARM, LZMA_FILTER_ARMTHUMB as FILTER_ARMTHUMB, - LZMA_FILTER_IA64 as FILTER_IA64, LZMA_FILTER_LZMA1 as FILTER_LZMA1, - LZMA_FILTER_LZMA2 as FILTER_LZMA2, LZMA_FILTER_POWERPC as FILTER_POWERPC, - LZMA_FILTER_SPARC as FILTER_SPARC, LZMA_FILTER_X86 as FILTER_X86, - }; - #[pyattr] - use lzma_sys::{ - LZMA_MF_BT2 as MF_BT2, LZMA_MF_BT3 as MF_BT3, LZMA_MF_BT4 as MF_BT4, LZMA_MF_HC3 as MF_HC3, - LZMA_MF_HC4 as MF_HC4, + LZMA_FILTER_DELTA as FILTER_DELTA, LZMA_FILTER_IA64 as FILTER_IA64, + LZMA_FILTER_LZMA1 as FILTER_LZMA1, LZMA_FILTER_LZMA2 as FILTER_LZMA2, + LZMA_FILTER_POWERPC as FILTER_POWERPC, LZMA_FILTER_SPARC as FILTER_SPARC, + LZMA_FILTER_X86 as FILTER_X86, }; #[pyattr] - use lzma_sys::{LZMA_MODE_FAST as MODE_FAST, LZMA_MODE_NORMAL as MODE_NORMAL}; - #[pyattr] - use lzma_sys::{ + use liblzma_sys::{ LZMA_PRESET_DEFAULT as PRESET_DEFAULT, LZMA_PRESET_EXTREME as PRESET_EXTREME, - LZMA_PRESET_LEVEL_MASK as PRESET_LEVEL_MASK, }; use rustpython_common::lock::PyMutex; - use rustpython_vm::builtins::{PyBaseExceptionRef, PyBytesRef, PyType, PyTypeRef}; + use rustpython_vm::builtins::{PyBaseExceptionRef, PyBytesRef, PyDict, PyType, PyTypeRef}; use rustpython_vm::convert::ToPyException; use rustpython_vm::function::ArgBytesLike; use rustpython_vm::types::Constructor; use rustpython_vm::{Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; - use xz2::stream::{Action, Check, Error, Filters, LzmaOptions, Status, Stream}; - - #[cfg(windows)] - type EnumVal = i32; - #[cfg(not(windows))] - type EnumVal = u32; const BUFSIZ: usize = 8192; - // TODO: can't find this in lzma-sys, but find way not to hardcode this + + // liblzma_sys enum types have platform-dependent signedness; `as _` normalizes to i32 + #[pyattr] + const CHECK_NONE: i32 = liblzma_sys::LZMA_CHECK_NONE as _; #[pyattr] - const FILTER_DELTA: i32 = 3; + const CHECK_CRC32: i32 = liblzma_sys::LZMA_CHECK_CRC32 as _; + #[pyattr] + const CHECK_CRC64: i32 = liblzma_sys::LZMA_CHECK_CRC64 as _; + #[pyattr] + const CHECK_SHA256: i32 = liblzma_sys::LZMA_CHECK_SHA256 as _; #[pyattr] const CHECK_ID_MAX: i32 = 15; #[pyattr] const CHECK_UNKNOWN: i32 = CHECK_ID_MAX + 1; - // the variant ids are hardcoded to be equivalent to the C enum values + #[pyattr] + const MF_HC3: i32 = liblzma_sys::LZMA_MF_HC3 as _; + #[pyattr] + const MF_HC4: i32 = liblzma_sys::LZMA_MF_HC4 as _; + #[pyattr] + const MF_BT2: i32 = liblzma_sys::LZMA_MF_BT2 as _; + #[pyattr] + const MF_BT3: i32 = liblzma_sys::LZMA_MF_BT3 as _; + #[pyattr] + const MF_BT4: i32 = liblzma_sys::LZMA_MF_BT4 as _; + + #[pyattr] + const MODE_FAST: i32 = liblzma_sys::LZMA_MODE_FAST as _; + #[pyattr] + const MODE_NORMAL: i32 = liblzma_sys::LZMA_MODE_NORMAL as _; + enum Format { Auto = 0, Xz = 1, @@ -86,44 +96,106 @@ mod _lzma { vm.new_exception_msg(vm.class("lzma", "LZMAError"), msg.into()) } - #[pyfunction] - fn is_check_supported(check: i32) -> bool { - unsafe { lzma_sys::lzma_check_is_supported(check as _) != 0 } + fn catch_lzma_error(err: Error, vm: &VirtualMachine) -> PyBaseExceptionRef { + match err { + Error::UnsupportedCheck => new_lzma_error("Unsupported integrity check", vm), + Error::Mem => vm.new_memory_error(""), + Error::MemLimit => new_lzma_error("Memory usage limit exceeded", vm), + Error::Format => new_lzma_error("Input format not supported by decoder", vm), + Error::Options => new_lzma_error("Invalid or unsupported options", vm), + Error::Data => new_lzma_error("Corrupt input data", vm), + Error::Program => new_lzma_error("Internal error", vm), + Error::NoCheck => new_lzma_error("Corrupt input data", vm), + } } - // TODO: To implement these we need a function to convert a pyobject to a lzma filter and related structs - #[pyfunction] - fn _encode_filter_properties() -> PyResult<()> { - Ok(()) + fn int_to_check(check: i32) -> Option { + if check == -1 { + return Some(Check::Crc64); + } + match check { + CHECK_NONE => Some(Check::None), + CHECK_CRC32 => Some(Check::Crc32), + CHECK_CRC64 => Some(Check::Crc64), + CHECK_SHA256 => Some(Check::Sha256), + _ => None, + } } - #[pyfunction] - fn _decode_filter_properties(_filter_id: u64, _buffer: ArgBytesLike) -> PyResult<()> { - Ok(()) + fn u32_to_mode(val: u32) -> Option { + match val as i32 { + MODE_FAST => Some(Mode::Fast), + MODE_NORMAL => Some(Mode::Normal), + _ => None, + } } - #[pyattr] - #[pyclass(name = "LZMADecompressor")] - #[derive(PyPayload)] - struct LZMADecompressor { - state: PyMutex>, + fn u32_to_mf(val: u32) -> Option { + match val as i32 { + MF_HC3 => Some(MatchFinder::HashChain3), + MF_HC4 => Some(MatchFinder::HashChain4), + MF_BT2 => Some(MatchFinder::BinaryTree2), + MF_BT3 => Some(MatchFinder::BinaryTree3), + MF_BT4 => Some(MatchFinder::BinaryTree4), + _ => None, + } + } + + struct LzmaStream { + stream: Stream, + check: i32, + header_buf: [u8; 8], + header_collected: u8, + track_header: bool, + } + + impl LzmaStream { + fn new(stream: Stream, check: i32, track_header: bool) -> Self { + Self { + stream, + check, + header_buf: [0u8; 8], + header_collected: 0, + track_header, + } + } } - impl Decompressor for Stream { + impl Decompressor for LzmaStream { type Flush = (); type Status = Status; type Error = Error; fn total_in(&self) -> u64 { - self.total_in() + self.stream.total_in() } + fn decompress_vec( &mut self, input: &[u8], output: &mut Vec, (): Self::Flush, ) -> Result { - self.process_vec(input, output, Action::Run) + if self.track_header && self.header_collected < 8 { + let need = (8 - self.header_collected) as usize; + let n = need.min(input.len()); + self.header_buf[self.header_collected as usize..][..n].copy_from_slice(&input[..n]); + self.header_collected += n as u8; + } + + match self.stream.process_vec(input, output, Action::Run) { + Ok(Status::GetCheck) => { + if self.header_collected >= 8 { + self.check = (self.header_buf[7] & 0x0F) as i32; + } + Ok(Status::Ok) + } + Err(Error::NoCheck) => { + self.check = CHECK_NONE; + Ok(Status::Ok) + } + other => other, + } } } @@ -133,43 +205,426 @@ mod _lzma { } } + fn get_dict_opt_u32( + spec: &PyObjectRef, + key: &str, + vm: &VirtualMachine, + ) -> PyResult> { + let dict = spec.downcast_ref::().ok_or_else(|| { + vm.new_type_error("Filter specifier must be a dict or dict-like object") + })?; + match dict.get_item_opt(key, vm)? { + Some(obj) => Ok(Some(obj.try_into_value::(vm)?)), + None => Ok(None), + } + } + + fn get_dict_opt_u64( + spec: &PyObjectRef, + key: &str, + vm: &VirtualMachine, + ) -> PyResult> { + let dict = spec.downcast_ref::().ok_or_else(|| { + vm.new_type_error("Filter specifier must be a dict or dict-like object") + })?; + match dict.get_item_opt(key, vm)? { + Some(obj) => Ok(Some(obj.try_into_value::(vm)?)), + None => Ok(None), + } + } + + fn parse_filter_spec_lzma(spec: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + let preset = get_dict_opt_u32(spec, "preset", vm)?.unwrap_or(PRESET_DEFAULT); + + let mut opts = LzmaOptions::new_preset(preset) + .map_err(|_| new_lzma_error(format!("Invalid compression preset: {preset}"), vm))?; + + if let Some(v) = get_dict_opt_u32(spec, "dict_size", vm)? { + opts.dict_size(v); + } + if let Some(v) = get_dict_opt_u32(spec, "lc", vm)? { + opts.literal_context_bits(v); + } + if let Some(v) = get_dict_opt_u32(spec, "lp", vm)? { + opts.literal_position_bits(v); + } + if let Some(v) = get_dict_opt_u32(spec, "pb", vm)? { + opts.position_bits(v); + } + if let Some(v) = get_dict_opt_u32(spec, "mode", vm)? { + let mode = u32_to_mode(v) + .ok_or_else(|| vm.new_value_error("Invalid filter specifier for LZMA filter"))?; + opts.mode(mode); + } + if let Some(v) = get_dict_opt_u32(spec, "nice_len", vm)? { + opts.nice_len(v); + } + if let Some(v) = get_dict_opt_u32(spec, "mf", vm)? { + let mf = u32_to_mf(v) + .ok_or_else(|| vm.new_value_error("Invalid filter specifier for LZMA filter"))?; + opts.match_finder(mf); + } + if let Some(v) = get_dict_opt_u32(spec, "depth", vm)? { + opts.depth(v); + } + + Ok(opts) + } + + fn parse_filter_spec_delta(spec: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dist = get_dict_opt_u32(spec, "dist", vm)?.unwrap_or(1); + if dist == 0 || dist > 256 { + return Err(vm.new_value_error("Invalid filter specifier for delta filter")); + } + Ok(dist) + } + + fn parse_filter_spec_bcj(spec: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(get_dict_opt_u32(spec, "start_offset", vm)?.unwrap_or(0)) + } + + fn add_bcj_filter( + filters: &mut Filters, + filter_id: u64, + start_offset: u32, + ) -> Result<(), Error> { + if start_offset == 0 { + match filter_id { + FILTER_X86 => { + filters.x86(); + } + FILTER_POWERPC => { + filters.powerpc(); + } + FILTER_IA64 => { + filters.ia64(); + } + FILTER_ARM => { + filters.arm(); + } + FILTER_ARMTHUMB => { + filters.arm_thumb(); + } + FILTER_SPARC => { + filters.sparc(); + } + _ => unreachable!(), + } + Ok(()) + } else { + let props = start_offset.to_le_bytes(); + match filter_id { + FILTER_X86 => { + filters.x86_properties(&props)?; + } + FILTER_POWERPC => { + filters.powerpc_properties(&props)?; + } + FILTER_IA64 => { + filters.ia64_properties(&props)?; + } + FILTER_ARM => { + filters.arm_properties(&props)?; + } + FILTER_ARMTHUMB => { + filters.arm_thumb_properties(&props)?; + } + FILTER_SPARC => { + filters.sparc_properties(&props)?; + } + _ => unreachable!(), + } + Ok(()) + } + } + + fn parse_filter_chain_spec( + filter_specs: Vec, + vm: &VirtualMachine, + ) -> PyResult { + const LZMA_FILTERS_MAX: usize = 4; + if filter_specs.len() > LZMA_FILTERS_MAX { + return Err(new_lzma_error( + format!("Too many filters - liblzma supports a maximum of {LZMA_FILTERS_MAX}"), + vm, + )); + } + + let mut filters = Filters::new(); + for spec in &filter_specs { + let filter_id = get_dict_opt_u64(spec, "id", vm)? + .ok_or_else(|| vm.new_value_error("Filter specifier must have an \"id\" entry"))?; + + match filter_id { + FILTER_LZMA1 => { + let opts = parse_filter_spec_lzma(spec, vm)?; + filters.lzma1(&opts); + } + FILTER_LZMA2 => { + let opts = parse_filter_spec_lzma(spec, vm)?; + filters.lzma2(&opts); + } + FILTER_DELTA => { + let dist = parse_filter_spec_delta(spec, vm)?; + filters + .delta_properties(&[(dist - 1) as u8]) + .map_err(|e| catch_lzma_error(e, vm))?; + } + FILTER_X86 | FILTER_POWERPC | FILTER_IA64 | FILTER_ARM | FILTER_ARMTHUMB + | FILTER_SPARC => { + let start_offset = parse_filter_spec_bcj(spec, vm)?; + add_bcj_filter(&mut filters, filter_id, start_offset) + .map_err(|e| catch_lzma_error(e, vm))?; + } + _ => { + return Err(vm.new_value_error(format!("Invalid filter ID: {filter_id}"))); + } + } + } + + Ok(filters) + } + + const DEFAULT_LC: u32 = liblzma_sys::LZMA_LC_DEFAULT; + const DEFAULT_LP: u32 = liblzma_sys::LZMA_LP_DEFAULT; + const DEFAULT_PB: u32 = liblzma_sys::LZMA_PB_DEFAULT; + const DICT_POW2: [u8; 10] = [18, 20, 21, 22, 22, 23, 23, 24, 25, 26]; + + fn preset_dict_size(preset: u32) -> u32 { + let level = (preset & liblzma_sys::LZMA_PRESET_LEVEL_MASK) as usize; + if level > 9 { + return 0; + } + 1u32 << DICT_POW2[level] + } + + fn lzma2_dict_size_from_prop(prop: u8) -> u32 { + if prop > 40 { + return u32::MAX; + } + if prop == 40 { + return u32::MAX; + } + let prop = prop as u32; + (2 | (prop & 1)) << (prop / 2 + 11) + } + + fn lzma2_prop_from_dict_size(dict_size: u32) -> u8 { + if dict_size == u32::MAX { + return 40; + } + for i in 0u8..40 { + if lzma2_dict_size_from_prop(i) >= dict_size { + return i; + } + } + 40 + } + + fn encode_lzma1_properties(lc: u32, lp: u32, pb: u32, dict_size: u32) -> Vec { + let mut result = vec![0u8; 5]; + result[0] = ((pb * 5 + lp) * 9 + lc) as u8; + result[1..5].copy_from_slice(&dict_size.to_le_bytes()); + result + } + + fn decode_lzma1_properties(props: &[u8]) -> Option<(u32, u32, u32, u32)> { + if props.len() < 5 { + return None; + } + let mut d = props[0] as u32; + let lc = d % 9; + d /= 9; + let lp = d % 5; + let pb = d / 5; + let dict_size = u32::from_le_bytes([props[1], props[2], props[3], props[4]]); + Some((lc, lp, pb, dict_size)) + } + + fn build_filter_spec( + filter_id: u64, + props: &[u8], + vm: &VirtualMachine, + ) -> PyResult { + let dict = vm.ctx.new_dict(); + dict.set_item("id", vm.new_pyobj(filter_id), vm)?; + + match filter_id { + FILTER_LZMA1 => { + let (lc, lp, pb, dict_size) = decode_lzma1_properties(props) + .ok_or_else(|| new_lzma_error("Invalid or unsupported options", vm))?; + dict.set_item("lc", vm.new_pyobj(lc), vm)?; + dict.set_item("lp", vm.new_pyobj(lp), vm)?; + dict.set_item("pb", vm.new_pyobj(pb), vm)?; + dict.set_item("dict_size", vm.new_pyobj(dict_size), vm)?; + } + FILTER_LZMA2 => { + if props.len() != 1 { + return Err(new_lzma_error("Invalid or unsupported options", vm)); + } + let dict_size = lzma2_dict_size_from_prop(props[0]); + dict.set_item("dict_size", vm.new_pyobj(dict_size), vm)?; + } + FILTER_DELTA => { + if props.len() != 1 { + return Err(new_lzma_error("Invalid or unsupported options", vm)); + } + let dist = props[0] as u32 + 1; + dict.set_item("dist", vm.new_pyobj(dist), vm)?; + } + FILTER_X86 | FILTER_POWERPC | FILTER_IA64 | FILTER_ARM | FILTER_ARMTHUMB + | FILTER_SPARC => { + if props.is_empty() { + // default: no start_offset + } else if props.len() == 4 { + let start_offset = u32::from_le_bytes([props[0], props[1], props[2], props[3]]); + dict.set_item("start_offset", vm.new_pyobj(start_offset), vm)?; + } else { + return Err(new_lzma_error("Invalid or unsupported options", vm)); + } + } + _ => { + return Err(vm.new_value_error(format!("Invalid filter ID: {filter_id}"))); + } + } + + Ok(dict.into()) + } + + #[pyfunction] + fn is_check_supported(check_id: i32) -> bool { + unsafe { liblzma_sys::lzma_check_is_supported(check_id as _) != 0 } + } + + #[pyfunction] + fn _encode_filter_properties( + filter_spec: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + let filter_id = get_dict_opt_u64(&filter_spec, "id", vm)? + .ok_or_else(|| vm.new_value_error("Filter specifier must have an \"id\" entry"))?; + + match filter_id { + FILTER_LZMA1 => { + let preset = + get_dict_opt_u32(&filter_spec, "preset", vm)?.unwrap_or(PRESET_DEFAULT); + let lc = get_dict_opt_u32(&filter_spec, "lc", vm)?.unwrap_or(DEFAULT_LC); + let lp = get_dict_opt_u32(&filter_spec, "lp", vm)?.unwrap_or(DEFAULT_LP); + let pb = get_dict_opt_u32(&filter_spec, "pb", vm)?.unwrap_or(DEFAULT_PB); + let dict_size = get_dict_opt_u32(&filter_spec, "dict_size", vm)? + .unwrap_or_else(|| preset_dict_size(preset)); + Ok(encode_lzma1_properties(lc, lp, pb, dict_size)) + } + FILTER_LZMA2 => { + let preset = + get_dict_opt_u32(&filter_spec, "preset", vm)?.unwrap_or(PRESET_DEFAULT); + let dict_size = get_dict_opt_u32(&filter_spec, "dict_size", vm)? + .unwrap_or_else(|| preset_dict_size(preset)); + Ok(vec![lzma2_prop_from_dict_size(dict_size)]) + } + FILTER_DELTA => { + let dist = get_dict_opt_u32(&filter_spec, "dist", vm)?.unwrap_or(1); + if dist == 0 || dist > 256 { + return Err(vm.new_value_error("Invalid filter specifier for delta filter")); + } + Ok(vec![(dist - 1) as u8]) + } + FILTER_X86 | FILTER_POWERPC | FILTER_IA64 | FILTER_ARM | FILTER_ARMTHUMB + | FILTER_SPARC => { + let start_offset = get_dict_opt_u32(&filter_spec, "start_offset", vm)?.unwrap_or(0); + if start_offset == 0 { + Ok(vec![]) + } else { + Ok(start_offset.to_le_bytes().to_vec()) + } + } + _ => Err(vm.new_value_error(format!("Invalid filter ID: {filter_id}"))), + } + } + + #[pyfunction] + fn _decode_filter_properties( + filter_id: u64, + encoded_props: ArgBytesLike, + vm: &VirtualMachine, + ) -> PyResult { + let props = encoded_props.borrow_buf(); + build_filter_spec(filter_id, &props, vm) + } + + #[pyattr] + #[pyclass(name = "LZMADecompressor")] + #[derive(PyPayload)] + struct LZMADecompressor { + state: PyMutex>, + } + impl fmt::Debug for LZMADecompressor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "_lzma.LZMADecompressor") } } + #[derive(FromArgs)] pub struct LZMADecompressorConstructorArgs { #[pyarg(any, default = FORMAT_AUTO)] format: i32, #[pyarg(any, optional)] - mem_limit: Option, + memlimit: Option, #[pyarg(any, optional)] - filters: Option, + filters: Option>, } impl Constructor for LZMADecompressor { type Args = LZMADecompressorConstructorArgs; fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { - if args.format == FORMAT_RAW && args.mem_limit.is_some() { + if args.format == FORMAT_RAW && args.memlimit.is_some() { return Err(vm.new_value_error("Cannot specify memory limit with FORMAT_RAW")); } - let mem_limit = args.mem_limit.unwrap_or(u64::MAX); - let filters = args.filters.unwrap_or(0); - let stream_result = match args.format { - FORMAT_AUTO => Stream::new_auto_decoder(mem_limit, filters), - FORMAT_XZ => Stream::new_stream_decoder(mem_limit, filters), - FORMAT_ALONE => Stream::new_lzma_decoder(mem_limit), - // TODO: FORMAT_RAW - _ => return Err(new_lzma_error("Invalid format", vm)), + + if args.format == FORMAT_RAW && args.filters.is_none() { + return Err(vm.new_value_error("Must specify filters for FORMAT_RAW")); + } + if args.format != FORMAT_RAW && args.filters.is_some() { + return Err(vm.new_value_error("Cannot specify filters except with FORMAT_RAW")); + } + + let memlimit = args.memlimit.unwrap_or(u64::MAX); + let decoder_flags = TELL_ANY_CHECK | TELL_NO_CHECK; + + let lzma_stream = match args.format { + FORMAT_AUTO => { + let stream = Stream::new_auto_decoder(memlimit, decoder_flags) + .map_err(|e| catch_lzma_error(e, vm))?; + LzmaStream::new(stream, CHECK_UNKNOWN, true) + } + FORMAT_XZ => { + let stream = Stream::new_stream_decoder(memlimit, decoder_flags) + .map_err(|e| catch_lzma_error(e, vm))?; + LzmaStream::new(stream, CHECK_UNKNOWN, true) + } + FORMAT_ALONE => { + let stream = + Stream::new_lzma_decoder(memlimit).map_err(|e| catch_lzma_error(e, vm))?; + LzmaStream::new(stream, CHECK_NONE, false) + } + FORMAT_RAW => { + let filter_specs = args.filters.unwrap(); // safe: checked above + let filters = parse_filter_chain_spec(filter_specs, vm)?; + let stream = + Stream::new_raw_decoder(&filters).map_err(|e| catch_lzma_error(e, vm))?; + LzmaStream::new(stream, CHECK_NONE, false) + } + _ => { + return Err( + vm.new_value_error(format!("Invalid container format: {}", args.format)) + ); + } }; + Ok(Self { - state: PyMutex::new(DecompressState::new( - stream_result - .map_err(|_| new_lzma_error("Failed to initialize decoder", vm))?, - vm, - )), + state: PyMutex::new(DecompressState::new(lzma_stream, vm)), }) } } @@ -185,11 +640,16 @@ mod _lzma { state .decompress(data, max_length, BUFSIZ, vm) .map_err(|e| match e { - DecompressError::Decompress(err) => vm.new_os_error(err.to_string()), + DecompressError::Decompress(err) => catch_lzma_error(err, vm), DecompressError::Eof(err) => err.to_pyexception(vm), }) } + #[pygetset] + fn check(&self) -> i32 { + self.state.lock().decompressor().check + } + #[pygetset] fn eof(&self) -> bool { self.state.lock().eof() @@ -202,12 +662,8 @@ mod _lzma { #[pygetset] fn needs_input(&self) -> bool { - // False if the decompress() method can provide more - // decompressed data before requiring new uncompressed input. self.state.lock().needs_input() } - - // TODO: mro()? } struct CompressorInner { @@ -246,7 +702,7 @@ mod _lzma { ) -> PyResult { self.stream .process_vec(input, output, flush) - .map_err(|_| new_lzma_error("Failed to compress data", vm)) + .map_err(|e| catch_lzma_error(e, vm)) } fn total_in(&mut self) -> usize { @@ -277,36 +733,6 @@ mod _lzma { } } - fn int_to_check(check: i32) -> Option { - if check == -1 { - return Some(Check::None); - } - match check as EnumVal { - CHECK_NONE => Some(Check::None), - CHECK_CRC32 => Some(Check::Crc32), - CHECK_CRC64 => Some(Check::Crc64), - CHECK_SHA256 => Some(Check::Sha256), - _ => None, - } - } - - fn parse_filter_chain_spec( - filter_specs: Vec, - vm: &VirtualMachine, - ) -> PyResult { - // TODO: don't hardcode - const LZMA_FILTERS_MAX: usize = 4; - if filter_specs.len() > LZMA_FILTERS_MAX { - return Err(new_lzma_error( - format!("Too many filters - liblzma supports a maximum of {LZMA_FILTERS_MAX}"), - vm, - )); - } - let filters = Filters::new(); - for _item in filter_specs {} - Ok(filters) - } - impl LZMACompressor { fn init_xz( check: i32, @@ -315,14 +741,13 @@ mod _lzma { vm: &VirtualMachine, ) -> PyResult { let real_check = - int_to_check(check).ok_or_else(|| vm.new_type_error("Invalid check value"))?; - if let Some(filters) = filters { - let filters = parse_filter_chain_spec(filters, vm)?; - Ok(Stream::new_stream_encoder(&filters, real_check) - .map_err(|_| new_lzma_error("Failed to initialize encoder", vm))?) + int_to_check(check).ok_or_else(|| vm.new_value_error("Invalid check value"))?; + if let Some(filter_specs) = filters { + let filters = parse_filter_chain_spec(filter_specs, vm)?; + Stream::new_stream_encoder(&filters, real_check) + .map_err(|e| catch_lzma_error(e, vm)) } else { - Ok(Stream::new_easy_encoder(preset, real_check) - .map_err(|_| new_lzma_error("Failed to initialize encoder", vm))?) + Stream::new_easy_encoder(preset, real_check).map_err(|e| catch_lzma_error(e, vm)) } } @@ -332,65 +757,76 @@ mod _lzma { vm: &VirtualMachine, ) -> PyResult { if let Some(_filter_specs) = filter_specs { - Err(new_lzma_error( - "TODO: RUSTPYTHON: LZMA: Alone filter specs", - vm, - )) + // TODO: validate single LZMA1 filter and use its options + let options = LzmaOptions::new_preset(preset).map_err(|_| { + new_lzma_error(format!("Invalid compression preset: {preset}"), vm) + })?; + Stream::new_lzma_encoder(&options).map_err(|e| catch_lzma_error(e, vm)) } else { - let options = LzmaOptions::new_preset(preset) - .map_err(|_| new_lzma_error("Failed to initialize encoder", vm))?; - let stream = Stream::new_lzma_encoder(&options) - .map_err(|_| new_lzma_error("Failed to initialize encoder", vm))?; - Ok(stream) + let options = LzmaOptions::new_preset(preset).map_err(|_| { + new_lzma_error(format!("Invalid compression preset: {preset}"), vm) + })?; + Stream::new_lzma_encoder(&options).map_err(|e| catch_lzma_error(e, vm)) } } + + fn init_raw( + filter_specs: Option>, + vm: &VirtualMachine, + ) -> PyResult { + let filter_specs = filter_specs + .ok_or_else(|| vm.new_value_error("Must specify filters for FORMAT_RAW"))?; + let filters = parse_filter_chain_spec(filter_specs, vm)?; + Stream::new_raw_encoder(&filters).map_err(|e| catch_lzma_error(e, vm)) + } } #[derive(FromArgs)] pub struct LZMACompressorConstructorArgs { - // format=FORMAT_XZ, check=-1, preset=None, filters=None - // {'format': 3, 'filters': [{'id': 3, 'dist': 2}, {'id': 33, 'preset': 2147483654}]} #[pyarg(any, default = FORMAT_XZ)] format: i32, #[pyarg(any, default = -1)] check: i32, #[pyarg(any, optional)] - preset: Option, + preset: Option, #[pyarg(any, optional)] filters: Option>, - #[pyarg(any, optional)] - _filter_specs: Option>, - #[pyarg(positional, optional)] - preset_obj: Option, } impl Constructor for LZMACompressor { type Args = LZMACompressorConstructorArgs; fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { - let preset = args.preset.unwrap_or(PRESET_DEFAULT); - #[allow(clippy::unnecessary_cast)] - if args.format != FORMAT_XZ as i32 - && args.check != -1 - && args.check != CHECK_NONE as i32 - { + if args.format != FORMAT_XZ && args.check != -1 && args.check != CHECK_NONE { return Err(new_lzma_error( "Integrity checks are only supported by FORMAT_XZ", vm, )); } - if args.preset_obj.is_some() && args.filters.is_some() { + + if args.preset.is_some() && args.filters.is_some() { return Err(new_lzma_error( "Cannot specify both preset and filter chain", vm, )); } + + let preset: u32 = match &args.preset { + Some(obj) => obj.clone().try_into_value(vm)?, + None => PRESET_DEFAULT, + }; + let stream = match args.format { FORMAT_XZ => Self::init_xz(args.check, preset, args.filters, vm)?, FORMAT_ALONE => Self::init_alone(preset, args.filters, vm)?, - // TODO: RAW - _ => return Err(new_lzma_error("Invalid format", vm)), + FORMAT_RAW => Self::init_raw(args.filters, vm)?, + _ => { + return Err( + vm.new_value_error(format!("Invalid container format: {}", args.format)) + ); + } }; + Ok(Self { state: PyMutex::new(CompressState::new(CompressorInner::new(stream))), }) @@ -402,15 +838,12 @@ mod _lzma { #[pymethod] fn compress(&self, data: ArgBytesLike, vm: &VirtualMachine) -> PyResult> { let mut state = self.state.lock(); - // TODO: Flush check state.compress(&data.borrow_buf(), vm) } #[pymethod] fn flush(&self, vm: &VirtualMachine) -> PyResult> { - // TODO: flush check let mut state = self.state.lock(); - // TODO: check if action is correct state.flush(Action::Finish, vm) } }