diff --git a/pyproject.toml b/pyproject.toml index 4bb4240..b917b52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,8 @@ [build-system] -requires = ["setuptools >= 61.0"] +requires = [ + "setuptools >= 61.0", + "requests" +] build-backend = "setuptools.build_meta" [project] @@ -7,7 +10,8 @@ name = "socket-sdk-python" dynamic = ["version"] requires-python = ">= 3.9" dependencies = [ - 'requests' + 'requests', + 'typing-extensions>=4.12.2' ] readme = "README.rst" license = {file = "LICENSE"} @@ -57,6 +61,7 @@ include = [ "socketdev.sbom", "socketdev.settings", "socketdev.tools", + "socketdev.utils", ] [tool.setuptools.dynamic] diff --git a/socketdev/__init__.py b/socketdev/__init__.py index ac1a1e8..440f15a 100644 --- a/socketdev/__init__.py +++ b/socketdev/__init__.py @@ -1,10 +1,7 @@ import logging -import requests -import base64 -from socketdev.core.classes import Response +from socketdev.core.api import API from socketdev.dependencies import Dependencies -from socketdev.exceptions import APIKeyMissing, APIFailure, APIAccessDenied, APIInsufficientQuota, APIResourceNotFound from socketdev.export import Export from socketdev.fullscans import FullScans from socketdev.npm import NPM @@ -17,106 +14,40 @@ from socketdev.repositories import Repositories from socketdev.sbom import Sbom from socketdev.settings import Settings +from socketdev.version import __version__ +from socketdev.utils import Utils, IntegrationType, INTEGRATION_TYPES - +__version__ = __version__ __author__ = "socket.dev" -__version__ = "1.0.14" -__all__ = ["socketdev"] - - -global encoded_key -encoded_key: str +__all__ = ["socketdev", "Utils", "IntegrationType", "INTEGRATION_TYPES"] -api_url = "https://api.socket.dev/v0" -request_timeout = 30 log = logging.getLogger("socketdev") log.addHandler(logging.NullHandler()) -def encode_key(token: str): - global encoded_key - encoded_key = base64.b64encode(token.encode()).decode("ascii") - - -def do_request( - path: str, headers: dict = None, payload: [dict, str] = None, files: list = None, method: str = "GET" -) -> Response: - """ - Shared function for performing the requests against the API. - :param path: String path of the URL - :param headers: Optional dictionary of the headers to include in the request. Defaults to None - :param payload: Optional dictionary or string of the payload to POST. Defaults to None - :param files: Optional list of files to send. Defaults to None - :param method: Optional string of the method for the Request. Defaults to GET - """ - - if encoded_key is None or encoded_key == "": - raise APIKeyMissing - - if headers is None: - headers = { - "Authorization": f"Basic {encoded_key}", - "User-Agent": f"SocketPythonScript/{__version__}", - "accept": "application/json", - } - url = f"{api_url}/{path}" - try: - response = requests.request( - method.upper(), url, headers=headers, data=payload, files=files, timeout=request_timeout - ) - if response.status_code >= 400: - raise APIFailure("Bad Request") - elif response.status_code == 401: - raise APIAccessDenied("Unauthorized") - elif response.status_code == 403: - raise APIInsufficientQuota("Insufficient max_quota for API method") - elif response.status_code == 404: - raise APIResourceNotFound(f"Path not found {path}") - elif response.status_code == 429: - raise APIInsufficientQuota("Insufficient quota for API route") - except Exception as error: - response = Response(text=f"{error}", error=True, status_code=500) - raise APIFailure(response) - return response - - class socketdev: - token: str - timeout: int - dependencies: Dependencies - npm: NPM - openapi: OpenAPI - org: Orgs - quota: Quota - report: Report - sbom: Sbom - purl: Purl - fullscans: FullScans - export: Export - repositories: Repositories - settings: Settings - repos: Repos - def __init__(self, token: str, timeout: int = 30): + self.api = API() self.token = token + ":" - encode_key(self.token) - self.timeout = timeout - socketdev.set_timeout(self.timeout) - self.dependencies = Dependencies() - self.npm = NPM() - self.openapi = OpenAPI() - self.org = Orgs() - self.quota = Quota() - self.report = Report() - self.sbom = Sbom() - self.purl = Purl() - self.fullscans = FullScans() - self.export = Export() - self.repositories = Repositories() - self.repos = Repos() - self.settings = Settings() + self.api.encode_key(self.token) + self.api.set_timeout(timeout) + + self.dependencies = Dependencies(self.api) + self.npm = NPM(self.api) + self.openapi = OpenAPI(self.api) + self.org = Orgs(self.api) + self.quota = Quota(self.api) + self.report = Report(self.api) + self.sbom = Sbom(self.api) + self.purl = Purl(self.api) + self.fullscans = FullScans(self.api) + self.export = Export(self.api) + self.repositories = Repositories(self.api) + self.repos = Repos(self.api) + self.settings = Settings(self.api) + self.utils = Utils() @staticmethod def set_timeout(timeout: int): - global request_timeout - request_timeout = timeout + # Kept for backwards compatibility + pass diff --git a/socketdev/core/api.py b/socketdev/core/api.py new file mode 100644 index 0000000..67ee6f1 --- /dev/null +++ b/socketdev/core/api.py @@ -0,0 +1,53 @@ +import base64 +import requests +from socketdev.core.classes import Response +from socketdev.exceptions import APIKeyMissing, APIFailure, APIAccessDenied, APIInsufficientQuota, APIResourceNotFound +from socketdev.version import __version__ + + +class API: + def __init__(self): + self.encoded_key = None + self.api_url = "https://api.socket.dev/v0" + self.request_timeout = 30 + + def encode_key(self, token: str): + self.encoded_key = base64.b64encode(token.encode()).decode("ascii") + + def set_timeout(self, timeout: int): + self.request_timeout = timeout + + def do_request( + self, path: str, headers: dict | None = None, payload: [dict, str] = None, files: list = None, method: str = "GET" + ) -> Response: + if self.encoded_key is None or self.encoded_key == "": + raise APIKeyMissing + + if headers is None: + headers = { + "Authorization": f"Basic {self.encoded_key}", + "User-Agent": f"SocketPythonScript/{__version__}", + "accept": "application/json", + } + url = f"{self.api_url}/{path}" + try: + response = requests.request( + method.upper(), url, headers=headers, data=payload, files=files, timeout=self.request_timeout + ) + + if response.status_code == 401: + raise APIAccessDenied("Unauthorized") + if response.status_code == 403: + raise APIInsufficientQuota("Insufficient max_quota for API method") + if response.status_code == 404: + raise APIResourceNotFound(f"Path not found {path}") + if response.status_code == 429: + raise APIInsufficientQuota("Insufficient quota for API route") + if response.status_code >= 400: + raise APIFailure("Bad Request") + + return response + + except Exception as error: + response = Response(text=f"{error}", error=True, status_code=500) + raise APIFailure(response) diff --git a/socketdev/dependencies/__init__.py b/socketdev/dependencies/__init__.py index b229aa1..45ea8c5 100644 --- a/socketdev/dependencies/__init__.py +++ b/socketdev/dependencies/__init__.py @@ -1,20 +1,18 @@ -import socketdev -from socketdev.tools import load_files -from urllib.parse import urlencode import json +from urllib.parse import urlencode + +from socketdev.tools import load_files class Dependencies: - @staticmethod - def post(files: list, params: dict) -> dict: + def __init__(self, api): + self.api = api + + def post(self, files: list, params: dict) -> dict: loaded_files = [] loaded_files = load_files(files, loaded_files) path = "dependencies/upload?" + urlencode(params) - response = socketdev.do_request( - path=path, - files=loaded_files, - method="POST" - ) + response = self.api.do_request(path=path, files=loaded_files, method="POST") if response.status_code == 200: result = response.json() else: @@ -23,22 +21,15 @@ def post(files: list, params: dict) -> dict: print(response.text) return result - @staticmethod def get( - limit: int = 50, - offset: int = 0, + self, + limit: int = 50, + offset: int = 0, ) -> dict: path = "dependencies/search" - payload = { - "limit": limit, - "offset": offset - } + payload = {"limit": limit, "offset": offset} payload_str = json.dumps(payload) - response = socketdev.do_request( - path=path, - method="POST", - payload=payload_str - ) + response = self.api.do_request(path=path, method="POST", payload=payload_str) if response.status_code == 200: result = response.json() else: diff --git a/socketdev/export/__init__.py b/socketdev/export/__init__.py index 4e3907e..d56f886 100644 --- a/socketdev/export/__init__.py +++ b/socketdev/export/__init__.py @@ -1,7 +1,6 @@ from urllib.parse import urlencode from dataclasses import dataclass, asdict from typing import Optional -import socketdev @dataclass @@ -21,8 +20,10 @@ def to_query_params(self) -> str: class Export: - @staticmethod - def cdx_bom(org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None) -> dict: + def __init__(self, api): + self.api = api + + def cdx_bom(self, org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None) -> dict: """ Export a Socket SBOM as a CycloneDX SBOM :param org_slug: String - The slug of the organization @@ -33,16 +34,15 @@ def cdx_bom(org_slug: str, id: str, query_params: Optional[ExportQueryParams] = path = f"orgs/{org_slug}/export/cdx/{id}" if query_params: path += query_params.to_query_params() - result = socketdev.do_request(path=path) + response = self.api.do_request(path=path) try: - sbom = result.json() + sbom = response.json() sbom["success"] = True except Exception as error: sbom = {"success": False, "message": str(error)} return sbom - @staticmethod - def spdx_bom(org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None) -> dict: + def spdx_bom(self, org_slug: str, id: str, query_params: Optional[ExportQueryParams] = None) -> dict: """ Export a Socket SBOM as an SPDX SBOM :param org_slug: String - The slug of the organization @@ -53,9 +53,9 @@ def spdx_bom(org_slug: str, id: str, query_params: Optional[ExportQueryParams] = path = f"orgs/{org_slug}/export/spdx/{id}" if query_params: path += query_params.to_query_params() - result = socketdev.do_request(path=path) + response = self.api.do_request(path=path) try: - sbom = result.json() + sbom = response.json() sbom["success"] = True except Exception as error: sbom = {"success": False, "message": str(error)} diff --git a/socketdev/fullscans/__init__.py b/socketdev/fullscans/__init__.py index 7d4d200..ec950bf 100644 --- a/socketdev/fullscans/__init__.py +++ b/socketdev/fullscans/__init__.py @@ -1,127 +1,822 @@ -import socketdev -from socketdev.tools import load_files import json import logging +from enum import Enum +from typing import Any, Dict, List, Optional +from dataclasses import dataclass, asdict, field + + +from ..utils import IntegrationType, Utils log = logging.getLogger("socketdev") +class SocketPURL_Type(str, Enum): + UNKNOWN = "unknown" + NPM = "npm" + PYPI = "pypi" + GOLANG = "golang" + + +class SocketIssueSeverity(str, Enum): + LOW = "low" + MIDDLE = "middle" + HIGH = "high" + CRITICAL = "critical" + + +class SocketCategory(str, Enum): + SUPPLY_CHAIN_RISK = "supplyChainRisk" + QUALITY = "quality" + MAINTENANCE = "maintenance" + VULNERABILITY = "vulnerability" + LICENSE = "license" + MISCELLANEOUS = "miscellaneous" + +class DiffType(str, Enum): + ADDED = "added" + REMOVED = "removed" + UNCHANGED = "unchanged" + REPLACED = "replaced" + UPDATED = "updated" + +@dataclass(kw_only=True) +class SocketPURL: + type: SocketPURL_Type + name: Optional[str] = None + namespace: Optional[str] = None + release: Optional[str] = None + subpath: Optional[str] = None + version: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SocketPURL": + return cls( + type=SocketPURL_Type(data["type"]), + name=data.get("name"), + namespace=data.get("namespace"), + release=data.get("release"), + subpath=data.get("subpath"), + version=data.get("version") + ) + +@dataclass +class SocketManifestReference: + file: str + start: Optional[int] = None + end: Optional[int] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SocketManifestReference": + return cls( + file=data["file"], + start=data.get("start"), + end=data.get("end") + ) + +@dataclass +class FullScanParams: + repo: str + org_slug: Optional[str] = None + branch: Optional[str] = None + commit_message: Optional[str] = None + commit_hash: Optional[str] = None + pull_request: Optional[int] = None + committers: Optional[List[str]] = None + integration_type: Optional[IntegrationType] = None + integration_org_slug: Optional[str] = None + make_default_branch: Optional[bool] = None + set_as_pending_head: Optional[bool] = None + tmp: Optional[bool] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "FullScanParams": + integration_type = data.get("integration_type") + return cls( + repo=data["repo"], + org_slug=data.get("org_slug"), + branch=data.get("branch"), + commit_message=data.get("commit_message"), + commit_hash=data.get("commit_hash"), + pull_request=data.get("pull_request"), + committers=data.get("committers"), + integration_type=IntegrationType(integration_type) if integration_type else None, + integration_org_slug=data.get("integration_org_slug"), + make_default_branch=data.get("make_default_branch"), + set_as_pending_head=data.get("set_as_pending_head"), + tmp=data.get("tmp") + ) + +@dataclass +class FullScanMetadata: + id: str + created_at: str + updated_at: str + organization_id: str + repository_id: str + branch: str + html_report_url: str + repo: Optional[str] = None # In docs, never shows up + organization_slug: Optional[str] = None # In docs, never shows up + committers: Optional[List[str]] = None + commit_message: Optional[str] = None + commit_hash: Optional[str] = None + pull_request: Optional[int] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "FullScanMetadata": + return cls( + id=data["id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + organization_id=data["organization_id"], + repository_id=data["repository_id"], + branch=data["branch"], + html_report_url=data["html_report_url"], + repo=data.get("repo"), + organization_slug=data.get("organization_slug"), + committers=data.get("committers"), + commit_message=data.get("commit_message"), + commit_hash=data.get("commit_hash"), + pull_request=data.get("pull_request") + ) + +@dataclass +class CreateFullScanResponse: + success: bool + status: int + data: Optional[FullScanMetadata] = None + message: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "CreateFullScanResponse": + return cls( + success=data["success"], + status=data["status"], + message=data.get("message"), + data=FullScanMetadata.from_dict(data.get("data")) if data.get("data") else None + ) + +@dataclass +class GetFullScanMetadataResponse: + success: bool + status: int + data: Optional[FullScanMetadata] = None + message: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "GetFullScanMetadataResponse": + return cls( + success=data["success"], + status=data["status"], + message=data.get("message"), + data=FullScanMetadata.from_dict(data.get("data")) if data.get("data") else None + ) + +@dataclass +class DependencyRef: + direct: bool + toplevelAncestors: List[str] + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "DependencyRef": + return cls( + direct=data["direct"], + toplevelAncestors=data["toplevelAncestors"] + ) + +@dataclass +class SocketScore: + supplyChain: float + quality: float + maintenance: float + vulnerability: float + license: float + overall: float + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SocketScore": + return cls( + supplyChain=data["supplyChain"], + quality=data["quality"], + maintenance=data["maintenance"], + vulnerability=data["vulnerability"], + license=data["license"], + overall=data["overall"] + ) + +@dataclass +class SecurityCapabilities: + env: bool + eval: bool + fs: bool + net: bool + shell: bool + unsafe: bool + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SecurityCapabilities": + return cls( + env=data["env"], + eval=data["eval"], + fs=data["fs"], + net=data["net"], + shell=data["shell"], + unsafe=data["unsafe"] + ) + +@dataclass +class Alert: + key: str + type: int + file: str + start: int + end: int + props: Dict[str, Any] + action: str + actionPolicyIndex: int + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "Alert": + return cls( + key=data["key"], + type=data["type"], + file=data["file"], + start=data["start"], + end=data["end"], + props=data["props"], + action=data["action"], + actionPolicyIndex=data["actionPolicyIndex"] + ) + +@dataclass +class LicenseMatch: + licenseId: str + licenseExceptionId: str + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "LicenseMatch": + return cls( + licenseId=data["licenseId"], + licenseExceptionId=data["licenseExceptionId"] + ) + +@dataclass +class LicenseDetail: + authors: List[str] + charEnd: int + charStart: int + filepath: str + match_strength: int + filehash: str + provenance: str + spdxDisj: List[List[LicenseMatch]] + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "LicenseDetail": + return cls( + authors=data["authors"], + charEnd=data["charEnd"], + charStart=data["charStart"], + filepath=data["filepath"], + match_strength=data["match_strength"], + filehash=data["filehash"], + provenance=data["provenance"], + spdxDisj=[[LicenseMatch.from_dict(match) for match in group] + for group in data["spdxDisj"]] + ) + +@dataclass +class AttributionData: + purl: str + foundAuthors: List[str] + foundInFilepath: Optional[str] = None + spdxExpr: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "AttributionData": + return cls( + purl=data["purl"], + foundAuthors=data["foundAuthors"], + foundInFilepath=data.get("foundInFilepath"), + spdxExpr=data.get("spdxExpr") + ) + +@dataclass +class LicenseAttribution: + attribText: str + attribData: List[AttributionData] + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "LicenseAttribution": + return cls( + attribText=data["attribText"], + attribData=[AttributionData.from_dict(item) for item in data["attribData"]] + ) + +@dataclass +class DiffArtifactAlert: + key: str + type: str + severity: Optional[SocketIssueSeverity] = None + category: Optional[SocketCategory] = None + file: Optional[str] = None + start: Optional[int] = None + end: Optional[int] = None + props: Optional[Dict[str, Any]] = None + action: Optional[str] = None + actionPolicyIndex: Optional[int] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "DiffArtifactAlert": + severity = data.get("severity") + category = data.get("category") + return cls( + key=data["key"], + type=data["type"], + severity=SocketIssueSeverity(severity) if severity else None, + category=SocketCategory(category) if category else None, + file=data.get("file"), + start=data.get("start"), + end=data.get("end"), + props=data.get("props"), + action=data.get("action"), + actionPolicyIndex=data.get("actionPolicyIndex") + ) + +@dataclass +class DiffArtifact: + diffType: DiffType + id: str + type: str + name: str + license: str + scores: SocketScore + capabilities: SecurityCapabilities + files: str + version: str + alerts: List[DiffArtifactAlert] + licenseDetails: List[LicenseDetail] + base: Optional[DependencyRef] = None + head: Optional[DependencyRef] = None + namespace: Optional[str] = None + subpath: Optional[str] = None + artifact_id: Optional[str] = None + artifactId: Optional[str] = None + qualifiers: Optional[Dict[str, Any]] = None + size: Optional[int] = None + author: Optional[str] = None + state: Optional[str] = None + error: Optional[str] = None + licenseAttrib: Optional[List[LicenseAttribution]] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "DiffArtifact": + return cls( + diffType=DiffType(data["diffType"]), + id=data["id"], + type=data["type"], + name=data["name"], + license=data.get("license", ""), + scores=SocketScore.from_dict(data["scores"]), + capabilities=SecurityCapabilities.from_dict(data["capabilities"]), + files=data["files"], + version=data["version"], + alerts=[DiffArtifactAlert.from_dict(alert) for alert in data["alerts"]], + licenseDetails=[LicenseDetail.from_dict(detail) for detail in data["licenseDetails"]], + base=DependencyRef.from_dict(data["base"]) if data.get("base") else None, + head=DependencyRef.from_dict(data["head"]) if data.get("head") else None, + namespace=data.get("namespace"), + subpath=data.get("subpath"), + artifact_id=data.get("artifact_id"), + artifactId=data.get("artifactId"), + qualifiers=data.get("qualifiers"), + size=data.get("size"), + author=data.get("author"), + state=data.get("state"), + error=data.get("error"), + licenseAttrib=[LicenseAttribution.from_dict(attrib) for attrib in data["licenseAttrib"]] if data.get("licenseAttrib") else None + ) + +@dataclass +class DiffArtifacts: + added: List[DiffArtifact] + removed: List[DiffArtifact] + unchanged: List[DiffArtifact] + replaced: List[DiffArtifact] + updated: List[DiffArtifact] + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "DiffArtifacts": + return cls( + added=[DiffArtifact.from_dict(a) for a in data["added"]], + removed=[DiffArtifact.from_dict(a) for a in data["removed"]], + unchanged=[DiffArtifact.from_dict(a) for a in data["unchanged"]], + replaced=[DiffArtifact.from_dict(a) for a in data["replaced"]], + updated=[DiffArtifact.from_dict(a) for a in data["updated"]] + ) + +@dataclass +class CommitInfo: + repository_id: str + branch: str + id: str + organization_id: str + committers: List[str] + commit_message: Optional[str] = None + commit_hash: Optional[str] = None + pull_request: Optional[int] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "CommitInfo": + return cls( + repository_id=data["repository_id"], + branch=data["branch"], + id=data["id"], + organization_id=data["organization_id"], + committers=data["committers"], + commit_message=data.get("commit_message"), + commit_hash=data.get("commit_hash"), + pull_request=data.get("pull_request") + ) + +@dataclass +class FullScanDiffReport: + before: CommitInfo + after: CommitInfo + directDependenciesChanged: bool + diff_report_url: str + artifacts: DiffArtifacts + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "FullScanDiffReport": + return cls( + before=CommitInfo.from_dict(data["before"]), + after=CommitInfo.from_dict(data["after"]), + directDependenciesChanged=data["directDependenciesChanged"], + diff_report_url=data["diff_report_url"], + artifacts=DiffArtifacts.from_dict(data["artifacts"]) + ) + +@dataclass +class StreamDiffResponse: + success: bool + status: int + data: Optional[FullScanDiffReport] = None + message: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "StreamDiffResponse": + return cls( + success=data["success"], + status=data["status"], + message=data.get("message"), + data=FullScanDiffReport.from_dict(data.get("data")) if data.get("data") else None + ) + +@dataclass(kw_only=True) +class SocketArtifactLink: + topLevelAncestors: List[str] + artifact: Optional[Dict] = None + dependencies: Optional[List[str]] = None + direct: Optional[bool] = None + manifestFiles: Optional[List[SocketManifestReference]] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SocketArtifactLink": + manifest_files = data.get("manifestFiles") + return cls( + topLevelAncestors=data["topLevelAncestors"], + artifact=data.get("artifact"), + dependencies=data.get("dependencies"), + direct=data.get("direct"), + manifestFiles=[SocketManifestReference.from_dict(m) for m in manifest_files] if manifest_files else None + ) + +@dataclass +class SocketAlert: + key: str + type: str + severity: SocketIssueSeverity + category: SocketCategory + file: Optional[str] = None + start: Optional[int] = None + end: Optional[int] = None + props: Optional[Dict[str, Any]] = None + action: Optional[str] = None + actionPolicyIndex: Optional[int] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SocketAlert": + return cls( + key=data["key"], + type=data["type"], + severity=SocketIssueSeverity(data["severity"]), + category=SocketCategory(data["category"]), + file=data.get("file"), + start=data.get("start"), + end=data.get("end"), + props=data.get("props"), + action=data.get("action"), + actionPolicyIndex=data.get("actionPolicyIndex") + ) + +@dataclass(kw_only=True) +class SocketArtifact(SocketPURL, SocketArtifactLink): + id: str + alerts: Optional[List[SocketAlert]] = field(default_factory=list) + author: Optional[List[str]] = field(default_factory=list) + batchIndex: Optional[int] = None + license: Optional[str] = None + licenseAttrib: Optional[List[LicenseAttribution]] = field(default_factory=list) + licenseDetails: Optional[List[LicenseDetail]] = field(default_factory=list) + score: Optional[SocketScore] = None + size: Optional[float] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SocketArtifact": + # First get the base class data + purl_data = {k: data.get(k) for k in SocketPURL.__dataclass_fields__} + link_data = {k: data.get(k) for k in SocketArtifactLink.__dataclass_fields__} + + # Handle nested types + alerts = data.get("alerts") + license_attrib = data.get("licenseAttrib") + license_details = data.get("licenseDetails") + score = data.get("score") + + return cls( + **purl_data, + **link_data, + id=data["id"], + alerts=[SocketAlert.from_dict(a) for a in alerts] if alerts is not None else [], + author=data.get("author"), + batchIndex=data.get("batchIndex"), + license=data.get("license"), + licenseAttrib=[LicenseAttribution.from_dict(la) for la in license_attrib] if license_attrib else None, + licenseDetails=[LicenseDetail.from_dict(ld) for ld in license_details] if license_details else None, + score=SocketScore.from_dict(score) if score else None, + size=data.get("size") + ) + +@dataclass +class FullScanStreamResponse: + success: bool + status: int + artifacts: Optional[Dict[str, SocketArtifact]] = None + message: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "FullScanStreamResponse": + return cls( + success=data["success"], + status=data["status"], + message=data.get("message"), + artifacts={ + k: SocketArtifact.from_dict(v) + for k, v in data["artifacts"].items() + } if data.get("artifacts") else None + ) class FullScans: - @staticmethod - def create_params_string(params: dict) -> str: + def __init__(self, api): + self.api = api + + def create_params_string(self, params: dict) -> str: param_str = "" - for name in params: - value = params[name] + for name, value in params.items(): if value: - param_str += f"&{name}={value}" + if name == "committers" and isinstance(value, list): + # Handle committers specially - add multiple params + for committer in value: + param_str += f"&{name}={committer}" + else: + param_str += f"&{name}={value}" param_str = "?" + param_str.lstrip("&") return param_str - @staticmethod - def get(org_slug: str, params: dict) -> dict: - params_arg = FullScans.create_params_string(params) + def get(self, org_slug: str, params: dict) -> GetFullScanMetadataResponse: + params_arg = self.create_params_string(params) + Utils.validate_integration_type(params.get("integration_type", "")) path = "orgs/" + org_slug + "/full-scans" + str(params_arg) headers = None payload = None - response = socketdev.do_request(path=path, headers=headers, payload=payload) + response = self.api.do_request(path=path, headers=headers, payload=payload) if response.status_code == 200: result = response.json() - result["success"] = True - result["status"] = 200 - return result - - result = {"success": False, "status": response.status_code, "message": response.text} - - return result + print(f"get full scan metadata result: {result}") + return GetFullScanMetadataResponse.from_dict({ + "success": True, + "status": 200, + "data": result + }) + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error getting full scan metadata: {response.status_code}, message: {error_message}") + return GetFullScanMetadataResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) + + def post(self, files: list, params: FullScanParams) -> CreateFullScanResponse: + + org_slug = str(params.org_slug) + params_dict = params.to_dict() + params_dict.pop("org_slug") + params_arg = self.create_params_string(params_dict) # Convert params to dict - @staticmethod - def post(files: list, params: dict) -> dict: - loaded_files = [] - loaded_files = load_files(files, loaded_files) - - params_arg = FullScans.create_params_string(params) - - path = "orgs/" + str(params["org_slug"]) + "/full-scans" + str(params_arg) - - response = socketdev.do_request(path=path, method="POST", files=loaded_files) + path = "orgs/" + org_slug + "/full-scans" + str(params_arg) + response = self.api.do_request(path=path, method="POST", files=files) + print("finished post") if response.status_code == 201: result = response.json() - else: - print(f"Error posting {files} to the Fullscans API") - print(response.text) - result = response.text - - return result - - @staticmethod - def delete(org_slug: str, full_scan_id: str) -> dict: + print(f"create new full scan result: {result}") + return CreateFullScanResponse.from_dict({ + "success": True, + "status": 201, + "data": result + }) + + log.error(f"Error posting {files} to the Fullscans API") + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(error_message) + + return CreateFullScanResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) + + def delete(self, org_slug: str, full_scan_id: str) -> dict: path = "orgs/" + org_slug + "/full-scans/" + full_scan_id - response = socketdev.do_request(path=path, method="DELETE") + response = self.api.do_request(path=path, method="DELETE") if response.status_code == 200: result = response.json() - else: - result = {} + return result - return result + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error deleting full scan: {response.status_code}, message: {error_message}") + return {} - @staticmethod - def stream_diff(org_slug: str, before: str, after: str, preview: bool = False) -> dict: - path = f"orgs/{org_slug}/full-scans/stream-diff?before={before}&after={after}&preview={preview}" + def stream_diff(self, org_slug: str, before: str, after: str) -> StreamDiffResponse: + path = f"orgs/{org_slug}/full-scans/diff?before={before}&after={after}" - response = socketdev.do_request(path=path, method="GET") + response = self.api.do_request(path=path, method="GET") if response.status_code == 200: - result = response.json() - else: - result = {} - - return result - - @staticmethod - def stream(org_slug: str, full_scan_id: str) -> dict: + return StreamDiffResponse.from_dict({ + "success": True, + "status": 200, + "data": response.json() + }) + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error streaming diff: {response.status_code}, message: {error_message}") + return StreamDiffResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) + + def stream(self, org_slug: str, full_scan_id: str) -> FullScanStreamResponse: path = "orgs/" + org_slug + "/full-scans/" + full_scan_id - response = socketdev.do_request(path=path, method="GET") - + response = self.api.do_request(path=path, method="GET") + if response.status_code == 200: - stream_str = [] - stream_dict = {} - result = response.text - result.strip('"') - result.strip() - for line in result.split("\n"): - if line != '"' and line != "" and line is not None: - item = json.loads(line) - stream_str.append(item) - for val in stream_str: - stream_dict[val["id"]] = val - - stream_dict["success"] = True - stream_dict["status"] = 200 - - return stream_dict + try: + stream_str = [] + artifacts = {} + result = response.text + result.strip('"') + result.strip() + for line in result.split("\n"): + if line != '"' and line != "" and line is not None: + item = json.loads(line) + stream_str.append(item) + for val in stream_str: + artifacts[val["id"]] = val # Just store the raw dict + + return FullScanStreamResponse.from_dict({ + "success": True, + "status": 200, + "artifacts": artifacts # Let from_dict handle the conversion + }) + except Exception as e: + error_message = f"Error parsing stream response: {str(e)}" + log.error(error_message) + return FullScanStreamResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error streaming full scan: {response.status_code}, message: {error_message}") + return FullScanStreamResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) + + def metadata(self, org_slug: str, full_scan_id: str) -> GetFullScanMetadataResponse: + path = "orgs/" + org_slug + "/full-scans/" + full_scan_id + "/metadata" - stream_dict = {"success": False, "status": response.status_code, "message": response.text} + response = self.api.do_request(path=path, method="GET") - return stream_dict + if response.status_code == 200: + return GetFullScanMetadataResponse.from_dict({ + "success": True, + "status": 200, + "data": response.json() + }) - @staticmethod - def metadata(org_slug: str, full_scan_id: str) -> dict: - path = "orgs/" + org_slug + "/full-scans/" + full_scan_id + "/metadata" + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error getting metadata: {response.status_code}, message: {error_message}") + return GetFullScanMetadataResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) - response = socketdev.do_request(path=path, method="GET") - if response.status_code == 200: - result = response.json() - else: - result = {} - return result diff --git a/socketdev/npm/__init__.py b/socketdev/npm/__init__.py index 11bf4b1..a54a6ba 100644 --- a/socketdev/npm/__init__.py +++ b/socketdev/npm/__init__.py @@ -1,20 +1,21 @@ -import socketdev + class NPM: - @staticmethod - def issues(package: str, version: str) -> list: + def __init__(self, api): + self.api = api + + def issues(self, package: str, version: str) -> list: path = f"npm/{package}/{version}/issues" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) issues = [] if response.status_code == 200: issues = response.json() return issues - @staticmethod - def score(package: str, version: str) -> list: + def score(self, package: str, version: str) -> list: path = f"npm/{package}/{version}/score" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) issues = [] if response.status_code == 200: issues = response.json() diff --git a/socketdev/openapi/__init__.py b/socketdev/openapi/__init__.py index 71c4f03..b3df1da 100644 --- a/socketdev/openapi/__init__.py +++ b/socketdev/openapi/__init__.py @@ -1,11 +1,13 @@ -import socketdev + class OpenAPI: - @staticmethod - def get() -> dict: + def __init__(self, api): + self.api = api + + def get(self) -> dict: path = "openapi" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: openapi = response.json() else: diff --git a/socketdev/org/__init__.py b/socketdev/org/__init__.py index 433ed94..12d906c 100644 --- a/socketdev/org/__init__.py +++ b/socketdev/org/__init__.py @@ -1,20 +1,24 @@ -import socketdev +from typing import TypedDict, Dict +class Organization(TypedDict): + id: str + name: str + image: str + plan: str + slug: str + +class OrganizationsResponse(TypedDict): + organizations: Dict[str, Organization] + # Add other fields from the response if needed class Orgs: - @staticmethod - def get() -> dict: - path = "organizations" - headers = None - payload = None + def __init__(self, api): + self.api = api - response = socketdev.do_request( - path=path, - headers=headers, - payload=payload - ) + def get(self) -> OrganizationsResponse: + path = "organizations" + response = self.api.do_request(path=path) if response.status_code == 200: - result = response.json() + return response.json() # Return the full response else: - result = {} - return result + return {"organizations": {}} # Return an empty structure \ No newline at end of file diff --git a/socketdev/purl/__init__.py b/socketdev/purl/__init__.py index 03166d3..6842e11 100644 --- a/socketdev/purl/__init__.py +++ b/socketdev/purl/__init__.py @@ -1,16 +1,16 @@ import json -import socketdev - class Purl: - @staticmethod - def post(license: str = "true", components: list = []) -> dict: + def __init__(self, api): + self.api = api + + def post(self, license: str = "true", components: list = []) -> dict: path = "purl?" + "license=" + license components = {"components": components} components = json.dumps(components) - response = socketdev.do_request(path=path, payload=components, method="POST") + response = self.api.do_request(path=path, payload=components, method="POST") if response.status_code == 200: purl = [] purl_dict = {} diff --git a/socketdev/quota/__init__.py b/socketdev/quota/__init__.py index aebadd9..1494888 100644 --- a/socketdev/quota/__init__.py +++ b/socketdev/quota/__init__.py @@ -1,11 +1,11 @@ -import socketdev - class Quota: - @staticmethod - def get() -> dict: + def __init__(self, api): + self.api = api + + def get(self) -> dict: path = "quota" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: quota = response.json() else: diff --git a/socketdev/report/__init__.py b/socketdev/report/__init__.py index 0c6376d..5483b2a 100644 --- a/socketdev/report/__init__.py +++ b/socketdev/report/__init__.py @@ -1,13 +1,15 @@ -import socketdev + from datetime import datetime, timedelta, timezone class Report: - @staticmethod - def list(from_time: int = None) -> dict: + def __init__(self, api): + self.api = api + + def list(self, from_time: int = None) -> dict: """ This function will return all reports from time specified. - :param from_time: Unix epoch time in seconds. Will default to 30 days + :param from_time: Unix epoch time in seconds. Will default self, to 30 days """ if from_time is None: from_time = int((datetime.now(timezone.utc) - timedelta(days=30)).timestamp()) @@ -17,60 +19,48 @@ def list(from_time: int = None) -> dict: path = "report/list" if from_time is not None: path += f"?from={from_time}" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: reports = response.json() else: reports = {} return reports - @staticmethod - def delete(report_id: str) -> bool: + def delete(self, report_id: str) -> bool: path = f"report/delete/{report_id}" - response = socketdev.do_request( - path=path, - method="DELETE" - ) + response = self.api.do_request(path=path, method="DELETE") if response.status_code == 200: deleted = True else: deleted = False return deleted - @staticmethod - def view(report_id) -> dict: + def view(self, report_id) -> dict: path = f"report/view/{report_id}" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: report = response.json() else: report = {} return report - @staticmethod - def supported() -> dict: + def supported(self) -> dict: path = "report/supported" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: report = response.json() else: report = {} return report - @staticmethod - def create(files: list) -> dict: + def create(self, files: list) -> dict: open_files = [] for name, path in files: - file_info = (name, (name, open(path, 'rb'), 'text/plain')) + file_info = (name, (name, open(path, "rb"), "text/plain")) open_files.append(file_info) path = "report/upload" payload = {} - response = socketdev.do_request( - path=path, - method="PUT", - files=open_files, - payload=payload - ) + response = self.api.do_request(path=path, method="PUT", files=open_files, payload=payload) if response.status_code == 200: reports = response.json() else: diff --git a/socketdev/repos/__init__.py b/socketdev/repos/__init__.py index f057d09..fb4d50a 100644 --- a/socketdev/repos/__init__.py +++ b/socketdev/repos/__init__.py @@ -1,15 +1,73 @@ -import socketdev +import logging +from typing import List, Optional +from dataclasses import dataclass, asdict +log = logging.getLogger("socketdev") + +@dataclass +class RepositoryInfo: + id: str + created_at: str # Could be datetime if we want to parse it + updated_at: str # Could be datetime if we want to parse it + head_full_scan_id: str + name: str + description: str + homepage: str + visibility: str + archived: bool + default_branch: str + slug: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "RepositoryInfo": + return cls( + id=data["id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + head_full_scan_id=data["head_full_scan_id"], + name=data["name"], + description=data["description"], + homepage=data["homepage"], + visibility=data["visibility"], + archived=data["archived"], + default_branch=data["default_branch"], + slug=data.get("slug") + ) + +@dataclass +class GetRepoResponse: + success: bool + status: int + data: Optional[RepositoryInfo] = None + message: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "GetRepoResponse": + return cls( + success=data["success"], + status=data["status"], + message=data.get("message"), + data=RepositoryInfo.from_dict(data.get("data")) if data.get("data") else None + ) class Repos: - @staticmethod - def get(org_slug: str, **kwargs) -> dict[str,]: + def __init__(self, api): + self.api = api + + def get(self, org_slug: str, **kwargs) -> dict[str, List[RepositoryInfo]]: query_params = {} if kwargs: for key, val in kwargs.items(): query_params[key] = val if len(query_params) == 0: return {} + path = "orgs/" + org_slug + "/repos" if query_params is not None: path += "?" @@ -17,68 +75,87 @@ def get(org_slug: str, **kwargs) -> dict[str,]: value = query_params[param] path += f"{param}={value}&" path = path.rstrip("&") - response = socketdev.do_request(path=path) + + response = self.api.do_request(path=path) + if response.status_code == 200: - result = response.json() - else: - result = {} - return result + raw_result = response.json() + result = { + key: [RepositoryInfo.from_dict(repo) for repo in repos] + for key, repos in raw_result.items() + } + return result - @staticmethod - def repo(org_slug: str, repo_name: str) -> dict: + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error getting repositories: {response.status_code}, message: {error_message}") + return {} + + def repo(self, org_slug: str, repo_name: str) -> GetRepoResponse: path = f"orgs/{org_slug}/repos/{repo_name}" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) + if response.status_code == 200: result = response.json() - else: - result = {} - return result + return GetRepoResponse.from_dict({ + "success": True, + "status": 200, + "data": result + }) + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Failed to get repository: {response.status_code}, message: {error_message}") + return GetRepoResponse.from_dict({ + "success": False, + "status": response.status_code, + "message": error_message + }) - @staticmethod - def delete(org_slug: str, name: str) -> dict: + def delete(self, org_slug: str, name: str) -> dict: path = f"orgs/{org_slug}/repos/{name}" - response = socketdev.do_request(path=path, method="DELETE") + response = self.api.do_request(path=path, method="DELETE") + if response.status_code == 200: result = response.json() - else: - result = {} - return result + return result + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error deleting repository: {response.status_code}, message: {error_message}") + return {} - @staticmethod - def post(org_slug: str, **kwargs) -> dict: + def post(self, org_slug: str, **kwargs) -> dict: params = {} if kwargs: for key, val in kwargs.items(): params[key] = val if len(params) == 0: return {} + path = "orgs/" + org_slug + "/repos" - response = socketdev.do_request( - path=path, - method="POST", - payload=params.__dict__ - ) - result = {} + response = self.api.do_request(path=path, method="POST", payload=params.__dict__) + if response.status_code == 200: result = response.json() - return result + return result + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error creating repository: {response.status_code}, message: {error_message}") + return {} - @staticmethod - def update(org_slug: str, repo_name: str, **kwargs) -> dict: + def update(self, org_slug: str, repo_name: str, **kwargs) -> dict: params = {} if kwargs: for key, val in kwargs.keys(): params[key] = val if len(params) == 0: return {} + path = f"orgs/{org_slug}/repos/{repo_name}" - response = socketdev.do_request( - path=path, - method="POST", - payload=params.__dict__ - ) + response = self.api.do_request(path=path, method="POST", payload=params.__dict__) + if response.status_code == 200: result = response.json() - else: - result = {} - return result + return result + + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Error updating repository: {response.status_code}, message: {error_message}") + return {} diff --git a/socketdev/repositories/__init__.py b/socketdev/repositories/__init__.py index 91eab3b..f1eaaa6 100644 --- a/socketdev/repositories/__init__.py +++ b/socketdev/repositories/__init__.py @@ -1,4 +1,3 @@ -import socketdev from typing import TypedDict @@ -12,10 +11,12 @@ class Repo(TypedDict): class Repositories: - @staticmethod - def list() -> dict: + def __init__(self, api): + self.api = api + + def list(self) -> dict: path = "repos" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: repos = response.json() else: diff --git a/socketdev/sbom/__init__.py b/socketdev/sbom/__init__.py index 13f8e37..4752e54 100644 --- a/socketdev/sbom/__init__.py +++ b/socketdev/sbom/__init__.py @@ -1,13 +1,14 @@ -import socketdev import json from socketdev.core.classes import Package class Sbom: - @staticmethod - def view(report_id: str) -> dict[str, dict]: + def __init__(self, api): + self.api = api + + def view(self, report_id: str) -> dict[str, dict]: path = f"sbom/view/{report_id}" - response = socketdev.do_request(path=path) + response = self.api.do_request(path=path) if response.status_code == 200: sbom = [] sbom_dict = {} @@ -19,13 +20,12 @@ def view(report_id: str) -> dict[str, dict]: item = json.loads(line) sbom.append(item) for val in sbom: - sbom_dict[val['id']] = val + sbom_dict[val["id"]] = val else: sbom_dict = {} return sbom_dict - @staticmethod - def create_packages_dict(sbom: dict[str, dict]) -> dict[str, Package]: + def create_packages_dict(self, sbom: dict[str, dict]) -> dict[str, Package]: """ Converts the SBOM Artifacts from the FulLScan into a Dictionary for parsing :param sbom: list - Raw artifacts for the SBOM diff --git a/socketdev/settings/__init__.py b/socketdev/settings/__init__.py index 4cdcb10..e4929b6 100644 --- a/socketdev/settings/__init__.py +++ b/socketdev/settings/__init__.py @@ -1,18 +1,92 @@ -import json +import logging +from enum import Enum +from typing import Dict, Optional +from dataclasses import dataclass, asdict +log = logging.getLogger("socketdev") -import socketdev +class SecurityAction(str, Enum): + DEFER = 'defer' + ERROR = 'error' + WARN = 'warn' + MONITOR = 'monitor' + IGNORE = 'ignore' +@dataclass +class SecurityPolicyRule: + action: SecurityAction + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "SecurityPolicyRule": + return cls( + action=SecurityAction(data["action"]) + ) + +@dataclass +class OrgSecurityPolicyResponse: + success: bool + status: int + securityPolicyRules: Optional[Dict[str, SecurityPolicyRule]] = None + message: Optional[str] = None + + def __getitem__(self, key): return getattr(self, key) + def to_dict(self): return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "OrgSecurityPolicyResponse": + return cls( + securityPolicyRules={ + k: SecurityPolicyRule.from_dict(v) + for k, v in data["securityPolicyRules"].items() + } if data.get("securityPolicyRules") else None, + success=data["success"], + status=data["status"], + message=data.get("message") + ) class Settings: - @staticmethod - def get(org_id: str) -> dict: - settings = {} - path = "settings" - payload = [{"organization": org_id}] - response = socketdev.do_request(path=path, method="POST", payload=json.dumps(payload)) - - if response.status_code != 200: - return settings - - settings = response.json() - return settings + def __init__(self, api): + self.api = api + + def create_params_string(self, params: dict) -> str: + param_str = "" + + for name, value in params.items(): + if value: + if name == "committers" and isinstance(value, list): + # Handle committers specially - add multiple params + for committer in value: + param_str += f"&{name}={committer}" + else: + param_str += f"&{name}={value}" + + param_str = "?" + param_str.lstrip("&") + + return param_str + + def get(self, org_slug: str, custom_rules_only: bool = False) -> OrgSecurityPolicyResponse: + path = f"orgs/{org_slug}/settings/security-policy" + params = {"custom_rules_only": custom_rules_only} + params_args = self.create_params_string(params) if custom_rules_only else "" + path += params_args + print(f"path: {path}") + response = self.api.do_request(path=path, method="GET") + + if response.status_code == 200: + rules = response.json() + return OrgSecurityPolicyResponse.from_dict({ + "securityPolicyRules": rules.get("securityPolicyRules", {}), + "success": True, + "status": 200 + }) + else: + error_message = response.json().get("error", {}).get("message", "Unknown error") + log.error(f"Failed to get security policy: {response.status_code}, message: {error_message}") + return OrgSecurityPolicyResponse.from_dict({ + "securityPolicyRules": {}, + "success": False, + "status": response.status_code, + "message": error_message + }) diff --git a/socketdev/utils/__init__.py b/socketdev/utils/__init__.py new file mode 100644 index 0000000..dd90b16 --- /dev/null +++ b/socketdev/utils/__init__.py @@ -0,0 +1,12 @@ +from typing import Literal + +IntegrationType = Literal["api", "github", "gitlab", "bitbucket", "azure"] +INTEGRATION_TYPES = ("api", "github", "gitlab", "bitbucket", "azure") + + +class Utils: + @staticmethod + def validate_integration_type(integration_type: str) -> IntegrationType: + if integration_type not in INTEGRATION_TYPES: + raise ValueError(f"Invalid integration type: {integration_type}") + return integration_type # type: ignore diff --git a/socketdev/version.py b/socketdev/version.py new file mode 100644 index 0000000..8c0d5d5 --- /dev/null +++ b/socketdev/version.py @@ -0,0 +1 @@ +__version__ = "2.0.0"