aboutsummaryrefslogtreecommitdiff
path: root/src/hydrilla/proxy/state_impl/repos.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/hydrilla/proxy/state_impl/repos.py')
-rw-r--r--src/hydrilla/proxy/state_impl/repos.py147
1 files changed, 135 insertions, 12 deletions
diff --git a/src/hydrilla/proxy/state_impl/repos.py b/src/hydrilla/proxy/state_impl/repos.py
index be11a88..5553ec2 100644
--- a/src/hydrilla/proxy/state_impl/repos.py
+++ b/src/hydrilla/proxy/state_impl/repos.py
@@ -32,19 +32,52 @@ inside Haketilo.
# Enable using with Python 3.7.
from __future__ import annotations
+import re
import typing as t
import dataclasses as dc
+from urllib.parse import urlparse
from datetime import datetime
+import sqlite3
+
from .. import state as st
from . import base
+from . import prune_packages
+
+
+def validate_repo_url(url: str) -> None:
+ try:
+ parsed = urlparse(url)
+ except:
+ raise st.RepoUrlInvalid()
+
+ if parsed.scheme not in ('http', 'https'):
+ raise st.RepoUrlInvalid()
+
+
+def ensure_repo_not_deleted(cursor: sqlite3.Cursor, repo_id: str) -> None:
+ cursor.execute(
+ 'SELECT deleted FROM repos WHERE repo_id = ?;',
+ (repo_id,)
+ )
+
+ rows = cursor.fetchall()
+
+ if rows == []:
+ raise st.MissingItemError()
+
+ (deleted,), = rows
+
+ if deleted:
+ raise st.MissingItemError()
+
def make_repo_display_info(
ref: st.RepoRef,
name: str,
- url: t.Optional[str],
- deleted: t.Optional[bool],
+ url: str,
+ deleted: bool,
last_refreshed: t.Optional[int],
resource_count: int,
mapping_count: int
@@ -54,13 +87,14 @@ def make_repo_display_info(
last_refreshed_converted = datetime.fromtimestamp(last_refreshed)
return st.RepoDisplayInfo(
- ref = ref,
- name = name,
- url = url,
- deleted = deleted,
- last_refreshed = last_refreshed_converted,
- resource_count = resource_count,
- mapping_count = mapping_count
+ ref = ref,
+ is_local_semirepo = ref.id == '1',
+ name = name,
+ url = url,
+ deleted = deleted,
+ last_refreshed = last_refreshed_converted,
+ resource_count = resource_count,
+ mapping_count = mapping_count
)
@@ -70,15 +104,47 @@ class ConcreteRepoRef(st.RepoRef):
state: base.HaketiloStateWithFields = dc.field(hash=False, compare=False)
def remove(self) -> None:
- raise NotImplementedError()
+ with self.state.cursor(transaction=True) as cursor:
+ ensure_repo_not_deleted(cursor, self.id)
+
+ cursor.execute(
+ '''
+ UPDATE
+ repos
+ SET
+ deleted = TRUE,
+ url = '',
+ active_iteration_id = NULL,
+ last_refreshed = NULL
+ WHERE
+ repo_id = ?;
+ ''',
+ (self.id,)
+ )
def update(
self,
*,
name: t.Optional[str] = None,
url: t.Optional[str] = None
- ) -> st.RepoRef:
- raise NotImplementedError()
+ ) -> None:
+ if name is not None:
+ raise NotImplementedError()
+
+ if url is None:
+ return
+
+ validate_repo_url(url)
+
+ with self.state.cursor(transaction=True) as cursor:
+ ensure_repo_not_deleted(cursor, self.id)
+
+ cursor.execute(
+ 'UPDATE repos SET url = ? WHERE repo_id = ?;',
+ (url, self.id)
+ )
+
+ prune_packages.prune(cursor)
def refresh(self) -> st.RepoIterationRef:
raise NotImplementedError()
@@ -108,6 +174,19 @@ class ConcreteRepoRef(st.RepoRef):
return make_repo_display_info(self, *row)
+repo_name_regex = re.compile(r'''
+^
+(?:
+ []a-zA-Z0-9()<>^&$.!,?@#|;:%"'*{}[/_=+-]+ # allowed non-whitespace characters
+
+ (?: # optional additional words separated by single spaces
+ [ ]
+ []a-zA-Z0-9()<>^&$.!,?@#|;:%"'*{}[/_=+-]+
+ )*
+)
+$
+''', re.VERBOSE)
+
@dc.dataclass(frozen=True)
class ConcreteRepoStore(st.RepoStore):
state: base.HaketiloStateWithFields
@@ -115,6 +194,50 @@ class ConcreteRepoStore(st.RepoStore):
def get(self, id: str) -> st.RepoRef:
return ConcreteRepoRef(id, self.state)
+ def add(self, name: str, url: str) -> st.RepoRef:
+ name = name.strip()
+ if repo_name_regex.match(name) is None:
+ raise st.RepoNameInvalid()
+
+ validate_repo_url(url)
+
+ with self.state.cursor(transaction=True) as cursor:
+ cursor.execute(
+ '''
+ SELECT
+ COUNT(repo_id)
+ FROM
+ repos
+ WHERE
+ NOT deleted AND name = ?;
+ ''',
+ (name,)
+ )
+ (name_taken,), = cursor.fetchall()
+
+ if name_taken:
+ raise st.RepoNameTaken()
+
+ cursor.execute(
+ '''
+ INSERT INTO repos(name, url, deleted, next_iteration)
+ VALUES (?, ?, FALSE, 1)
+ ON CONFLICT (name)
+ DO UPDATE SET
+ name = excluded.name,
+ url = excluded.url,
+ deleted = FALSE,
+ last_refreshed = NULL;
+ ''',
+ (name, url)
+ )
+
+ cursor.execute('SELECT repo_id FROM repos WHERE name = ?;', (name,))
+
+ (repo_id,), = cursor.fetchall()
+
+ return ConcreteRepoRef(str(repo_id), self.state)
+
def get_display_infos(self, include_deleted: bool = False) \
-> t.Iterable[st.RepoDisplayInfo]:
with self.state.cursor() as cursor: