ayaka.driver.ayakabot.template

  1import functools
  2from string import Formatter
  3from typing import (
  4    TYPE_CHECKING,
  5    Any,
  6    Set,
  7    Dict,
  8    List,
  9    Type,
 10    Tuple,
 11    Union,
 12    Generic,
 13    Mapping,
 14    TypeVar,
 15    Callable,
 16    Optional,
 17    Sequence,
 18    cast,
 19    overload,
 20)
 21
 22if TYPE_CHECKING:
 23    from .message import Message, MessageSegment
 24
 25TM = TypeVar("TM", bound="Message")
 26TF = TypeVar("TF", str, "Message")
 27
 28FormatSpecFunc = Callable[[Any], str]
 29FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
 30
 31
 32class MessageTemplate(Formatter, Generic[TF]):
 33    """消息模板格式化实现类。
 34
 35    参数:
 36        template: 模板
 37        factory: 消息类型工厂,默认为 `str`
 38    """
 39
 40    @overload
 41    def __init__(
 42        self: "MessageTemplate[str]", template: str, factory: Type[str] = str
 43    ) -> None:
 44        ...
 45
 46    @overload
 47    def __init__(
 48        self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
 49    ) -> None:
 50        ...
 51
 52    def __init__(self, template, factory=str) -> None:
 53        self.template: TF = template
 54        self.factory: Type[TF] = factory
 55        self.format_specs: Dict[str, FormatSpecFunc] = {}
 56
 57    def add_format_spec(
 58        self, spec: FormatSpecFunc_T, name: Optional[str] = None
 59    ) -> FormatSpecFunc_T:
 60        name = name or spec.__name__
 61        if name in self.format_specs:
 62            raise ValueError(f"Format spec {name} already exists!")
 63        self.format_specs[name] = spec
 64        return spec
 65
 66    def format(self, *args, **kwargs):
 67        """根据传入参数和模板生成消息对象"""
 68        return self._format(args, kwargs)
 69
 70    def format_map(self, mapping: Mapping[str, Any]) -> TF:
 71        """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用"""
 72        return self._format([], mapping)
 73
 74    def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
 75        msg = self.factory()
 76        if isinstance(self.template, str):
 77            msg += self.vformat(self.template, args, kwargs)
 78        elif isinstance(self.template, self.factory):
 79            template = cast("Message[MessageSegment]", self.template)
 80            for seg in template:
 81                msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
 82        else:
 83            raise TypeError("template must be a string or instance of Message!")
 84
 85        return msg  # type:ignore
 86
 87    def vformat(
 88        self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
 89    ) -> TF:
 90        used_args = set()
 91        result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
 92        self.check_unused_args(list(used_args), args, kwargs)
 93        return result
 94
 95    def _vformat(
 96        self,
 97        format_string: str,
 98        args: Sequence[Any],
 99        kwargs: Mapping[str, Any],
100        used_args: Set[Union[int, str]],
101        recursion_depth: int,
102        auto_arg_index: int = 0,
103    ) -> Tuple[TF, int]:
104        if recursion_depth < 0:
105            raise ValueError("Max string recursion exceeded")
106
107        results: List[Any] = [self.factory()]
108
109        for (literal_text, field_name, format_spec, conversion) in self.parse(
110            format_string
111        ):
112
113            # output the literal text
114            if literal_text:
115                results.append(literal_text)
116
117            # if there's a field, output it
118            if field_name is not None:
119                # this is some markup, find the object and do
120                #  the formatting
121
122                # handle arg indexing when empty field_names are given.
123                if field_name == "":
124                    if auto_arg_index is False:
125                        raise ValueError(
126                            "cannot switch from manual field specification to "
127                            "automatic field numbering"
128                        )
129                    field_name = str(auto_arg_index)
130                    auto_arg_index += 1
131                elif field_name.isdigit():
132                    if auto_arg_index:
133                        raise ValueError(
134                            "cannot switch from manual field specification to "
135                            "automatic field numbering"
136                        )
137                    # disable auto arg incrementing, if it gets
138                    # used later on, then an exception will be raised
139                    auto_arg_index = False
140
141                # given the field_name, find the object it references
142                #  and the argument it came from
143                obj, arg_used = self.get_field(field_name, args, kwargs)
144                used_args.add(arg_used)
145
146                assert format_spec is not None
147
148                # do any conversion on the resulting object
149                obj = self.convert_field(obj, conversion) if conversion else obj
150
151                # expand the format spec, if needed
152                format_control, auto_arg_index = self._vformat(
153                    format_spec,
154                    args,
155                    kwargs,
156                    used_args,
157                    recursion_depth - 1,
158                    auto_arg_index,
159                )
160
161                # format the object and append to the result
162                formatted_text = self.format_field(obj, str(format_control))
163                results.append(formatted_text)
164
165        return functools.reduce(self._add, results), auto_arg_index
166
167    def format_field(self, value: Any, format_spec: str) -> Any:
168        formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
169        if formatter is None and not issubclass(self.factory, str):
170            segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
171            method = getattr(segment_class, format_spec, None)
172            if callable(method) and not cast(str, method.__name__).startswith("_"):
173                formatter = getattr(segment_class, format_spec)
174        return (
175            super().format_field(value, format_spec)
176            if formatter is None
177            else formatter(value)
178        )
179
180    def _add(self, a: Any, b: Any) -> Any:
181        try:
182            return a + b
183        except TypeError:
184            return a + str(b)
class MessageTemplate(string.Formatter, typing.Generic[~TF]):
 33class MessageTemplate(Formatter, Generic[TF]):
 34    """消息模板格式化实现类。
 35
 36    参数:
 37        template: 模板
 38        factory: 消息类型工厂,默认为 `str`
 39    """
 40
 41    @overload
 42    def __init__(
 43        self: "MessageTemplate[str]", template: str, factory: Type[str] = str
 44    ) -> None:
 45        ...
 46
 47    @overload
 48    def __init__(
 49        self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM]
 50    ) -> None:
 51        ...
 52
 53    def __init__(self, template, factory=str) -> None:
 54        self.template: TF = template
 55        self.factory: Type[TF] = factory
 56        self.format_specs: Dict[str, FormatSpecFunc] = {}
 57
 58    def add_format_spec(
 59        self, spec: FormatSpecFunc_T, name: Optional[str] = None
 60    ) -> FormatSpecFunc_T:
 61        name = name or spec.__name__
 62        if name in self.format_specs:
 63            raise ValueError(f"Format spec {name} already exists!")
 64        self.format_specs[name] = spec
 65        return spec
 66
 67    def format(self, *args, **kwargs):
 68        """根据传入参数和模板生成消息对象"""
 69        return self._format(args, kwargs)
 70
 71    def format_map(self, mapping: Mapping[str, Any]) -> TF:
 72        """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用"""
 73        return self._format([], mapping)
 74
 75    def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
 76        msg = self.factory()
 77        if isinstance(self.template, str):
 78            msg += self.vformat(self.template, args, kwargs)
 79        elif isinstance(self.template, self.factory):
 80            template = cast("Message[MessageSegment]", self.template)
 81            for seg in template:
 82                msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg
 83        else:
 84            raise TypeError("template must be a string or instance of Message!")
 85
 86        return msg  # type:ignore
 87
 88    def vformat(
 89        self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
 90    ) -> TF:
 91        used_args = set()
 92        result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
 93        self.check_unused_args(list(used_args), args, kwargs)
 94        return result
 95
 96    def _vformat(
 97        self,
 98        format_string: str,
 99        args: Sequence[Any],
100        kwargs: Mapping[str, Any],
101        used_args: Set[Union[int, str]],
102        recursion_depth: int,
103        auto_arg_index: int = 0,
104    ) -> Tuple[TF, int]:
105        if recursion_depth < 0:
106            raise ValueError("Max string recursion exceeded")
107
108        results: List[Any] = [self.factory()]
109
110        for (literal_text, field_name, format_spec, conversion) in self.parse(
111            format_string
112        ):
113
114            # output the literal text
115            if literal_text:
116                results.append(literal_text)
117
118            # if there's a field, output it
119            if field_name is not None:
120                # this is some markup, find the object and do
121                #  the formatting
122
123                # handle arg indexing when empty field_names are given.
124                if field_name == "":
125                    if auto_arg_index is False:
126                        raise ValueError(
127                            "cannot switch from manual field specification to "
128                            "automatic field numbering"
129                        )
130                    field_name = str(auto_arg_index)
131                    auto_arg_index += 1
132                elif field_name.isdigit():
133                    if auto_arg_index:
134                        raise ValueError(
135                            "cannot switch from manual field specification to "
136                            "automatic field numbering"
137                        )
138                    # disable auto arg incrementing, if it gets
139                    # used later on, then an exception will be raised
140                    auto_arg_index = False
141
142                # given the field_name, find the object it references
143                #  and the argument it came from
144                obj, arg_used = self.get_field(field_name, args, kwargs)
145                used_args.add(arg_used)
146
147                assert format_spec is not None
148
149                # do any conversion on the resulting object
150                obj = self.convert_field(obj, conversion) if conversion else obj
151
152                # expand the format spec, if needed
153                format_control, auto_arg_index = self._vformat(
154                    format_spec,
155                    args,
156                    kwargs,
157                    used_args,
158                    recursion_depth - 1,
159                    auto_arg_index,
160                )
161
162                # format the object and append to the result
163                formatted_text = self.format_field(obj, str(format_control))
164                results.append(formatted_text)
165
166        return functools.reduce(self._add, results), auto_arg_index
167
168    def format_field(self, value: Any, format_spec: str) -> Any:
169        formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
170        if formatter is None and not issubclass(self.factory, str):
171            segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
172            method = getattr(segment_class, format_spec, None)
173            if callable(method) and not cast(str, method.__name__).startswith("_"):
174                formatter = getattr(segment_class, format_spec)
175        return (
176            super().format_field(value, format_spec)
177            if formatter is None
178            else formatter(value)
179        )
180
181    def _add(self, a: Any, b: Any) -> Any:
182        try:
183            return a + b
184        except TypeError:
185            return a + str(b)

消息模板格式化实现类。

参数: template: 模板 factory: 消息类型工厂,默认为 str

MessageTemplate(template, factory=<class 'str'>)
53    def __init__(self, template, factory=str) -> None:
54        self.template: TF = template
55        self.factory: Type[TF] = factory
56        self.format_specs: Dict[str, FormatSpecFunc] = {}
def add_format_spec( self, spec: ~FormatSpecFunc_T, name: Union[str, NoneType] = None) -> ~FormatSpecFunc_T:
58    def add_format_spec(
59        self, spec: FormatSpecFunc_T, name: Optional[str] = None
60    ) -> FormatSpecFunc_T:
61        name = name or spec.__name__
62        if name in self.format_specs:
63            raise ValueError(f"Format spec {name} already exists!")
64        self.format_specs[name] = spec
65        return spec
def format(self, *args, **kwargs):
67    def format(self, *args, **kwargs):
68        """根据传入参数和模板生成消息对象"""
69        return self._format(args, kwargs)

根据传入参数和模板生成消息对象

def format_map(self, mapping: Mapping[str, Any]) -> ~TF:
71    def format_map(self, mapping: Mapping[str, Any]) -> TF:
72        """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用"""
73        return self._format([], mapping)

根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用

def vformat( self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> ~TF:
88    def vformat(
89        self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
90    ) -> TF:
91        used_args = set()
92        result, _ = self._vformat(format_string, args, kwargs, used_args, 2)
93        self.check_unused_args(list(used_args), args, kwargs)
94        return result
def format_field(self, value: Any, format_spec: str) -> Any:
168    def format_field(self, value: Any, format_spec: str) -> Any:
169        formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
170        if formatter is None and not issubclass(self.factory, str):
171            segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
172            method = getattr(segment_class, format_spec, None)
173            if callable(method) and not cast(str, method.__name__).startswith("_"):
174                formatter = getattr(segment_class, format_spec)
175        return (
176            super().format_field(value, format_spec)
177            if formatter is None
178            else formatter(value)
179        )
Inherited Members
string.Formatter
get_value
check_unused_args
convert_field
parse
get_field