From fce49fe799ea4ee0950d3a51c333800ce1846e92 Mon Sep 17 00:00:00 2001 From: andrecs <12188364+andrecsilva@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:57:04 -0300 Subject: [PATCH] Refactored url-sandbox codemod --- src/core_codemods/url_sandbox.py | 197 +++++++++++-------------------- 1 file changed, 66 insertions(+), 131 deletions(-) diff --git a/src/core_codemods/url_sandbox.py b/src/core_codemods/url_sandbox.py index 904c9bb3..884d3fc0 100644 --- a/src/core_codemods/url_sandbox.py +++ b/src/core_codemods/url_sandbox.py @@ -1,24 +1,19 @@ -from typing import List, Optional, Union - import libcst as cst -from libcst import CSTNode, matchers from libcst.codemod import CodemodContext, ContextAwareVisitor -from libcst.codemod.visitors import AddImportsVisitor, ImportItem -from libcst.metadata import PositionProvider, ScopeProvider +from libcst.metadata import ( + ImportAssignment, + ParentNodeProvider, + PositionProvider, + ScopeProvider, +) -from codemodder.codemods.base_visitor import UtilsMixin from codemodder.codemods.libcst_transformer import ( LibcstResultTransformer, LibcstTransformerPipeline, ) -from codemodder.codemods.semgrep import SemgrepRuleDetector -from codemodder.codemods.transformations.remove_unused_imports import ( - RemoveUnusedImportsCodemod, -) from codemodder.codemods.utils import ReplaceNodes -from codemodder.codetf import Change +from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin from codemodder.dependency import Security -from codemodder.file_context import FileContext from core_codemods.api import CoreCodemod, Metadata, Reference, ReviewGuidance replacement_import = "safe_requests" @@ -26,125 +21,81 @@ class UrlSandboxTransformer(LibcstResultTransformer): change_description = "Switch use of requests for security.safe_requests" - METADATA_DEPENDENCIES = (PositionProvider, ScopeProvider) adds_dependency = True + METADATA_DEPENDENCIES = ( + ScopeProvider, + PositionProvider, + ParentNodeProvider, + ) + def transform_module_impl(self, tree: cst.Module) -> cst.Module: - # we first gather all the nodes we want to change together with their replacements - find_requests_visitor = FindRequestCallsAndImports( - self.context, - self.file_context, - self.file_context.findings, - ) - tree.visit(find_requests_visitor) - if find_requests_visitor.nodes_to_change: - self.file_context.codemod_changes.extend( - find_requests_visitor.changes_in_file - ) - new_tree = tree.visit(ReplaceNodes(find_requests_visitor.nodes_to_change)) - self.add_dependency(Security) - # if it finds any request.get(...), try to remove the imports - if any( - ( - matchers.matches(n, matchers.Call()) - for n in find_requests_visitor.nodes_to_change - ) - ): - new_tree = AddImportsVisitor( - self.context, - [ImportItem(Security.name, replacement_import, None, 0)], - ).transform_module(new_tree) - new_tree = RemoveUnusedImportsCodemod(self.context).transform_module( - new_tree - ) - return new_tree + visitor = UrlSandboxVisitor(self.context) + tree.visit(visitor) + nodes_to_change = {} + for call, (original, replacement) in visitor.calls_to_change.items(): + if self.node_is_selected(call): + if not isinstance(original, cst.ImportFrom | cst.Import): + self.add_needed_import("security", replacement_import) + self.remove_unused_import(call) + self.report_change(call, self.change_description) + nodes_to_change[original] = replacement + self.add_dependency(Security) + if nodes_to_change: + result = tree.visit(ReplaceNodes(nodes_to_change)) + print(tree.code) + print(result.code) + return result + return tree -class FindRequestCallsAndImports(ContextAwareVisitor, UtilsMixin): - METADATA_DEPENDENCIES = (ScopeProvider,) +class UrlSandboxVisitor(ContextAwareVisitor, NameAndAncestorResolutionMixin): - def __init__( - self, codemod_context: CodemodContext, file_context: FileContext, results - ): - self.nodes_to_change: dict[ - cst.CSTNode, Union[cst.CSTNode, cst.FlattenSentinel, cst.RemovalSentinel] - ] = {} - self.changes_in_file: List[Change] = [] - ContextAwareVisitor.__init__(self, codemod_context) - UtilsMixin.__init__( - self, - results=results, - line_include=file_context.line_include, - line_exclude=file_context.line_exclude, - ) + def __init__(self, context: CodemodContext) -> None: + self.calls_to_change: dict[cst.Call, tuple[cst.CSTNode, cst.CSTNode]] = {} + super().__init__(context) def leave_Call(self, original_node: cst.Call): - if not self.node_is_selected(original_node): - return - - line_number = self.node_position(original_node).start.line - match original_node.args[0].value: - case cst.SimpleString(): - return - - match original_node: - # case get(...) - case cst.Call(func=cst.Name()): - # find if get(...) comes from an from requests import get - match self.find_single_assignment(original_node): - case cst.ImportFrom() as node: - self.nodes_to_change.update( - { - node: cst.ImportFrom( + # is first arg a hardcoded string? + match original_node.args: + case [cst.Arg(value=first_arg), *_]: + resolved_arg = self.resolve_expression(first_arg) + if isinstance(resolved_arg, cst.SimpleString): + return + + resolved = self.resolve_expression(original_node) + true_name = self.find_base_name(resolved) + if true_name in ("requests.get",): + # is it aliased? (i.e. get(...) or requests.get(...)) + match resolved: + case cst.Call(func=cst.Name()): + origin = self.find_single_assignment(resolved) + # sanity check mostly, ImportFrom is the only possibility here + match origin: + case ImportAssignment(node=cst.ImportFrom() as node): + + self.calls_to_change[original_node] = ( + node, + cst.ImportFrom( module=cst.Attribute( value=cst.Name(Security.name), attr=cst.Name(replacement_import), ), names=node.names, - ) - } - ) - self.changes_in_file.append( - Change( - lineNumber=line_number, - description=UrlSandboxTransformer.change_description, + ), ) + case _: + maybe_imported = self.get_imported_prefix(original_node) + if maybe_imported: + self.calls_to_change[original_node] = ( + original_node, + cst.Call( + func=cst.parse_expression(replacement_import + ".get"), + args=original_node.args, + ), ) - # case req.get(...) - case _: - self.nodes_to_change.update( - { - original_node: cst.Call( - func=cst.parse_expression(replacement_import + ".get"), - args=original_node.args, - ) - } - ) - self.changes_in_file.append( - Change( - lineNumber=line_number, - description=UrlSandboxTransformer.change_description, - ) - ) - - def _find_assignments(self, node: CSTNode): - """ - Given a MetadataWrapper and a CSTNode representing an access, find all the possible assignments that it refers. - """ - scope = self.get_metadata(ScopeProvider, node) - return next(iter(scope.accesses[node]))._Access__assignments - - def find_single_assignment(self, node: CSTNode) -> Optional[CSTNode]: - """ - Given a MetadataWrapper and a CSTNode representing an access, find if there is a single assignment that it refers to. - """ - assignments = self._find_assignments(node) - if len(assignments) == 1: - return next(iter(assignments)).node - return None - UrlSandbox = CoreCodemod( metadata=Metadata( @@ -165,22 +116,6 @@ def find_single_assignment(self, node: CSTNode) -> Optional[CSTNode]: Reference(url="https://blog.assetnote.io/2021/01/13/blind-ssrf-chains/"), ], ), - detector=SemgrepRuleDetector( - """ - rules: - - id: url-sandbox - message: Unbounded URL creation - severity: WARNING - languages: - - python - pattern-either: - - patterns: - - pattern: requests.get(...) - - pattern-not: requests.get("...") - - pattern-inside: | - import requests - ... - """ - ), + detector=None, transformer=LibcstTransformerPipeline(UrlSandboxTransformer), )