Sfoglia il codice sorgente

Merge pull request #479 from afreydev/tags

Tags migration
Cristian Vargas 5 anni fa
parent
commit
71e111e13f

+ 6 - 5
.github/workflows/test.yml

@@ -3,6 +3,7 @@ on: [push]
 
 
 env:
 env:
   MAX_LINE_LENGTH: 110
   MAX_LINE_LENGTH: 110
+  DOCKER_IMAGE: archivebox-ci
 
 
 jobs:
 jobs:
   lint:
   lint:
@@ -118,12 +119,12 @@ jobs:
 
 
       - name: Build image
       - name: Build image
         run: |
         run: |
-          docker build . -t archivebox
+          docker build . -t "$DOCKER_IMAGE"
 
 
       - name: Init data dir
       - name: Init data dir
         run: |
         run: |
           mkdir data
           mkdir data
-          docker run -v "$PWD"/data:/data archivebox init
+          docker run -v "$PWD"/data:/data "$DOCKER_IMAGE" init
 
 
       - name: Run test server
       - name: Run test server
         run: |
         run: |
@@ -132,7 +133,7 @@ jobs:
 
 
       - name: Add link
       - name: Add link
         run: |
         run: |
-          docker run -v "$PWD"/data:/data --network host archivebox add http://www.test-nginx-1.local
+          docker run -v "$PWD"/data:/data --network host "$DOCKER_IMAGE" add http://www.test-nginx-1.local
 
 
       - name: Add stdin link
       - name: Add stdin link
         run: |
         run: |
@@ -140,8 +141,8 @@ jobs:
 
 
       - name: List links
       - name: List links
         run: |
         run: |
-          docker run -v "$PWD"/data:/data archivebox list | grep -q "www.test-nginx-1.local" || { echo "The site 1 isn't in the list"; exit 1; }
-          docker run -v "$PWD"/data:/data archivebox list | grep -q "www.test-nginx-2.local" || { echo "The site 2 isn't in the list"; exit 1; }
+          docker run -v "$PWD"/data:/data "$DOCKER_IMAGE" list | grep -q "www.test-nginx-1.local" || { echo "The site 1 isn't in the list"; exit 1; }
+          docker run -v "$PWD"/data:/data "$DOCKER_IMAGE" list | grep -q "www.test-nginx-2.local" || { echo "The site 2 isn't in the list"; exit 1; }
 
 
       - name: Start docker-compose stack
       - name: Start docker-compose stack
         run: |
         run: |

+ 38 - 4
archivebox/core/admin.py

@@ -9,9 +9,10 @@ from django.utils.html import format_html
 from django.utils.safestring import mark_safe
 from django.utils.safestring import mark_safe
 from django.shortcuts import render, redirect
 from django.shortcuts import render, redirect
 from django.contrib.auth import get_user_model
 from django.contrib.auth import get_user_model
+from django import forms
 
 
 from core.models import Snapshot
 from core.models import Snapshot
-from core.forms import AddLinkForm
+from core.forms import AddLinkForm, TagField
 from core.utils import get_icons
 from core.utils import get_icons
 
 
 from util import htmldecode, urldecode, ansi_to_html
 from util import htmldecode, urldecode, ansi_to_html
@@ -55,6 +56,32 @@ def delete_snapshots(modeladmin, request, queryset):
 delete_snapshots.short_description = "Delete"
 delete_snapshots.short_description = "Delete"
 
 
 
 
+class SnapshotAdminForm(forms.ModelForm):
+    tags = TagField(required=False)
+
+    class Meta:
+        model = Snapshot
+        fields = "__all__"
+
+    def save(self, commit=True):
+        # Based on: https://stackoverflow.com/a/49933068/3509554
+
+        # Get the unsave instance
+        instance = forms.ModelForm.save(self, False)
+        tags = self.cleaned_data.pop("tags")
+
+        #update save_m2m
+        def new_save_m2m():
+            instance.save_tags(tags)
+
+        # Do we need to save all changes now?
+        self.save_m2m = new_save_m2m
+        if commit:
+            instance.save()
+
+        return instance
+
+
 class SnapshotAdmin(admin.ModelAdmin):
 class SnapshotAdmin(admin.ModelAdmin):
     list_display = ('added', 'title_str', 'url_str', 'files', 'size')
     list_display = ('added', 'title_str', 'url_str', 'files', 'size')
     sort_fields = ('title_str', 'url_str', 'added')
     sort_fields = ('title_str', 'url_str', 'added')
