Skip to content

protocols.py

文件信息

  • 📄 原文件:03_protocols.py
  • 🔤 语言:python

Python 协议与鸭子类型 本文件介绍 Python 中的协议(Protocol)、鸭子类型和结构子类型。

"如果它走起来像鸭子,叫起来也像鸭子,那它就是鸭子。"

完整代码

python
from typing import Protocol, runtime_checkable, Iterable, Iterator
from abc import ABC, abstractmethod


def main01_duck_typing():
    """
    ============================================================
                    1. 鸭子类型
    ============================================================
    Python 是动态类型语言,关注对象的行为而非类型
    """
    print("=" * 60)
    print("1. 鸭子类型")
    print("=" * 60)

    # 【鸭子类型基本概念】
    class Duck:
        def quack(self):
            return "嘎嘎!"

        def walk(self):
            return "摇摇摆摆走"

    class Person:
        def quack(self):
            return "我在模仿鸭子叫:嘎嘎!"

        def walk(self):
            return "正常走路"

    class Robot:
        def quack(self):
            return "合成声音:嘎嘎"

        def walk(self):
            return "机械行走"

    def make_it_quack(thing):
        """不关心类型,只关心有没有 quack 方法"""
        print(f"  {thing.__class__.__name__}: {thing.quack()}")

    print("鸭子类型演示:")
    for obj in [Duck(), Person(), Robot()]:
        make_it_quack(obj)

    # 【文件类对象】
    print(f"\n--- 文件类对象 ---")

    class StringWriter:
        """像文件一样的字符串写入器"""
        def __init__(self):
            self.content = []

        def write(self, text):
            self.content.append(text)

        def read(self):
            return ''.join(self.content)

    def write_greeting(file_like):
        """接受任何有 write 方法的对象"""
        file_like.write("Hello, ")
        file_like.write("World!")

    sw = StringWriter()
    write_greeting(sw)
    print(f"StringWriter 内容: {sw.read()}")


def main02_builtin_protocols():
    """
    ============================================================
                2. 内置协议(特殊方法)
    ============================================================
    Python 通过特殊方法实现内置协议
    """
    print("\n" + "=" * 60)
    print("2. 内置协议(特殊方法)")
    print("=" * 60)

    # 【可迭代协议】__iter__
    print("--- 可迭代协议 ---")

    class Countdown:
        def __init__(self, start):
            self.start = start

        def __iter__(self):
            n = self.start
            while n > 0:
                yield n
                n -= 1

    print(f"Countdown(5): {list(Countdown(5))}")

    # 【序列协议】__len__, __getitem__
    print(f"\n--- 序列协议 ---")

    class Sentence:
        def __init__(self, text):
            self.words = text.split()

        def __len__(self):
            return len(self.words)

        def __getitem__(self, index):
            return self.words[index]

    s = Sentence("Hello World Python")
    print(f"len(s) = {len(s)}")
    print(f"s[0] = {s[0]}")
    print(f"s[-1] = {s[-1]}")
    print(f"for word in s: {[w for w in s]}")  # 自动支持迭代

    # 【上下文管理协议】__enter__, __exit__
    print(f"\n--- 上下文管理协议 ---")

    class Timer:
        def __enter__(self):
            import time
            self.start = time.perf_counter()
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            import time
            self.elapsed = time.perf_counter() - self.start
            print(f"  耗时: {self.elapsed*1000:.2f}ms")
            return False

    with Timer():
        sum(range(100000))

    # 【可调用协议】__call__
    print(f"\n--- 可调用协议 ---")

    class Multiplier:
        def __init__(self, factor):
            self.factor = factor

        def __call__(self, x):
            return x * self.factor

    double = Multiplier(2)
    triple = Multiplier(3)
    print(f"double(5) = {double(5)}")
    print(f"triple(5) = {triple(5)}")

    # 【哈希协议】__hash__, __eq__
    print(f"\n--- 哈希协议 ---")

    class Point:
        def __init__(self, x, y):
            self.x = x
            self.y = y

        def __hash__(self):
            return hash((self.x, self.y))

        def __eq__(self, other):
            return self.x == other.x and self.y == other.y

    p1 = Point(1, 2)
    p2 = Point(1, 2)
    p3 = Point(3, 4)

    # 可以用作字典键和集合元素
    points = {p1: "origin", p3: "other"}
    print(f"p1 in points: {p1 in points}")
    print(f"p2 in points: {p2 in points}")  # True,因为 p1 == p2


def main03_typing_protocol():
    """
    ============================================================
                3. typing.Protocol(结构子类型)
    ============================================================
    Python 3.8+ 引入的正式协议支持
    """
    print("\n" + "=" * 60)
    print("3. typing.Protocol(结构子类型)")
    print("=" * 60)

    # 【定义协议】
    class Drawable(Protocol):
        """可绘制协议"""
        def draw(self) -> str:
            ...

    class Resizable(Protocol):
        """可调整大小协议"""
        def resize(self, factor: float) -> None:
            ...

    # 【实现协议(隐式)】
    class Circle:
        def __init__(self, radius: float):
            self.radius = radius

        def draw(self) -> str:
            return f"绘制圆形,半径={self.radius}"

        def resize(self, factor: float) -> None:
            self.radius *= factor

    class Rectangle:
        def __init__(self, width: float, height: float):
            self.width = width
            self.height = height

        def draw(self) -> str:
            return f"绘制矩形,{self.width}x{self.height}"

        def resize(self, factor: float) -> None:
            self.width *= factor
            self.height *= factor

    # 【使用协议作为类型提示】
    def render(shape: Drawable) -> None:
        print(f"  {shape.draw()}")

    def scale_up(shape: Resizable) -> None:
        shape.resize(2.0)

    print("协议类型检查(类型检查器使用):")
    circle = Circle(5)
    rect = Rectangle(4, 3)

    render(circle)
    render(rect)

    # 【运行时可检查的协议】
    print(f"\n--- 运行时检查 ---")

    @runtime_checkable
    class Speakable(Protocol):
        def speak(self) -> str:
            ...

    class Dog:
        def speak(self) -> str:
            return "汪!"

    class Cat:
        def speak(self) -> str:
            return "喵!"

    class Rock:
        pass

    dog = Dog()
    rock = Rock()

    print(f"isinstance(dog, Speakable): {isinstance(dog, Speakable)}")
    print(f"isinstance(rock, Speakable): {isinstance(rock, Speakable)}")


def main04_protocol_vs_abc():
    """
    ============================================================
                4. Protocol vs ABC
    ============================================================
    """
    print("\n" + "=" * 60)
    print("4. Protocol vs ABC")
    print("=" * 60)

    # 【ABC:名义类型】需要显式继承
    class AnimalABC(ABC):
        @abstractmethod
        def speak(self) -> str:
            pass

    class DogABC(AnimalABC):  # 必须继承
        def speak(self) -> str:
            return "汪!"

    # 【Protocol:结构类型】不需要继承
    class AnimalProtocol(Protocol):
        def speak(self) -> str:
            ...

    class DogProtocol:  # 不需要继承,只要实现方法
        def speak(self) -> str:
            return "汪!"

    print("""
    ABC(抽象基类):
    - 需要显式继承
    - 运行时强制检查
    - 适合定义接口规范

    Protocol(协议):
    - 不需要继承(结构子类型)
    - 主要用于静态类型检查
    - 更灵活,适合鸭子类型
    """)

    # 【组合使用】
    @runtime_checkable
    class Closeable(Protocol):
        def close(self) -> None:
            ...

    class FileWrapper:
        def __init__(self, name):
            self.name = name

        def close(self) -> None:
            print(f"    关闭 {self.name}")

    def cleanup(resource: Closeable) -> None:
        resource.close()

    print("组合使用:")
    fw = FileWrapper("test.txt")
    print(f"  isinstance(fw, Closeable): {isinstance(fw, Closeable)}")
    cleanup(fw)


def main05_common_protocols():
    """
    ============================================================
                5. 常用协议示例
    ============================================================
    """
    print("\n" + "=" * 60)
    print("5. 常用协议示例")
    print("=" * 60)

    # 【Comparable 协议】
    print("--- Comparable 协议 ---")

    class Comparable(Protocol):
        def __lt__(self, other) -> bool: ...
        def __le__(self, other) -> bool: ...
        def __gt__(self, other) -> bool: ...
        def __ge__(self, other) -> bool: ...

    class Score:
        def __init__(self, value: int):
            self.value = value

        def __lt__(self, other: 'Score') -> bool:
            return self.value < other.value

        def __le__(self, other: 'Score') -> bool:
            return self.value <= other.value

        def __gt__(self, other: 'Score') -> bool:
            return self.value > other.value

        def __ge__(self, other: 'Score') -> bool:
            return self.value >= other.value

        def __repr__(self):
            return f"Score({self.value})"

    scores = [Score(85), Score(92), Score(78)]
    print(f"排序前: {scores}")
    print(f"排序后: {sorted(scores)}")

    # 【Hashable 协议】
    print(f"\n--- Hashable 协议 ---")

    @runtime_checkable
    class Hashable(Protocol):
        def __hash__(self) -> int: ...

    class ImmutablePoint:
        def __init__(self, x: int, y: int):
            self._x = x
            self._y = y

        def __hash__(self) -> int:
            return hash((self._x, self._y))

        def __eq__(self, other) -> bool:
            return self._x == other._x and self._y == other._y

    p = ImmutablePoint(1, 2)
    print(f"isinstance(p, Hashable): {isinstance(p, Hashable)}")
    print(f"可以用作字典键: {{{p}: 'point'}}")

    # 【SupportsAdd 协议】
    print(f"\n--- SupportsAdd 协议 ---")

    class SupportsAdd(Protocol):
        def __add__(self, other): ...

    def double(x: SupportsAdd):
        return x + x

    print(f"double(5) = {double(5)}")
    print(f"double('hello') = {double('hello')}")
    print(f"double([1, 2]) = {double([1, 2])}")


