手机上运行的网盘代码

资源推荐2周前发布 aitew
7 00
除了文件移动没有外,基础功能都是有的,应该没有bug吧。在Termux上运行,附带Termux的安卓安装包。
具体信息请看视频:https://b23.tv/Zm823vt
代码地址:https://wwamp.lanzouu.com/izuJN3o4ogre
手机上运行的网盘代码

下载附件  保存到相册

手机上运行的网盘代码

[Python]

  1. #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
     
    import os
    import sys
    import json
    import asyncio
    import socket
    import hashlib
    import mimetypes
    import uuid
    import shutil
    import re
    import time
    import subprocess
    from datetime import datetime, timedelta
    from pathlib import Path
    from urllib.parse import unquote, quote
    from typing import Dict, List, Optional, Tuple
     
    # -------------------- 自动安装依赖 --------------------
    def ensure_dependencies():
        required = {
            'aiohttp': 'aiohttp',
            'aiofiles': 'aiofiles',
            'aiosqlite': 'aiosqlite',
            'netifaces': 'netifaces',
        }
        for module, package in required.items():
            try:
                __import__(module)
            except ImportError:
                print(f"[*] 缺少依赖: {package},正在安装...")
                try:
                    subprocess.check_call(
                        [sys.executable, "-m", "pip", "install", package]
                    )
                    print(f"[+] {package} 安装成功")
                except Exception as e:
                    print(f"[-] 安装失败: {e},请手动安装 {package}")
                    sys.exit(1)
     
    ensure_dependencies()
     
    import aiohttp
    from aiohttp import web
    try:
        from aiohttp.client_exceptions import ClientDisconnectedError
    except ImportError:
        try:
            from aiohttp import ClientDisconnectedError
        except ImportError:
            from aiohttp import ClientError as ClientDisconnectedError
    import aiofiles
    import aiofiles.os
    import aiosqlite
     
    mimetypes.init()
     
     
    # -------------------- 日志记录器 --------------------
    class Logger:
        def __init__(self, log_file='cloud_drive.log'):
            self.log_file = log_file
            log_dir = os.path.dirname(log_file)
            if log_dir and not os.path.exists(log_dir):
                os.makedirs(log_dir, exist_ok=True)
            self._write_startup_marker()
            self.download_requests: Dict[str, set] = {}
     
        def _write_startup_marker(self):
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            marker = f"\n{'='*60}\n程序启动时间: {timestamp}\n{'='*60}\n"
            try:
                with open(self.log_file, 'a', encoding='utf-8') as f:
                    f.write(marker)
            except Exception:
                pass
     
        def _log(self, ip: str, user: str, action: str, details: str = "", status: str = ""):
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            parts = [f"[{timestamp}]"]
            if ip:
                parts.append(ip)
            if user:
                parts.append(f"{user} -")
            parts.append(action)
            if details:
                parts.append(details)
            if status:
                parts.append(f"状态{status}")
            log_entry = " ".join(parts)
            print(log_entry)
            try:
                with open(self.log_file, 'a', encoding='utf-8') as f:
                    f.write(log_entry + "\n")
            except Exception:
                pass
     
        def log_access(self, ip, user, path, status):
            if not path:
                path = "/"
            action = "访问成功"
            details = path if path != "/" else "/"
            self._log(ip, user, action, details, str(status))
     
        def log_download(self, ip, user, path, force_download, status, is_range_request=False, request_id=None):
            if is_range_request:
                return
            key = f"{user}:{path}"
            if request_id:
                if key not in self.download_requests:
                    self.download_requests[key] = set()
                if request_id in self.download_requests[key]:
                    return
                self.download_requests[key].add(request_id)
            download_type = "下载" if force_download else "预览"
            action = "文件下载"
            details = f"{path} ({download_type})"
            self._log(ip, user, action, details, str(status))
     
        def log_upload(self, ip, user, file_count, path="", status=200):
            action = "文件上传"
            details = f"{file_count} 个文件"
            if path:
                details += f" {path}"
            self._log(ip, user, action, details, str(status))
     
        def log_create_folder(self, ip, user, path, status=200):
            self._log(ip, user, "创建文件夹成功", path, str(status))
     
        def log_rename(self, ip, user, old_path, new_name, status=200):
            self._log(ip, user, "重命名成功", f"{old_path} -> {new_name}", str(status))
     
        def log_delete(self, ip, user, path, status=200):
            self._log(ip, user, "删除成功", path, str(status))
     
        def log_server_event(self, msg):
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            entry = f"[{timestamp}] {msg}"
            print(entry)
            try:
                with open(self.log_file, 'a', encoding='utf-8') as f:
                    f.write(entry + "\n")
            except Exception:
                pass
     
        def log_error(self, ip="", user="", action="", details="", error="", is_range_request=False):
            if is_range_request and error in ["Cannot write to closing transport", "Connection lost"]:
                return
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            parts = [f"[{timestamp}]"]
            if ip:
                parts.append(ip)
            if user:
                parts.append(f"{user} -")
            if action:
                parts.append(f"{action}失败")
            if details:
                parts.append(details)
            if error:
                parts.append(f"错误: {error}")
            entry = " ".join(parts)
            print(entry)
            try:
                with open(self.log_file, 'a', encoding='utf-8') as f:
                    f.write(entry + "\n")
            except Exception:
                pass
     
        def log_login_failed(self, ip, username, reason=""):
            self._log(ip, "游客", "登录失败", f"用户名: {username} {reason}".strip(), "401")
     
        def log_register_success(self, ip, username):
            self._log(ip, "游客", "注册成功", f"新用户: {username}", "200")
     
        def log_delete_user_success(self, ip, username, user_id):
            self._log(ip, username, "删除用户成功", f"用户: {username} (ID: {user_id})", "200")
     
        def log_visitor_login_page(self, ip):
            self._log(ip, "游客", "访问登录页面")
     
        def log_visitor_register_page(self, ip):
            self._log(ip, "游客", "访问注册页面")
     
        def clear_download_cache(self):
            if len(self.download_requests) > 1000:
                keys = list(self.download_requests.keys())
                for key in keys[:500]:
                    del self.download_requests[key]
     
     
    # -------------------- 云盘服务器 --------------------
    class CloudDriveServer:
        def __init__(self, config: dict):
            self.port = config.get('port', 8000)
            self.allow_register = config.get('allow_register', True)
            self.base_dir = Path(__file__).parent.resolve()
            self.uploads_dir = self.base_dir / 'uploads'
            self.db_path = self.base_dir / 'cloud_drive.db'
            self.logger = Logger()
            self.active_connections = set()
            self.is_running = False
            self.site = None
            self.runner = None
            self.app = None
            self.uploads_dir.mkdir(exist_ok=True)
     
        # ---------- 工具方法 ----------
        def hash_password(self, password: str, salt: bytes = None) -> Tuple[str, str]:
            if salt is None:
                salt = hashlib.sha256(os.urandom(60)).digest()
            pwdhash = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, 100000)
            return pwdhash.hex(), salt.hex()
     
        def normalize_path(self, path: str) -> str:
            if not path:
                return ''
            path = path.replace('\\', '/')
            path = re.sub(r'/+', '/', path)
            if path.startswith('/'):
                path = path[1:]
            return path.rstrip('/')
     
        def sanitize_path(self, path: str) -> str:
            if not path:
                return path
            cleaned = re.sub(r'[\x00-\x1f\x7f]', '', path)
            cleaned = re.sub(r'[<>:"|?*]', '', cleaned)
            cleaned = re.sub(r'[/\\]{2,}', '/', cleaned)
            cleaned = re.sub(r'\.{2,}', '.', cleaned)
            cleaned = cleaned.strip(' .')
            return cleaned
     
        def preprocess_path(self, raw_path: str) -> str:
            if not raw_path:
                return ''
            try:
                decoded = unquote(raw_path)
            except Exception:
                decoded = raw_path
            normalized = self.normalize_path(decoded)
            cleaned = self.sanitize_path(normalized)
            return cleaned
     
        def is_safe_path(self, base_path: Path, target_path: str) -> Tuple[bool, str]:
            try:
                base = base_path.absolute()
                processed = self.preprocess_path(target_path)
                target = (base / processed).resolve() if processed else base
                safe = target.is_relative_to(base)
                if not safe:
                    self.logger.log_error(action="路径安全检查", details=f"越权: {target_path}")
                return safe, processed
            except Exception as e:
                self.logger.log_error(action="路径安全检查", error=str(e))
                return False, ""
     
        def format_file_size(self, size_bytes: int) -> str:
            if size_bytes == 0:
                return "0 B"
            units = ['B', 'KB', 'MB', 'GB', 'TB']
            unit_index = 0
            while size_bytes >= 1024 and unit_index < len(units) - 1:
                size_bytes /= 1024.0
                unit_index += 1
            if unit_index == 0:
                return f"{int(size_bytes)} {units[unit_index]}"
            else:
                return f"{size_bytes:.1f} {units[unit_index]}"
     
        def get_user_dir(self, username: str) -> Path:
            return self.uploads_dir / username
     
        def read_template(self, template_name: str, context: Dict = None) -> str:
            template_path = self.base_dir / 'templates' / template_name
            try:
                with open(template_path, 'r', encoding='utf-8') as f:
                    content = f.read()
                if context:
                    for key, value in context.items():
                        content = content.replace(f'{{{{ {key} }}}}', str(value))
                return content
            except Exception as e:
                self.logger.log_error(action="读取模板", details=template_name, error=str(e))
                return ''
     
        # -------------------- 网络地址获取 --------------------
        @staticmethod
        def _get_ipv4_by_udp():
            try:
                s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                s.settimeout(1)
                s.connect(('8.8.8.8', 80))
                ip = s.getsockname()[0]
                s.close()
                return ip
            except Exception:
                return None
     
        @staticmethod
        def _get_ipv6_by_udp():
            try:
                s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
                s.settimeout(1)
                s.connect(('2001:4860:4860::8888', 80))
                ip = s.getsockname()[0]
                s.close()
                if ip == '::1' or ip.startswith('fe80:'):
                    return None
                return ip
            except Exception:
                return None
     
        @staticmethod
        def _get_all_ipv6_from_proc():
            ipv6s = set()
            proc_path = '/proc/net/if_inet6'
            if not os.path.isfile(proc_path):
                return ipv6s
            try:
                with open(proc_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) < 1:
                            continue
                        hex_str = parts[0]
                        ipv6 = ':'.join(hex_str[i:i+4] for i in range(0, 32, 4))
                        if ipv6 == '::1' or ipv6.startswith('fe80:'):
                            continue
                        ipv6s.add(ipv6)
            except Exception:
                pass
            return ipv6s
     
        @staticmethod
        def _get_all_ipv6_via_command():
            ipv6s = set()
            try:
                result = subprocess.run(['ip', '-6', 'addr', 'show'], capture_output=True, text=True, timeout=5)
                if result.returncode != 0:
                    result = subprocess.run(['ip', 'addr', 'show'], capture_output=True, text=True, timeout=5)
                output = result.stdout
            except Exception:
                try:
                    result = subprocess.run(['ifconfig'], capture_output=True, text=True, timeout=5)
                    output = result.stdout
                except Exception:
                    output = ""
            if output:
                for line in output.splitlines():
                    if 'inet6' not in line:
                        continue
                    m = re.search(r'inet6\s+([0-9a-fA-F:]+)', line)
                    if not m:
                        continue
                    ip = m.group(1).split('/')[0]
                    if (ip == '::1' or ip.startswith('fe80:') or
                        ip.startswith('ff') or ip.startswith('::ffff:')):
                        continue
                    if 'scope host' in line or 'scope link' in line:
                        continue
                    ipv6s.add(ip)
            return ipv6s
     
        @staticmethod
        def _get_all_ipv6_via_netifaces():
            ipv6s = set()
            try:
                import netifaces
                for interface in netifaces.interfaces():
                    addrs = netifaces.ifaddresses(interface).get(netifaces.AF_INET6, [])
                    for addr_info in addrs:
                        ip = addr_info.get('addr', '')
                        if not ip:
                            continue
                        if (ip == '::1' or ip.startswith('fe80:') or
                            ip.startswith('ff') or ip.startswith('::ffff:')):
                            continue
                        ipv6s.add(ip)
            except ImportError:
                pass
            except Exception:
                pass
            return ipv6s
     
        def get_network_addresses(self):
            ipv4_set = set()
            ipv6_set = set()
     
            # IPv4
            udp_ip = self._get_ipv4_by_udp()
            if udp_ip and not udp_ip.startswith('127.'):
                ipv4_set.add(udp_ip)
            try:
                result = subprocess.run(['ip', 'addr', 'show'], capture_output=True, text=True, timeout=5)
                if result.returncode != 0:
                    result = subprocess.run(['ifconfig'], capture_output=True, text=True, timeout=5)
                output = result.stdout
            except Exception:
                try:
                    result = subprocess.run(['ifconfig'], capture_output=True, text=True, timeout=5)
                    output = result.stdout
                except Exception:
                    output = ""
            for m in re.finditer(r'inet\s+([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)', output):
                ip = m.group(1)
                if not ip.startswith('127.'):
                    ipv4_set.add(ip)
            if not ipv4_set:
                try:
                    hostname = socket.gethostname()
                    for addr in socket.gethostbyname_ex(hostname)[2]:
                        if not addr.startswith('127.'):
                            ipv4_set.add(addr)
                except:
                    pass
     
            # IPv6
            udp_ipv6 = self._get_ipv6_by_udp()
            if udp_ipv6:
                ipv6_set.add(udp_ipv6)
            ipv6_set |= self._get_all_ipv6_from_proc()
            ipv6_set |= self._get_all_ipv6_via_command()
            ipv6_set |= self._get_all_ipv6_via_netifaces()
            try:
                hostname = socket.gethostname()
                for info in socket.getaddrinfo(hostname, None, socket.AF_INET6):
                    ip = info[4][0]
                    if (ip == '::1' or ip.startswith('fe80:') or
                        ip.startswith('ff') or ip.startswith('::ffff:')):
                        continue
                    ipv6_set.add(ip)
            except:
                pass
     
            ipv4_list = sorted(list(ipv4_set), key=lambda x: [int(n) for n in x.split('.')])
            ipv6_list = sorted(list(ipv6_set))
            return ipv4_list, ipv6_list
     
        # ---------- 数据库操作 ----------
        async def init_db(self):
            async with aiosqlite.connect(self.db_path) as db:
                await db.execute('''
                    CREATE TABLE IF NOT EXISTS users (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        username TEXT UNIQUE NOT NULL,
                        password_hash TEXT NOT NULL,
                        salt TEXT NOT NULL,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        storage_used INTEGER DEFAULT 0
                    )
                ''')
                await db.execute('''
                    CREATE TABLE IF NOT EXISTS sessions (
                        session_id TEXT PRIMARY KEY,
                        user_id INTEGER NOT NULL,
                        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                        expires_at TIMESTAMP NOT NULL,
                        FOREIGN KEY (user_id) REFERENCES users (id)
                    )
                ''')
                await db.commit()
            self.logger.log_server_event("数据库初始化完成")
     
        async def get_user_info(self, user_id: int) -> Optional[Dict]:
            async with aiosqlite.connect(self.db_path) as db:
                cursor = await db.execute(
                    'SELECT username, storage_used FROM users WHERE id = ?', (user_id,)
                )
                result = await cursor.fetchone()
                if result:
                    username, used = result
                    return {
                        'username': username,
                        'used': used,
                        'used_formatted': self.format_file_size(used),
                        'percentage': min(100, (used / (1024**3)) * 100) if used > 0 else 0
                    }
            return None
     
        async def verify_user(self, username: str, password: str) -> Optional[int]:
            async with aiosqlite.connect(self.db_path) as db:
                cursor = await db.execute(
                    'SELECT id, password_hash, salt FROM users WHERE username = ?', (username,)
                )
                user = await cursor.fetchone()
                if user:
                    user_id, stored_hash, salt = user
                    hashed, _ = self.hash_password(password, bytes.fromhex(salt))
                    if hashed == stored_hash:
                        return user_id
            return None
     
        async def create_user(self, username: str, password: str) -> bool:
            if not self.allow_register:
                return False
            try:
                password_hash, salt = self.hash_password(password)
                async with aiosqlite.connect(self.db_path) as db:
                    cursor = await db.execute('SELECT id FROM users WHERE username = ?', (username,))
                    if await cursor.fetchone():
                        return False
                    await db.execute(
                        'INSERT INTO users (username, password_hash, salt) VALUES (?, ?, ?)',
                        (username, password_hash, salt)
                    )
                    user_dir = self.get_user_dir(username)
                    await aiofiles.os.makedirs(str(user_dir), exist_ok=True)
                    await db.commit()
                    return True
            except Exception as e:
                self.logger.log_error(action="创建用户", error=str(e))
                return False
     
        async def create_session(self, user_id: int) -> str:
            session_id = str(uuid.uuid4())
            expires_at = datetime.now() + timedelta(days=1)
            async with aiosqlite.connect(self.db_path) as db:
                await db.execute(
                    'INSERT INTO sessions (session_id, user_id, expires_at) VALUES (?, ?, ?)',
                    (session_id, user_id, expires_at)
                )
                await db.commit()
            return session_id
     
        async def validate_session(self, session_id: str) -> Optional[int]:
            async with aiosqlite.connect(self.db_path) as db:
                cursor = await db.execute(
                    'SELECT user_id FROM sessions WHERE session_id = ? AND expires_at > ?',
                    (session_id, datetime.now())
                )
                result = await cursor.fetchone()
                return result[0] if result else None
     
        async def delete_session(self, session_id: str):
            async with aiosqlite.connect(self.db_path) as db:
                await db.execute('DELETE FROM sessions WHERE session_id = ?', (session_id,))
                await db.commit()
     
        async def delete_user(self, user_id: int):
            async with aiosqlite.connect(self.db_path) as db:
                cursor = await db.execute('SELECT username FROM users WHERE id = ?', (user_id,))
                result = await cursor.fetchone()
                if not result:
                    return
                username = result[0]
                user_dir = self.get_user_dir(username)
                if user_dir.exists():
                    try:
                        shutil.rmtree(str(user_dir))
                    except Exception as e:
                        self.logger.log_error(action="删除用户目录", error=str(e))
                self.logger.log_delete_user_success("127.0.0.1", username, user_id)
                await db.execute('DELETE FROM users WHERE id = ?', (user_id,))
                await db.execute('DELETE FROM sessions WHERE user_id = ?', (user_id,))
                await db.commit()
     
        # ---------- 网络请求处理 ----------
        async def check_auth(self, request):
            session_id = request.cookies.get('session_id')
            if not session_id:
                return None
            user_id = await self.validate_session(session_id)
            if not user_id:
                return None
            async with aiosqlite.connect(self.db_path) as db:
                expires_at = datetime.now() + timedelta(days=1)
                await db.execute(
                    'UPDATE sessions SET expires_at = ? WHERE session_id = ?',
                    (expires_at, session_id)
                )
                await db.commit()
            return user_id
     
        async def handle_index(self, request):
            session_id = request.cookies.get('session_id')
            if session_id and await self.validate_session(session_id):
                return web.HTTPFound('/cloud')
            return web.HTTPFound('/login')
     
        async def handle_login(self, request):
            client_ip = request.remote
            if request.method == 'GET':
                self.logger.log_visitor_login_page(client_ip)
                html = self.read_template('login.html')
                if not html:
                    return web.Response(text='模板缺失', status=500)
                return web.Response(text=html, content_type='text/html')
     
            data = await request.post()
            username = data.get('username', '').strip()
            password = data.get('password', '').strip()
            if not username or not password:
                self.logger.log_login_failed(client_ip, username, "用户名或密码为空")
                html = self.read_template('login.html', {'error': '用户名或密码不能为空'})
                return web.Response(text=html, content_type='text/html')
     
            user_id = await self.verify_user(username, password)
            if user_id:
                session_id = await self.create_session(user_id)
                resp = web.HTTPFound('/cloud')
                resp.set_cookie('session_id', session_id, httponly=True, max_age=86400)
                self.logger.log_access(client_ip, username, '/', 200)
                return resp
            else:
                self.logger.log_login_failed(client_ip, username, "密码错误")
                html = self.read_template('login.html', {'error': '用户名或密码错误'})
                return web.Response(text=html, content_type='text/html')
     
        async def handle_register(self, request):
            client_ip = request.remote
            if request.method == 'GET':
                self.logger.log_visitor_register_page(client_ip)
                if not self.allow_register:
                    html = '''
                    <!DOCTYPE html>
                    <html><head><meta charset="UTF-8"><title>注册关闭</title>
                    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
                    <style>.error-container{max-width:500px;margin:100px auto;padding:40px;text-align:center;}
                    .error-icon{font-size:64px;color:#dc3545;margin-bottom:20px;}
                    .error-title{font-size:24px;color:#333;margin-bottom:15px;}
                    .error-message{color:#666;margin-bottom:30px;}
                    .error-back-link{color:#007bff;text-decoration:none;}
                    </style></head><body class="bg-gray-100">
                    <div class="error-container"><div class="error-icon"><i class="fas fa-lock"></i></div>
                    <h1 class="error-title">注册功能已关闭</h1>
                    <p class="error-message">请联系管理员获取权限</p>
                    <a href="/" class="error-back-link"><i class="fas fa-arrow-left"></i>返回登录</a></div>
                    </body></html>'''
                    return web.Response(text=html, content_type='text/html')
                return web.Response(text=self.read_template('register.html'), content_type='text/html')
     
            data = await request.post()
            username = data.get('username', '').strip()
            password = data.get('password', '').strip()
            if not username or not password:
                return web.Response(text=self.read_template('register.html', {'error': '用户名或密码不能为空'}), content_type='text/html')
            if len(username) < 3 or len(username) > 20:
                return web.Response(text=self.read_template('register.html', {'error': '用户名长度3-20字符'}), content_type='text/html')
     
            success = await self.create_user(username, password)
            if success:
                user_id = await self.verify_user(username, password)
                session_id = await self.create_session(user_id)
                resp = web.HTTPFound('/cloud')
                resp.set_cookie('session_id', session_id, httponly=True, max_age=86400)
                self.logger.log_register_success(client_ip, username)
                return resp
            else:
                return web.Response(text=self.read_template('register.html', {'error': '用户名已存在'}), content_type='text/html')
     
        async def handle_logout(self, request):
            session_id = request.cookies.get('session_id')
            if session_id:
                await self.delete_session(session_id)
            resp = web.HTTPFound('/login')
            resp.del_cookie('session_id')
            return resp
     
        async def handle_cloud(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.HTTPFound('/login')
            user_info = await self.get_user_info(user_id)
            if not user_info:
                return web.HTTPFound('/logout')
     
            path_param = request.query.get('path', '')
            search_query = request.query.get('q', '')
            current_path = self.normalize_path(path_param)
            user_dir = self.get_user_dir(user_info['username'])
     
            if current_path and not self.is_safe_path(user_dir, current_path):
                return web.Response(text='路径不安全', status=403)
     
            current_dir = user_dir / current_path if current_path else user_dir
            items = []
     
            if search_query:
                search_query = search_query.lower()
                for root, dirs, files in os.walk(str(user_dir)):
                    rel_root = Path(root).relative_to(user_dir)
                    rel_root_posix = rel_root.as_posix() if str(rel_root) != '.' else ''
                    for d in dirs:
                        if search_query in d.lower():
                            dir_path = rel_root / d
                            try:
                                stat = (user_dir / dir_path).stat()
                                items.append({
                                    'name': d,
                                    'path': dir_path.as_posix(),
                                    'is_dir': True,
                                    'size': '-',
                                    'mtime': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M'),
                                    'search_result': True,
                                    'dir_path': rel_root_posix
                                })
                            except: pass
                    for f in files:
                        if search_query in f.lower():
                            file_path = rel_root / f
                            try:
                                stat = (user_dir / file_path).stat()
                                items.append({
                                    'name': f,
                                    'path': file_path.as_posix(),
                                    'is_dir': False,
                                    'size': self.format_file_size(stat.st_size),
                                    'mtime': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M'),
                                    'search_result': True,
                                    'dir_path': rel_root_posix
                                })
                            except: pass
            else:
                try:
                    if current_dir.exists() and current_dir.is_dir():
                        for item in current_dir.iterdir():
                            try:
                                stat = item.stat()
                                rel = item.relative_to(user_dir)
                                if item.is_dir():
                                    items.append({
                                        'name': item.name,
                                        'path': rel.as_posix(),
                                        'is_dir': True,
                                        'size': '-',
                                        'mtime': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M')
                                    })
                                else:
                                    items.append({
                                        'name': item.name,
                                        'path': rel.as_posix(),
                                        'is_dir': False,
                                        'size': self.format_file_size(stat.st_size),
                                        'mtime': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M')
                                    })
                            except: pass
                except Exception as e:
                    self.logger.log_error(action="读取目录", error=str(e))
     
            items.sort(key=lambda x: (not x['is_dir'], x['name'].lower()))
     
            breadcrumbs = [('根目录', '')]
            if current_path:
                parts = current_path.split('/')
                current = ''
                for part in parts:
                    if part:
                        current += f'{part}/'
                        breadcrumbs.append((part, current.rstrip('/')))
     
            def make_breadcrumbs_html(bcs):
                html = []
                for i, (name, path) in enumerate(bcs):
                    if i > 0:
                        html.append('<li class="separator"><i class="fas fa-chevron-right"></i></li>')
                    if i == len(bcs) - 1:
                        html.append(f'<li><span>{name}</span></li>')
                    else:
                        html.append(f'<li><a href="/cloud?path={quote(path)}">{name}</a></li>')
                return ''.join(html)
     
            def make_file_list_html(items):
                if not items:
                    return '<div class="empty-state"><i class="fas fa-folder-open"></i><h3>文件夹为空</h3></div>'
                html = ['<ul class="file-list">']
                for item in items:
                    url_path = quote(item['path'])
                    raw_path = item['path'].replace("'", "\\'")
                    name_escaped = item['name'].replace("'", "\\'")
                    icon = 'fa-folder' if item['is_dir'] else 'fa-file'
                    link = f'/cloud?path={url_path}' if item['is_dir'] else f'/download/{url_path}'
                    html.append(f'<li class="file-item">')
                    html.append(f'<div class="file-icon"><i class="fas {icon}"></i></div>')
                    html.append(f'<div class="file-info"><div class="file-name"><a href="{link}">{item["name"]}</a></div>')
                    html.append(f'<div class="file-meta"><div class="file-size">{item["size"]}</div><div class="file-time">{item["mtime"]}</div></div></div>')
                    html.append(f'<div class="file-actions">')
                    if not item['is_dir']:
                        html.append(f'<button class="file-action-btn download-btn" onclick="downloadFile(\'{url_path}\')"><i class="fas fa-download"></i></button>')
                    html.append(f'<button class="file-action-btn rename-btn" onclick="showRenameModal(\'{raw_path}\', \'{name_escaped}\')"><i class="fas fa-edit"></i></button>')
                    html.append(f'<button class="file-action-btn delete-btn" onclick="showDeleteConfirm(\'{raw_path}\', \'{name_escaped}\')"><i class="fas fa-trash-alt"></i></button>')
                    html.append('</div></li>')
                html.append('</ul>')
                return ''.join(html)
     
            context = {
                'username': user_info['username'],
                'current_path': current_path,
                'search_query': search_query,
                'used_formatted': user_info['used_formatted'],
                'percentage': user_info['percentage'],
                'breadcrumbs_html': make_breadcrumbs_html(breadcrumbs),
                'file_list_html': make_file_list_html(items)
            }
     
            html = self.read_template('cloud.html', context)
            if not html:
                return web.Response(text='模板缺失', status=500)
     
            self.logger.log_access(request.remote, user_info['username'], f'/{current_path}' if current_path else '/', 200)
            return web.Response(text=html, content_type='text/html')
     
        # -------------------- 文件操作接口 --------------------
        async def handle_download(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.Response(text='未授权', status=401)
            user_info = await self.get_user_info(user_id)
            if not user_info:
                return web.Response(text='用户不存在', status=404)
     
            filepath = unquote(request.match_info['path'])
            user_dir = self.get_user_dir(user_info['username'])
            file_path = user_dir / filepath
            safe, _ = self.is_safe_path(user_dir, filepath)
            if not safe or not file_path.exists() or not file_path.is_file():
                return web.Response(text='文件不存在', status=404)
     
            force_download = request.query.get('force', '').lower() == 'true'
            file_size = file_path.stat().st_size
            mime_type, _ = mimetypes.guess_type(str(file_path))
            mime_type = mime_type or 'application/octet-stream'
     
            client_ip = request.remote
            range_header = request.headers.get('Range')
            is_range_request = range_header is not None
            request_id = f"{client_ip}:{filepath}:{int(time.time())}"
     
            try:
                response = web.StreamResponse()
                if force_download:
                    filename = file_path.name
                    import urllib.parse
                    encoded_filename = urllib.parse.quote(filename, encoding='utf-8')
                    response.headers['Content-Disposition'] = f'attachment; filename="{encoded_filename}"'
                    response.headers['Content-Type'] = 'application/octet-stream'
                else:
                    response.headers['Content-Type'] = mime_type
                response.headers['Accept-Ranges'] = 'bytes'
     
                if range_header:
                    match = re.match(r'bytes=(\d+)-(\d*)', range_header)
                    if match:
                        start = int(match.group(1))
                        end = int(match.group(2)) if match.group(2) else file_size - 1
                        if start >= file_size or end >= file_size or start > end:
                            return web.Response(status=416)
                        response.set_status(206)
                        response.headers['Content-Range'] = f'bytes {start}-{end}/{file_size}'
                        response.headers['Content-Length'] = str(end - start + 1)
                        await response.prepare(request)
                        chunk_size = 64 * 1024
                        async with aiofiles.open(str(file_path), 'rb') as f:
                            await f.seek(start)
                            remain = end - start + 1
                            while remain > 0:
                                chunk = await f.read(min(chunk_size, remain))
                                if not chunk:
                                    break
                                try:
                                    await response.write(chunk)
                                except (ClientDisconnectedError, ConnectionResetError, asyncio.CancelledError, ConnectionError):
                                    return response
                                remain -= len(chunk)
                        await response.write_eof()
                        return response
     
                response.headers['Content-Length'] = str(file_size)
                await response.prepare(request)
                self.logger.log_download(client_ip, user_info['username'], filepath, force_download, 200, is_range_request, request_id)
     
                chunk_size = 64 * 1024
                async with aiofiles.open(str(file_path), 'rb') as f:
                    while True:
                        chunk = await f.read(chunk_size)
                        if not chunk:
                            break
                        try:
                            await response.write(chunk)
                        except (ClientDisconnectedError, ConnectionResetError, asyncio.CancelledError, ConnectionError):
                            return response
                await response.write_eof()
                return response
     
            except (ClientDisconnectedError, ConnectionResetError, asyncio.CancelledError, ConnectionError):
                return web.Response(status=204)
            except Exception as e:
                self.logger.log_error(client_ip, user_info['username'], "文件下载", filepath, str(e), is_range_request)
                return web.Response(text='下载失败', status=500)
     
        async def handle_upload(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.Response(text='未授权', status=401)
            user_info = await self.get_user_info(user_id)
            if not user_info:
                return web.Response(text='用户不存在', status=404)
            client_ip = request.remote
     
            try:
                reader = await request.multipart()
            except Exception as e:
                self.logger.log_error(client_ip, user_info['username'], "上传", "解析请求失败", str(e))
                return web.Response(text='请求解析失败', status=400)
     
            user_dir = self.get_user_dir(user_info['username'])
            upload_path = ''
            target_dir = None
            uploaded_files = []
            total_size = 0
     
            while True:
                part = await reader.next()
                if part is None:
                    break
     
                if part.name == 'path':
                    text = await part.text()
                    upload_path = self.normalize_path(text.strip())
                    target_dir = user_dir / upload_path if upload_path else user_dir
                    safe, _ = self.is_safe_path(user_dir, upload_path)
                    if not safe:
                        return web.Response(text='路径不安全', status=403)
                    try:
                        target_dir.mkdir(parents=True, exist_ok=True)
                    except Exception as e:
                        self.logger.log_error(client_ip, user_info['username'], "创建上传目录", str(e))
                        return web.Response(text='无法创建目录', status=500)
     
                elif part.filename:
                    filename = self.sanitize_path(part.filename)
                    if not filename:
                        continue
                    if target_dir is None:
                        target_dir = user_dir
                        target_dir.mkdir(parents=True, exist_ok=True)
     
                    file_path = target_dir / filename
                    counter = 1
                    name, ext = os.path.splitext(filename)
                    while file_path.exists():
                        filename = f"{name}_{counter}{ext}"
                        file_path = target_dir / filename
                        counter += 1
     
                    try:
                        async with aiofiles.open(str(file_path), 'wb') as f:
                            while True:
                                chunk = await part.read_chunk(65536)
                                if not chunk:
                                    break
                                await f.write(chunk)
                                total_size += len(chunk)
                        uploaded_files.append(filename)
                    except Exception as e:
                        self.logger.log_error(client_ip, user_info['username'], "保存文件", filename, str(e))
                        continue
     
            if uploaded_files:
                async with aiosqlite.connect(self.db_path) as db:
                    await db.execute(
                        'UPDATE users SET storage_used = storage_used + ? WHERE id = ?',
                        (total_size, user_id)
                    )
                    await db.commit()
                self.logger.log_upload(client_ip, user_info['username'], len(uploaded_files), upload_path, 200)
     
            return web.json_response({'success': True, 'uploaded_files': uploaded_files})
     
        async def handle_create_folder(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.json_response({'status': 'error', 'message': '未授权'}, status=401)
            try:
                data = await request.json()
            except:
                return web.json_response({'status': 'error', 'message': '请求格式错误'}, status=400)
            folder_name = data.get('name', '').strip()
            raw_path = data.get('path', '').strip()
            if not folder_name:
                return web.json_response({'status': 'error', 'message': '文件夹名称不能为空'})
            folder_name = self.sanitize_path(folder_name)
            if not folder_name:
                return web.json_response({'status': 'error', 'message': '文件夹名称包含非法字符'})
     
            user_info = await self.get_user_info(user_id)
            if not user_info:
                return web.json_response({'status': 'error', 'message': '用户不存在'}, status=404)
            user_dir = self.get_user_dir(user_info['username'])
            safe, parent_path = self.is_safe_path(user_dir, raw_path)
            if not safe:
                return web.json_response({'status': 'error', 'message': '路径不安全'})
     
            target_dir = user_dir / parent_path if parent_path else user_dir
            folder_path = target_dir / folder_name
            try:
                folder_path.mkdir(exist_ok=False)
                rel_path = folder_path.relative_to(user_dir).as_posix()
                full_path = f"/{rel_path}" if rel_path else "/"
                self.logger.log_create_folder(request.remote, user_info['username'], full_path, 200)
                return web.json_response({'status': 'success'})
            except FileExistsError:
                return web.json_response({'status': 'error', 'message': '文件夹已存在'})
            except Exception as e:
                self.logger.log_error(request.remote, user_info['username'], "创建文件夹", str(e))
                return web.json_response({'status': 'error', 'message': str(e)})
     
        async def handle_delete(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.json_response({'status': 'error', 'message': '未授权'}, status=401)
            try:
                data = await request.json()
            except:
                return web.json_response({'status': 'error', 'message': '请求格式错误'}, status=400)
            raw_path = data.get('path', '').strip()
            user_info = await self.get_user_info(user_id)
            if not user_info:
                return web.json_response({'status': 'error', 'message': '用户不存在'}, status=404)
            user_dir = self.get_user_dir(user_info['username'])
            safe, processed_path = self.is_safe_path(user_dir, raw_path)
            if not safe:
                return web.json_response({'status': 'error', 'message': '路径不安全'})
     
            full_path = user_dir / processed_path if processed_path else user_dir
            if not full_path.exists():
                return web.json_response({'status': 'error', 'message': '文件/文件夹不存在'})
     
            try:
                total_size = 0
                if full_path.is_file():
                    total_size = full_path.stat().st_size
                    full_path.unlink()
                else:
                    for file in full_path.rglob('*'):
                        if file.is_file():
                            total_size += file.stat().st_size
                    shutil.rmtree(str(full_path))
     
                async with aiosqlite.connect(self.db_path) as db:
                    await db.execute(
                        'UPDATE users SET storage_used = MAX(0, storage_used - ?) WHERE id = ?',
                        (total_size, user_id)
                    )
                    await db.commit()
                self.logger.log_delete(request.remote, user_info['username'], f"/{processed_path}" if processed_path else "/", 200)
                return web.json_response({'status': 'success'})
            except Exception as e:
                self.logger.log_error(request.remote, user_info['username'], "删除", str(e))
                return web.json_response({'status': 'error', 'message': str(e)})
     
        async def handle_rename(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.json_response({'status': 'error', 'message': '未授权'}, status=401)
            try:
                data = await request.json()
            except:
                return web.json_response({'status': 'error', 'message': '请求格式错误'}, status=400)
            old_path = data.get('old_path', '').strip()
            new_name = data.get('new_name', '').strip()
            if not old_path or not new_name:
                return web.json_response({'status': 'error', 'message': '参数不完整'})
            new_name = self.sanitize_path(new_name)
            if not new_name:
                return web.json_response({'status': 'error', 'message': '新名称包含非法字符'})
     
            user_info = await self.get_user_info(user_id)
            if not user_info:
                return web.json_response({'status': 'error', 'message': '用户不存在'}, status=404)
            user_dir = self.get_user_dir(user_info['username'])
            safe, processed_old = self.is_safe_path(user_dir, old_path)
            if not safe:
                return web.json_response({'status': 'error', 'message': '路径不安全'})
     
            old_full = user_dir / processed_old if processed_old else user_dir
            new_full = old_full.parent / new_name
            if not old_full.exists():
                return web.json_response({'status': 'error', 'message': '目标不存在'})
            try:
                old_full.rename(new_full)
                self.logger.log_rename(request.remote, user_info['username'],
                                       f"/{processed_old}" if processed_old else "/", new_name, 200)
                return web.json_response({'status': 'success'})
            except Exception as e:
                self.logger.log_error(request.remote, user_info['username'], "重命名", str(e))
                return web.json_response({'status': 'error', 'message': str(e)})
     
        async def handle_storage_info(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.Response(text='未授权', status=401)
            info = await self.get_user_info(user_id)
            if not info:
                return web.Response(text='用户不存在', status=404)
            return web.json_response(info)
     
        async def handle_delete_account(self, request):
            user_id = await self.check_auth(request)
            if not user_id:
                return web.Response(text='未授权', status=401)
            try:
                data = await request.json()
            except:
                return web.json_response({'status': 'error', 'message': '请求格式错误'}, status=400)
            password = data.get('password', '').strip()
            if not password:
                return web.json_response({'status': 'error', 'message': '请输入密码'})
            async with aiosqlite.connect(self.db_path) as db:
                cursor = await db.execute('SELECT username, password_hash, salt FROM users WHERE id = ?', (user_id,))
                user = await cursor.fetchone()
                if not user:
                    return web.json_response({'status': 'error', 'message': '用户不存在'})
                username, stored_hash, salt = user
                hashed, _ = self.hash_password(password, bytes.fromhex(salt))
                if hashed != stored_hash:
                    return web.json_response({'status': 'error', 'message': '密码错误'})
            await self.delete_user(user_id)
            resp = web.json_response({'status': 'success'})
            resp.del_cookie('session_id')
            return resp
     
        @web.middleware
        async def connection_tracker(self, request, handler):
            self.active_connections.add(id(request))
            try:
                return await handler(request)
            finally:
                self.active_connections.discard(id(request))
     
        # ---------- 启动/停止 ----------
        async def start_server(self):
            await self.init_db()
            self.app = web.Application(middlewares=[self.connection_tracker])
     
            self.app.router.add_get('/', self.handle_index)
            self.app.router.add_get('/login', self.handle_login)
            self.app.router.add_post('/login', self.handle_login)
            self.app.router.add_get('/register', self.handle_register)
            self.app.router.add_post('/register', self.handle_register)
            self.app.router.add_get('/logout', self.handle_logout)
            self.app.router.add_get('/cloud', self.handle_cloud)
            self.app.router.add_post('/upload', self.handle_upload)
            self.app.router.add_post('/create_folder', self.handle_create_folder)
            self.app.router.add_post('/delete', self.handle_delete)
            self.app.router.add_post('/rename', self.handle_rename)
            self.app.router.add_get('/storage_info', self.handle_storage_info)
            self.app.router.add_post('/delete_account', self.handle_delete_account)
            self.app.router.add_get('/download/{path:.+}', self.handle_download)
            self.app.router.add_static('/static', str(self.base_dir / 'static'))
     
            try:
                sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
                sock.bind(('::', self.port))
                sock.listen(100)
                self.logger.log_server_event("使用IPv6 socket(兼容IPv4)")
            except Exception:
                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                sock.bind(('0.0.0.0', self.port))
                sock.listen(100)
     
            self.runner = web.AppRunner(self.app)
            await self.runner.setup()
            self.site = web.SockSite(self.runner, sock)
            await self.site.start()
     
            ipv4_addresses, ipv6_addresses = self.get_network_addresses()
     
            self.logger.log_server_event("=" * 50)
            self.logger.log_server_event("云盘服务器启动成功!")
            self.logger.log_server_event(f"端口: {self.port}")
            self.logger.log_server_event(f"允许注册: {'是' if self.allow_register else '否'}")
            self.logger.log_server_event("访问地址:")
     
            if ipv4_addresses:
                self.logger.log_server_event("[IPv4]")
                for ip in ipv4_addresses:
                    self.logger.log_server_event(f"  http://{ip}:{self.port}/")
            if ipv6_addresses:
                self.logger.log_server_event("[IPv6]")
                for ip in ipv6_addresses:
                    self.logger.log_server_event(f"  http://[{ip}]:{self.port}/")
            self.logger.log_server_event("[本地]")
            self.logger.log_server_event(f"  http://127.0.0.1:{self.port}/")
            self.logger.log_server_event("=" * 50)
     
            self.is_running = True
     
            async def periodic_cleanup():
                while self.is_running:
                    await asyncio.sleep(300)
                    self.logger.clear_download_cache()
            asyncio.ensure_future(periodic_cleanup())
     
            while self.is_running:
                await asyncio.sleep(1)
     
        async def stop_server(self):
            self.is_running = False
            if self.active_connections:
                self.logger.log_server_event(f'等待 {len(self.active_connections)} 个连接关闭...')
                wait = 30
                for _ in range(wait):
                    if not self.active_connections:
                        break
                    await asyncio.sleep(1)
            if hasattr(self, 'site') and self.site:
                await self.site.stop()
                self.site = None
            if hasattr(self, 'runner') and self.runner:
                await self.runner.cleanup()
                self.runner = None
            if hasattr(self, 'app') and self.app:
                await self.app.cleanup()
                self.app = None
            self.logger.log_server_event('服务器已停止')
     
     
    # ---------- 配置加载 ----------
    def load_config():
        config_path = Path(__file__).parent / 'config.json'
        defaults = {
            'port': 8000,
            'allow_register': 'yes'
        }
        config = {}
        if config_path.exists():
            try:
                with open(config_path, 'r', encoding='utf-8') as f:
                    config = json.load(f)
            except Exception as e:
                print(f"配置文件读取失败,使用默认配置: {e}")
        else:
            with open(config_path, 'w', encoding='utf-8') as f:
                json.dump(defaults, f, indent=2, ensure_ascii=False)
            print(f"已生成配置文件: {config_path}")
            return defaults
     
        for key, val in defaults.items():
            if key not in config:
                config[key] = val
     
        allow = config.get('allow_register', 'yes')
        if isinstance(allow, bool):
            config['allow_register'] = allow
        elif isinstance(allow, str):
            config['allow_register'] = allow.lower() in ('yes', 'true', '1', 'y')
        else:
            config['allow_register'] = True
     
        return config
     
     
    async def main():
        config = load_config()
        server = CloudDriveServer(config)
        try:
            await server.start_server()
        except asyncio.CancelledError:
            pass
        except KeyboardInterrupt:
            pass
        except Exception as e:
            server.logger.log_error(action="服务器运行", error=str(e))
        finally:
            await server.stop_server()
     
     
    if __name__ == "__main__":
        try:
            asyncio.run(main())
        except KeyboardInterrupt:
            print("\n服务器已手动停止")
    

     

© 版权声明

相关文章

暂无评论

none
暂无评论...