@@ -65,6 +92,13 @@ class SnapshotAdmin(admin.ModelAdmin):
     ordering = ['-added']
     ordering = ['-added']
     actions = [delete_snapshots, overwrite_snapshots, update_snapshots, update_titles, verify_snapshots]
     actions = [delete_snapshots, overwrite_snapshots, update_snapshots, update_titles, verify_snapshots]
     actions_template = 'admin/actions_as_select.html'
     actions_template = 'admin/actions_as_select.html'
+    form = SnapshotAdminForm
+
+    def get_queryset(self, request):
+        return super().get_queryset(request).prefetch_related('tags')
+
+    def tag_list(self, obj):
+        return ', '.join(obj.tags.values_list('name', flat=True))
 
 
     def id_str(self, obj):
     def id_str(self, obj):
         return format_html(
         return format_html(
@@ -75,9 +109,9 @@ class SnapshotAdmin(admin.ModelAdmin):
     def title_str(self, obj):
     def title_str(self, obj):
         canon = obj.as_link().canonical_outputs()
         canon = obj.as_link().canonical_outputs()
         tags = ''.join(
         tags = ''.join(
-            format_html('<span>{}</span>', tag.strip())
-            for tag in obj.tags.split(',')
-        ) if obj.tags else ''
+            format_html(' <a href="/admin/core/snapshot/?tags__id__exact={}"><span class="tag">{}</span></a> ', tag.id, tag)
+            for tag in obj.tags.all()
+        )
         return format_html(
         return format_html(
             '<a href="/{}">'
             '<a href="/{}">'
                 '<img src="/{}/{}" class="favicon" onerror="this.remove()">'
                 '<img src="/{}/{}" class="favicon" onerror="this.remove()">'

+ 42 - 0
archivebox/core/forms.py

@@ -3,6 +3,7 @@ __package__ = 'archivebox.core'
 from django import forms
 from django import forms
 
 
 from ..util import URL_REGEX
 from ..util import URL_REGEX
+from .utils_taggit import edit_string_for_tags, parse_tags
 
 
 CHOICES = (
 CHOICES = (
     ('0', 'depth = 0 (archive just these URLs)'),
     ('0', 'depth = 0 (archive just these URLs)'),
@@ -12,3 +13,44 @@ CHOICES = (
 class AddLinkForm(forms.Form):
 class AddLinkForm(forms.Form):
     url = forms.RegexField(label="URLs (one per line)", regex=URL_REGEX, min_length='6', strip=True, widget=forms.Textarea, required=True)
     url = forms.RegexField(label="URLs (one per line)", regex=URL_REGEX, min_length='6', strip=True, widget=forms.Textarea, required=True)
     depth = forms.ChoiceField(label="Archive depth", choices=CHOICES, widget=forms.RadioSelect, initial='0')
     depth = forms.ChoiceField(label="Archive depth", choices=CHOICES, widget=forms.RadioSelect, initial='0')
+
+
+class TagWidgetMixin:
+    def format_value(self, value):
+        if value is not None and not isinstance(value, str):
+            value = edit_string_for_tags(value)
+        return super().format_value(value)
+
+class TagWidget(TagWidgetMixin, forms.TextInput):
+    pass
+
+class TagField(forms.CharField):
+    widget = TagWidget
+
+    def clean(self, value):
+        value = super().clean(value)
+        try:
+            return parse_tags(value)
+        except ValueError:
+            raise forms.ValidationError(
+                "Please provide a comma-separated list of tags."
+            )
+
+    def has_changed(self, initial_value, data_value):
+        # Always return False if the field is disabled since self.bound_data
+        # always uses the initial value in this case.
+        if self.disabled:
+            return False
+
+        try:
+            data_value = self.clean(data_value)
+        except forms.ValidationError:
+            pass
+
+        if initial_value is None:
+            initial_value = []
+
+        initial_value = [tag.name for tag in initial_value]
+        initial_value.sort()
+
+        return initial_value != data_value

+ 70 - 0
archivebox/core/migrations/0006_auto_20201012_1520.py

@@ -0,0 +1,70 @@
+# Generated by Django 3.0.8 on 2020-10-12 15:20
+
+from django.db import migrations, models
+from django.utils.text import slugify
+
+def forwards_func(apps, schema_editor):
+    SnapshotModel = apps.get_model("core", "Snapshot")
+    TagModel = apps.get_model("core", "Tag")
+
+    db_alias = schema_editor.connection.alias
+    snapshots = SnapshotModel.objects.all()
+    for snapshot in snapshots:
+        tags = snapshot.tags
+        tag_set = (
+            set(tag.strip() for tag in (snapshot.tags_old or '').split(','))
+        )
+        tag_set.discard("")
+
+        for tag in tag_set:
+            to_add, _ = TagModel.objects.get_or_create(name=tag, slug=slugify(tag))
+            snapshot.tags.add(to_add)
+
+
+def reverse_func(apps, schema_editor):
+    SnapshotModel = apps.get_model("core", "Snapshot")
+    TagModel = apps.get_model("core", "Tag")
+
+    db_alias = schema_editor.connection.alias
+    snapshots = SnapshotModel.objects.all()
+    for snapshot in snapshots:
+        tags = snapshot.tags.values_list("name", flat=True)
+        snapshot.tags_old = ",".join([tag for tag in tags])
+        snapshot.save()
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('core', '0005_auto_20200728_0326'),
+    ]
+
+    operations = [
+        migrations.RenameField(
+            model_name='snapshot',
+            old_name='tags',
+            new_name='tags_old',
+        ),
+        migrations.CreateModel(
+            name='Tag',
+            fields=[
+                ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+                ('name', models.CharField(max_length=100, unique=True, verbose_name='name')),
+                ('slug', models.SlugField(max_length=100, unique=True, verbose_name='slug')),
+            ],
+            options={
+                'verbose_name': 'Tag',
+                'verbose_name_plural': 'Tags',
+            },
+        ),
+        migrations.AddField(
+            model_name='snapshot',
+            name='tags',
+            field=models.ManyToManyField(to='core.Tag'),
+        ),
+        migrations.RunPython(forwards_func, reverse_func),
+        migrations.RemoveField(
+            model_name='snapshot',
+            name='tags_old',
+        ),
+    ]

+ 60 - 4
archivebox/core/models.py

@@ -2,13 +2,55 @@ __package__ = 'archivebox.core'
 
 
 import uuid
 import uuid
 
 
-from django.db import models
+from django.db import models, transaction
 from django.utils.functional import cached_property
 from django.utils.functional import cached_property
+from django.utils.text import slugify
 
 
 from ..util import parse_date
 from ..util import parse_date
 from ..index.schema import Link
 from ..index.schema import Link
 
 
 
 
+class Tag(models.Model):
+    """
+    Based on django-taggit model
+    """
+    name = models.CharField(verbose_name="name", unique=True, blank=False, max_length=100)
+    slug = models.SlugField(verbose_name="slug", unique=True, max_length=100)
+
+    class Meta:
+        verbose_name = "Tag"
+        verbose_name_plural = "Tags"
+
+    def __str__(self):
+        return self.name
+
+    def slugify(self, tag, i=None):
+        slug = slugify(tag)
+        if i is not None:
+            slug += "_%d" % i
+        return slug
+
+    def save(self, *args, **kwargs):
+        if self._state.adding and not self.slug:
+            self.slug = self.slugify(self.name)
+
+            with transaction.atomic():
+                slugs = set(
+                    type(self)
+                    ._default_manager.filter(slug__startswith=self.slug)
+                    .values_list("slug", flat=True)
+                )
+
+                i = None
+                while True:
+                    slug = self.slugify(self.name, i)
+                    if slug not in slugs:
+                        self.slug = slug
+                        return super().save(*args, **kwargs)
+                    i = 1 if i is None else i+1
+        else:
+            return super().save(*args, **kwargs)
+
 class Snapshot(models.Model):
 class Snapshot(models.Model):
     id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
     id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
 
 
@@ -16,11 +58,10 @@ class Snapshot(models.Model):
     timestamp = models.CharField(max_length=32, unique=True, db_index=True)
     timestamp = models.CharField(max_length=32, unique=True, db_index=True)
 
 
     title = models.CharField(max_length=128, null=True, blank=True, db_index=True)
     title = models.CharField(max_length=128, null=True, blank=True, db_index=True)
-    tags = models.CharField(max_length=256, null=True, blank=True, db_index=True)
 
 
     added = models.DateTimeField(auto_now_add=True, db_index=True)
     added = models.DateTimeField(auto_now_add=True, db_index=True)
     updated = models.DateTimeField(null=True, blank=True, db_index=True)
     updated = models.DateTimeField(null=True, blank=True, db_index=True)
-    # bookmarked = models.DateTimeField()
+    tags = models.ManyToManyField(Tag)
 
 
     keys = ('url', 'timestamp', 'title', 'tags', 'updated')
     keys = ('url', 'timestamp', 'title', 'tags', 'updated')
 
 
@@ -41,7 +82,8 @@ class Snapshot(models.Model):
         args = args or self.keys
         args = args or self.keys
         return {
         return {
             key: getattr(self, key)
             key: getattr(self, key)
-            for key in args
+            if key != 'tags' else self.get_tags_str()
+            for key in args 
         }
         }
 
 
     def as_link(self) -> Link:
     def as_link(self) -> Link:
@@ -50,6 +92,13 @@ class Snapshot(models.Model):
     def as_link_with_details(self) -> Link:
     def as_link_with_details(self) -> Link:
         from ..index import load_link_details
         from ..index import load_link_details
         return load_link_details(self.as_link())
         return load_link_details(self.as_link())
+    
+    def get_tags_str(self) -> str:
+        tags = ','.join(
+            tag.name
+            for tag in self.tags.all()
+        ) if self.tags.all() else ''
+        return tags
 
 
     @cached_property
     @cached_property
     def bookmarked(self):
     def bookmarked(self):
@@ -96,3 +145,10 @@ class Snapshot(models.Model):
             and self.history['title'][-1].output.strip()):
             and self.history['title'][-1].output.strip()):
             return self.history['title'][-1].output.strip()
             return self.history['title'][-1].output.strip()
         return None
         return None
+
+    def save_tags(self, tags=[]):
+        tags_id = []
+        for tag in tags:
+            tags_id.append(Tag.objects.get_or_create(name=tag)[0].id)
+        self.tags.clear()
+        self.tags.add(*tags_id)

+ 113 - 0
archivebox/core/utils_taggit.py

@@ -0,0 +1,113 @@
+# Taken from https://github.com/jazzband/django-taggit/blob/3b56adb637ab95aca5036c37a358402c825a367c/taggit/utils.py
+
+def parse_tags(tagstring):
+    """
+    Parses tag input, with multiple word input being activated and
+    delineated by commas and double quotes. Quotes take precedence, so
+    they may contain commas.
+
+    Returns a sorted list of unique tag names.
+
+    Ported from Jonathan Buchanan's `django-tagging
+    <http://django-tagging.googlecode.com/>`_
+    """
+    if not tagstring:
+        return []
+
+    # Special case - if there are no commas or double quotes in the
+    # input, we don't *do* a recall... I mean, we know we only need to
+    # split on spaces.
+    if "," not in tagstring and '"' not in tagstring:
+        words = list(set(split_strip(tagstring, " ")))
+        words.sort()
+        return words
+
+    words = []
+    buffer = []
+    # Defer splitting of non-quoted sections until we know if there are
+    # any unquoted commas.
+    to_be_split = []
+    saw_loose_comma = False
+    open_quote = False
+    i = iter(tagstring)
+    try:
+        while True:
+            c = next(i)
+            if c == '"':
+                if buffer:
+                    to_be_split.append("".join(buffer))
+                    buffer = []
+                # Find the matching quote
+                open_quote = True
+                c = next(i)
+                while c != '"':
+                    buffer.append(c)
+                    c = next(i)
+                if buffer:
+                    word = "".join(buffer).strip()
+                    if word:
+                        words.append(word)
+                    buffer = []
+                open_quote = False
+            else:
+                if not saw_loose_comma and c == ",":
+                    saw_loose_comma = True
+                buffer.append(c)
+    except StopIteration:
+        # If we were parsing an open quote which was never closed treat
+        # the buffer as unquoted.
+        if buffer:
+            if open_quote and "," in buffer:
+                saw_loose_comma = True
+            to_be_split.append("".join(buffer))
+    if to_be_split:
+        if saw_loose_comma:
+            delimiter = ","
+        else:
+            delimiter = " "
+        for chunk in to_be_split:
+            words.extend(split_strip(chunk, delimiter))
+    words = list(set(words))
+    words.sort()
+    return words
+
+
+def split_strip(string, delimiter=","):
+    """
+    Splits ``string`` on ``delimiter``, stripping each resulting string
+    and returning a list of non-empty strings.
+
+    Ported from Jonathan Buchanan's `django-tagging
+    <http://django-tagging.googlecode.com/>`_
+    """
+    if not string:
+        return []
+
+    words = [w.strip() for w in string.split(delimiter)]
+    return [w for w in words if w]
+
+
+def edit_string_for_tags(tags):
+    """
+    Given list of ``Tag`` instances, creates a string representation of
+    the list suitable for editing by the user, such that submitting the
+    given string representation back without changing it will give the
+    same list of tags.
+
+    Tag names which contain commas will be double quoted.
+
+    If any tag name which isn't being quoted contains whitespace, the
+    resulting string of tag names will be comma-delimited, otherwise
+    it will be space-delimited.
+
+    Ported from Jonathan Buchanan's `django-tagging
+    <http://django-tagging.googlecode.com/>`_
+    """
+    names = []
+    for tag in tags:
+        name = tag.name
+        if "," in name or " " in name:
+            names.append('"%s"' % name)
+        else:
+            names.append(name)
+    return ", ".join(sorted(names))

+ 14 - 2
archivebox/index/sql.py

@@ -34,13 +34,19 @@ def remove_from_sql_main_index(snapshots: QuerySet, out_dir: Path=OUTPUT_DIR) ->
 def write_link_to_sql_index(link: Link):
 def write_link_to_sql_index(link: Link):
     from core.models import Snapshot
     from core.models import Snapshot
     info = {k: v for k, v in link._asdict().items() if k in Snapshot.keys}
     info = {k: v for k, v in link._asdict().items() if k in Snapshot.keys}
+    tags = info.pop("tags")
+    if tags is None:
+        tags = []
+
     try:
     try:
         info["timestamp"] = Snapshot.objects.get(url=link.url).timestamp
         info["timestamp"] = Snapshot.objects.get(url=link.url).timestamp
     except Snapshot.DoesNotExist:
     except Snapshot.DoesNotExist:
         while Snapshot.objects.filter(timestamp=info["timestamp"]).exists():
         while Snapshot.objects.filter(timestamp=info["timestamp"]).exists():
             info["timestamp"] = str(float(info["timestamp"]) + 1.0)
             info["timestamp"] = str(float(info["timestamp"]) + 1.0)
 
 
-    return Snapshot.objects.update_or_create(url=link.url, defaults=info)[0]
+    snapshot, _ = Snapshot.objects.update_or_create(url=link.url, defaults=info)
+    snapshot.save_tags(tags)
+    return snapshot
 
 
 
 
 @enforce_types
 @enforce_types
@@ -65,8 +71,14 @@ def write_sql_link_details(link: Link, out_dir: Path=OUTPUT_DIR) -> None:
         except Snapshot.DoesNotExist:
         except Snapshot.DoesNotExist:
             snap = write_link_to_sql_index(link)
             snap = write_link_to_sql_index(link)
         snap.title = link.title
         snap.title = link.title
-        snap.tags = link.tags
+
+        tag_set = (
+            set(tag.strip() for tag in (link.tags or '').split(','))
+        )
+        tag_list = list(tag_set) or []
+
         snap.save()
         snap.save()
+        snap.save_tags(tag_list)
 
 
 
 
 
 

+ 8 - 0
archivebox/themes/default/static/admin.css

@@ -222,3 +222,11 @@ body.model-snapshot.change-list #content .object-tools {
   0% { transform: rotate(0deg); }
   0% { transform: rotate(0deg); }
   100% { transform: rotate(360deg); }
   100% { transform: rotate(360deg); }
 }
 }
+
+.tags > a > .tag {
+  border: 1px solid;
+  border-radius: 10px;
+  background-color: #f3f3f3;
+  padding: 3px;
+}
+

+ 1 - 1
docker-compose.yml

@@ -12,7 +12,7 @@ version: '3.7'
 services:
 services:
     archivebox:
     archivebox:
         # build: .
         # build: .
-        image: nikisweeting/archivebox:latest
+        image: ${DOCKER_IMAGE:-nikisweeting/archivebox:latest} 
         command: server 0.0.0.0:8000
         command: server 0.0.0.0:8000
         stdin_open: true
         stdin_open: true
         tty: true
         tty: true

BIN
tests/tags_migration/index.sqlite3


+ 40 - 2
tests/test_init.py

@@ -4,7 +4,7 @@
 import os
 import os
 import subprocess
 import subprocess
 from pathlib import Path
 from pathlib import Path
-import json
+import json, shutil
 import sqlite3
 import sqlite3
 
 
 from archivebox.config import OUTPUT_PERMISSIONS
 from archivebox.config import OUTPUT_PERMISSIONS
@@ -131,4 +131,42 @@ def test_unrecognized_folders(tmp_path, process, disable_extractors_dict):
 
 
     init_process = subprocess.run(['archivebox', 'init'], capture_output=True, env=disable_extractors_dict)
     init_process = subprocess.run(['archivebox', 'init'], capture_output=True, env=disable_extractors_dict)
     assert "Skipped adding 1 invalid link data directories" in init_process.stdout.decode("utf-8")
     assert "Skipped adding 1 invalid link data directories" in init_process.stdout.decode("utf-8")
-    assert init_process.returncode == 0
+    assert init_process.returncode == 0
+
+def test_tags_migration(tmp_path, disable_extractors_dict):
+    
+    base_sqlite_path = Path(__file__).parent / 'tags_migration'
+    
+    if os.path.exists(tmp_path):
+        shutil.rmtree(tmp_path)
+    shutil.copytree(str(base_sqlite_path), tmp_path)
+    os.chdir(tmp_path)
+
+    conn = sqlite3.connect("index.sqlite3")
+    conn.row_factory = sqlite3.Row
+    c = conn.cursor()
+    c.execute("SELECT id, tags from core_snapshot")
+    snapshots = c.fetchall()
+    snapshots_dict = { sn['id']: sn['tags'] for sn in snapshots}
+    conn.commit()
+    conn.close()
+    
+    init_process = subprocess.run(['archivebox', 'init'], capture_output=True, env=disable_extractors_dict)
+
+    conn = sqlite3.connect("index.sqlite3")
+    conn.row_factory = sqlite3.Row
+    c = conn.cursor()
+    c.execute("""
+        SELECT core_snapshot.id, core_tag.name from core_snapshot
+        JOIN core_snapshot_tags on core_snapshot_tags.snapshot_id=core_snapshot.id
+        JOIN core_tag on core_tag.id=core_snapshot_tags.tag_id
+    """)
+    tags = c.fetchall()
+    conn.commit()
+    conn.close()
+
+    for tag in tags:
+        snapshot_id = tag["id"]
+        tag_name = tag["name"]
+        # Check each tag migrated is in the previous field
+        assert tag_name in snapshots_dict[snapshot_id]