import aiohttp
import asyncio
from abc import ABC, abstractmethod
from base64 import b64encode
from dataclasses import dataclass, field
from typing import Dict, Optional, AsyncIterator, AsyncContextManager, Mapping, TypeVar, Generic
from urllib.parse import urljoin
from .error import RequestError
SessionT = TypeVar("SessionT")
@dataclass(frozen=True)
class Request:
method: str
url: str
params: Dict[str, str] = field(default_factory=dict)
headers: Dict[str, str] = field(default_factory=dict)
content: Optional[bytes] = None
@dataclass(frozen=True)
class Response(AsyncIterator[bytes]):
status: int
headers: Mapping[str, str]
content_type: str
charset: Optional[str]
reader: AsyncIterator[bytes]
async def __anext__(self) -> bytes:
return await self.reader.__anext__()
def raise_for(self, status: int, content_type: Optional[str] = None, charset: Optional[str] = None) -> None:
if status != self.status:
raise RequestError(f"Response status {self.status}, expected {status}")
if content_type is not None and content_type != self.content_type:
raise RequestError(f"Response type '{self.content_type}', expected '{content_type}'")
if charset is not None and self.charset is None:
raise RequestError(f"Response charset not given, expected '{charset}'")
if charset is not None and charset != self.charset:
raise RequestError(f"Response charset '{self.charset}', expected '{charset}'")
class Authenticator(Generic[SessionT], ABC):
@abstractmethod
async def authenticate(self, session: SessionT, base_url: str, verify_ssl: bool) -> None:
"""Add authentication information to the given request headers, for example via Authorization or Cookie."""
raise NotImplementedError
@property
@abstractmethod
def username(self) -> str:
raise NotImplementedError
def __str__(self) -> str:
return f"<{self.__class__.__name__} {self.username}>"
class AioBasicAuthenticator(Authenticator[aiohttp.ClientSession]):
def __init__(self, username: str, password: str) -> None:
self._username: str = username
self._password: str = password
async def authenticate(self, session: aiohttp.ClientSession, base_url: str, verify_ssl: bool) -> None:
session.headers.update({
"Authorization": "Basic " + b64encode(f"{self._username}:{self._password}".encode()).decode()
})
@property
def username(self) -> str:
return self._username
class RequestContext(AsyncContextManager[Response], ABC):
"""Proxy to actually start and cleanup a request, yielding a Response."""
pass
class AioRequestContext(RequestContext):
class AioStreamReader(AsyncIterator[bytes]):
def __init__(self, reader: aiohttp.StreamReader) -> None:
self._reader: aiohttp.StreamReader = reader
async def __anext__(self) -> bytes:
try:
rv: bytes = await self._reader.readany()
except aiohttp.EofStream:
raise StopAsyncIteration
except aiohttp.ClientError as e:
raise RequestError(str(e)) from e
except OSError as e:
raise RequestError(str(e)) from e
if rv == b"":
raise StopAsyncIteration
else:
return rv
def __init__(self, url: str,
ctx: AsyncContextManager[aiohttp.ClientResponse],
limit: Optional[asyncio.Semaphore]) -> None:
self._url: str = url
self._ctx: AsyncContextManager[aiohttp.ClientResponse] = ctx
self._limit: Optional[asyncio.Semaphore] = limit
async def __aenter__(self) -> Response:
if self._limit is not None:
await self._limit.acquire()
try:
response: aiohttp.ClientResponse = await self._ctx.__aenter__()
return Response(response.status,
response.headers, response.content_type, response.charset,
AioRequestContext.AioStreamReader(response.content))
except (aiohttp.ClientError, OSError) as e:
raise RequestError(f"Cannot request '{self._url}': {str(e)}") from e
finally:
if self._limit is not None:
self._limit.release()
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self._ctx.__aexit__(exc_type, exc_val, exc_tb)
class Requester(AsyncContextManager, ABC):
"""
Main interface for Request -> RequestContext -> Response.
Must be entered as AsyncContextManager (by the API) for triggering authentication and to maintain the internal
session with connection pool.
"""
@abstractmethod
def request(self, r: Request) -> RequestContext:
raise NotImplementedError
@property
@abstractmethod
def username(self) -> str:
raise NotImplementedError
@property
@abstractmethod
def base_url(self) -> str:
raise NotImplementedError
def __str__(self) -> str:
return f"<{self.__class__.__name__} {self.username} @ {self.base_url}>"
class AioRequester(Requester):
"""
Use the aiohttp library for async requests to the API endpoint.
A limiting semaphore can avoid 'database is locked' errors on some installations.
"""
def __init__(self, base_url: str, authenticator: Authenticator[aiohttp.ClientSession],
verify_ssl: bool = True, limit: int = 0, timeout: Optional[float] = None) -> None:
self._base_url: str = base_url
self._authenticator: Authenticator[aiohttp.ClientSession] = authenticator
self._verify_ssl = verify_ssl
self._timeout: Optional[float] = timeout
self._limit: Optional[asyncio.Semaphore] = asyncio.BoundedSemaphore(limit) if limit > 0 else None
self._session: Optional[aiohttp.ClientSession] = None
def request(self, r: Request) -> RequestContext:
if self._session is None:
raise RequestError(f"Cannot request '{r.url}': No HTTP session")
try:
return AioRequestContext(r.url, self._session.request(
method=r.method.upper(),
url=urljoin(self._base_url, r.url),
params=r.params,
headers=r.headers,
data=r.content,
ssl=None if self._verify_ssl else False,
), self._limit)
except (aiohttp.ClientError, OSError, ValueError) as e:
raise RequestError(f"Cannot request '{r.url}': {str(e)}") from e
async def _create_session(self) -> aiohttp.ClientSession:
"""Session factory."""
session: aiohttp.ClientSession = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self._timeout),
cookie_jar=aiohttp.DummyCookieJar(),
headers={"User-Agent": "Mozilla/5.0 (compatible; nextcloud-tasks-api)"},
)
await self._authenticator.authenticate(session, self._base_url, self._verify_ssl)
return session
async def __aenter__(self) -> 'AioRequester':
if self._session is None:
self._session = await self._create_session()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if self._session is not None:
await self._session.close()
self._session = None
@property
def username(self) -> str:
return self._authenticator.username
@property
def base_url(self) -> str:
return self._base_url