def main06_generic_protocols():
    """
    ============================================================
                6. 泛型协议
    ============================================================
    """
    print("\n" + "=" * 60)
    print("6. 泛型协议")
    print("=" * 60)

    from typing import TypeVar, Generic

    T = TypeVar('T')
    T_co = TypeVar('T_co', covariant=True)  # 协变

    # 【泛型协议】
    class Container(Protocol[T]):
        def get(self) -> T: ...
        def set(self, value: T) -> None: ...

    class Box(Generic[T]):
        def __init__(self, value: T):
            self._value = value

        def get(self) -> T:
            return self._value

        def set(self, value: T) -> None:
            self._value = value

    box: Container[int] = Box(42)
    print(f"box.get() = {box.get()}")
    box.set(100)
    print(f"box.get() = {box.get()}")

    # 【协变协议】
    print(f"\n--- 协变协议 ---")

    class Reader(Protocol[T_co]):
        def read(self) -> T_co: ...

    class StringReader:
        def read(self) -> str:
            return "Hello"

    def process_reader(reader: Reader[str]) -> str:
        return reader.read()

    sr = StringReader()
    print(f"process_reader(sr) = {process_reader(sr)}")


def main07_practical_examples():
    """
    ============================================================
                7. 实际应用示例
    ============================================================
    """
    print("\n" + "=" * 60)
    print("7. 实际应用示例")
    print("=" * 60)

    # 【存储协议】
    print("--- 存储协议 ---")

    class Storage(Protocol):
        def save(self, key: str, data: dict) -> None: ...
        def load(self, key: str) -> dict | None: ...
        def delete(self, key: str) -> None: ...

    class MemoryStorage:
        def __init__(self):
            self._data = {}

        def save(self, key: str, data: dict) -> None:
            self._data[key] = data
            print(f"    [Memory] 保存: {key}")

        def load(self, key: str) -> dict | None:
            return self._data.get(key)

        def delete(self, key: str) -> None:
            self._data.pop(key, None)
            print(f"    [Memory] 删除: {key}")

    class FileStorage:
        def __init__(self, directory: str):
            self.directory = directory

        def save(self, key: str, data: dict) -> None:
            print(f"    [File] 保存到 {self.directory}/{key}.json")

        def load(self, key: str) -> dict | None:
            print(f"    [File] 从 {self.directory}/{key}.json 加载")
            return {"mock": "data"}

        def delete(self, key: str) -> None:
            print(f"    [File] 删除 {self.directory}/{key}.json")

    class UserService:
        def __init__(self, storage: Storage):
            self.storage = storage

        def create_user(self, user_id: str, name: str) -> None:
            self.storage.save(user_id, {"name": name})

        def get_user(self, user_id: str) -> dict | None:
            return self.storage.load(user_id)

    # 可以轻松切换存储实现
    print("使用内存存储:")
    service1 = UserService(MemoryStorage())
    service1.create_user("1", "Alice")

    print("\n使用文件存储:")
    service2 = UserService(FileStorage("/data"))
    service2.create_user("1", "Alice")

    # 【日志协议】
    print(f"\n--- 日志协议 ---")

    class Logger(Protocol):
        def info(self, msg: str) -> None: ...
        def error(self, msg: str) -> None: ...

    class ConsoleLogger:
        def info(self, msg: str) -> None:
            print(f"    [INFO] {msg}")

        def error(self, msg: str) -> None:
            print(f"    [ERROR] {msg}")

    class FileLogger:
        def __init__(self, filename: str):
            self.filename = filename

        def info(self, msg: str) -> None:
            print(f"    [INFO -> {self.filename}] {msg}")

        def error(self, msg: str) -> None:
            print(f"    [ERROR -> {self.filename}] {msg}")

    def process_data(data: list, logger: Logger) -> None:
        logger.info(f"开始处理 {len(data)} 条数据")
        # 处理逻辑
        logger.info("处理完成")

    print("使用控制台日志:")
    process_data([1, 2, 3], ConsoleLogger())

    print("\n使用文件日志:")
    process_data([1, 2, 3], FileLogger("app.log"))


if __name__ == "__main__":
    main01_duck_typing()
    main02_builtin_protocols()
    main03_typing_protocol()
    main04_protocol_vs_abc()
    main05_common_protocols()
    main06_generic_protocols()
    main07_practical_examples()

💬 讨论

使用 GitHub 账号登录后即可参与讨论

基于 MIT 许可发布