mirror of
https://github.com/Eugeny/tabby-web.git
synced 2025-06-08 13:39:56 +00:00
214 lines
6.0 KiB
Python
214 lines
6.0 KiB
Python
import fsspec
|
|
import os
|
|
import asyncio
|
|
import random
|
|
from django.conf import settings
|
|
from django.contrib.auth import logout
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from rest_framework import fields, status
|
|
from rest_framework.exceptions import APIException, PermissionDenied, NotFound
|
|
from rest_framework.permissions import IsAuthenticated
|
|
from rest_framework.response import Response
|
|
from rest_framework.mixins import ListModelMixin, RetrieveModelMixin, UpdateModelMixin
|
|
from rest_framework.views import APIView
|
|
from rest_framework.viewsets import GenericViewSet, ModelViewSet
|
|
from rest_framework.serializers import ModelSerializer, Serializer
|
|
from rest_framework_dataclasses.serializers import DataclassSerializer
|
|
from social_django.models import UserSocialAuth
|
|
from typing import List
|
|
from urllib.parse import urlparse
|
|
|
|
from .gateway import GatewayAdminConnection
|
|
from .sponsors import check_is_sponsor, check_is_sponsor_cached
|
|
from .models import Config, Gateway, User
|
|
|
|
|
|
@dataclass
|
|
class AppVersion:
|
|
version: str
|
|
plugins: List[str]
|
|
|
|
|
|
class AppVersionSerializer(DataclassSerializer):
|
|
class Meta:
|
|
dataclass = AppVersion
|
|
|
|
|
|
class GatewaySerializer(ModelSerializer):
|
|
url = fields.SerializerMethodField()
|
|
auth_token = fields.CharField()
|
|
|
|
class Meta:
|
|
fields = '__all__'
|
|
model = Gateway
|
|
|
|
def get_url(self, gw):
|
|
return f'{"wss" if gw.secure else "ws"}://{gw.host}:{gw.port}/'
|
|
|
|
|
|
class ConfigSerializer(ModelSerializer):
|
|
name = fields.CharField(required=False)
|
|
|
|
class Meta:
|
|
model = Config
|
|
read_only_fields = ('user', 'created_at', 'modified_at')
|
|
fields = '__all__'
|
|
|
|
|
|
class ConfigViewSet(ModelViewSet):
|
|
queryset = Config.objects.all()
|
|
serializer_class = ConfigSerializer
|
|
permission_classes = [IsAuthenticated]
|
|
|
|
def get_queryset(self):
|
|
if self.request.user.is_authenticated:
|
|
return Config.objects.filter(user=self.request.user)
|
|
return Config.objects.none()
|
|
|
|
def perform_create(self, serializer):
|
|
serializer.save(user=self.request.user)
|
|
|
|
|
|
class AppVersionViewSet(ListModelMixin, GenericViewSet):
|
|
serializer_class = AppVersionSerializer
|
|
lookup_field = 'id'
|
|
lookup_value_regex = r'[\w\d.-]+'
|
|
queryset = ''
|
|
|
|
def _get_versions(self):
|
|
fs = fsspec.filesystem(urlparse(settings.APP_DIST_STORAGE).scheme)
|
|
return [
|
|
self._get_version(x['name'])
|
|
for x in fs.listdir(settings.APP_DIST_STORAGE)
|
|
]
|
|
|
|
def _get_version(self, dir):
|
|
fs = fsspec.filesystem(urlparse(settings.APP_DIST_STORAGE).scheme)
|
|
plugins = [
|
|
os.path.basename(x['name'])
|
|
for x in fs.listdir(dir)
|
|
if x['type'] == 'directory' and os.path.basename(x['name'])
|
|
not in [
|
|
'tabby-web-container',
|
|
'tabby-web-demo',
|
|
]
|
|
]
|
|
|
|
return AppVersion(
|
|
version=os.path.basename(dir),
|
|
plugins=plugins,
|
|
)
|
|
|
|
def list(self, request, *args, **kwargs):
|
|
return Response(
|
|
self.serializer_class(
|
|
self._get_versions(),
|
|
many=True,
|
|
).data
|
|
)
|
|
|
|
|
|
class UserSerializer(ModelSerializer):
|
|
id = fields.IntegerField()
|
|
is_pro = fields.SerializerMethodField()
|
|
is_sponsor = fields.SerializerMethodField()
|
|
github_username = fields.SerializerMethodField()
|
|
|
|
class Meta:
|
|
model = User
|
|
fields = (
|
|
'id',
|
|
'username',
|
|
'active_config',
|
|
'custom_connection_gateway',
|
|
'custom_connection_gateway_token',
|
|
'config_sync_token',
|
|
'is_pro',
|
|
'is_sponsor',
|
|
'github_username',
|
|
)
|
|
read_only_fields = ('id', 'username')
|
|
|
|
def get_is_pro(self, obj):
|
|
return check_is_sponsor_cached(obj) or obj.force_pro
|
|
|
|
def get_is_sponsor(self, obj):
|
|
return check_is_sponsor_cached(obj)
|
|
|
|
def get_github_username(self, obj):
|
|
social_auth = UserSocialAuth.objects.filter(user=obj, provider='github').first()
|
|
if not social_auth:
|
|
return None
|
|
|
|
return social_auth.extra_data.get('login')
|
|
|
|
|
|
class UserViewSet(RetrieveModelMixin, UpdateModelMixin, GenericViewSet):
|
|
queryset = User.objects.all()
|
|
serializer_class = UserSerializer
|
|
|
|
def get_object(self):
|
|
if self.request.user.is_authenticated:
|
|
return self.request.user
|
|
raise PermissionDenied()
|
|
|
|
|
|
class LogoutView(APIView):
|
|
def post(self, request, format=None):
|
|
logout(request)
|
|
return Response(None)
|
|
|
|
|
|
class InstanceInfoSerializer(Serializer):
|
|
login_enabled = fields.BooleanField()
|
|
homepage_enabled = fields.BooleanField()
|
|
|
|
|
|
class InstanceInfoViewSet(RetrieveModelMixin, GenericViewSet):
|
|
queryset = '' # type: ignore
|
|
serializer_class = InstanceInfoSerializer
|
|
|
|
def get_object(self):
|
|
return {
|
|
'login_enabled': settings.ENABLE_LOGIN,
|
|
'homepage_enabled': settings.ENABLE_HOMEPAGE,
|
|
}
|
|
|
|
|
|
class NoGatewaysError(APIException):
|
|
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
|
default_detail = 'No connection gateways available.'
|
|
default_code = 'no_gateways'
|
|
|
|
|
|
class ChooseGatewayViewSet(RetrieveModelMixin, GenericViewSet):
|
|
queryset = Gateway.objects.filter(enabled=True)
|
|
serializer_class = GatewaySerializer
|
|
|
|
async def _authorize_client(self, gw):
|
|
c = GatewayAdminConnection(gw)
|
|
await c.connect()
|
|
token = await c.authorize_client()
|
|
await c.close()
|
|
return token
|
|
|
|
def get_object(self):
|
|
gateways = list(self.queryset)
|
|
random.shuffle(gateways)
|
|
if not len(gateways):
|
|
raise NotFound()
|
|
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
for gw in gateways:
|
|
try:
|
|
gw.auth_token = loop.run_until_complete(self._authorize_client(gw))
|
|
except ConnectionError:
|
|
continue
|
|
return gw
|
|
|
|
raise NoGatewaysError()
|
|
finally:
|
|
loop.close()
|