ayaka.driver.ayakabot_driver.template
1import functools 2from string import Formatter 3from typing import ( 4 TYPE_CHECKING, 5 Any, 6 Set, 7 Dict, 8 List, 9 Type, 10 Tuple, 11 Union, 12 Generic, 13 Mapping, 14 TypeVar, 15 Callable, 16 Optional, 17 Sequence, 18 cast, 19 overload, 20) 21 22if TYPE_CHECKING: 23 from .message import Message, MessageSegment 24 25TM = TypeVar("TM", bound="Message") 26TF = TypeVar("TF", str, "Message") 27 28FormatSpecFunc = Callable[[Any], str] 29FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc) 30 31 32class MessageTemplate(Formatter, Generic[TF]): 33 """消息模板格式化实现类。 34 35 参数: 36 template: 模板 37 factory: 消息类型工厂,默认为 `str` 38 """ 39 40 @overload 41 def __init__( 42 self: "MessageTemplate[str]", template: str, factory: Type[str] = str 43 ) -> None: 44 ... 45 46 @overload 47 def __init__( 48 self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM] 49 ) -> None: 50 ... 51 52 def __init__(self, template, factory=str) -> None: 53 self.template: TF = template 54 self.factory: Type[TF] = factory 55 self.format_specs: Dict[str, FormatSpecFunc] = {} 56 57 def add_format_spec( 58 self, spec: FormatSpecFunc_T, name: Optional[str] = None 59 ) -> FormatSpecFunc_T: 60 name = name or spec.__name__ 61 if name in self.format_specs: 62 raise ValueError(f"Format spec {name} already exists!") 63 self.format_specs[name] = spec 64 return spec 65 66 def format(self, *args, **kwargs): 67 """根据传入参数和模板生成消息对象""" 68 return self._format(args, kwargs) 69 70 def format_map(self, mapping: Mapping[str, Any]) -> TF: 71 """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用""" 72 return self._format([], mapping) 73 74 def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF: 75 msg = self.factory() 76 if isinstance(self.template, str): 77 msg += self.vformat(self.template, args, kwargs) 78 elif isinstance(self.template, self.factory): 79 template = cast("Message[MessageSegment]", self.template) 80 for seg in template: 81 msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg 82 else: 83 raise TypeError("template must be a string or instance of Message!") 84 85 return msg # type:ignore 86 87 def vformat( 88 self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any] 89 ) -> TF: 90 used_args = set() 91 result, _ = self._vformat(format_string, args, kwargs, used_args, 2) 92 self.check_unused_args(list(used_args), args, kwargs) 93 return result 94 95 def _vformat( 96 self, 97 format_string: str, 98 args: Sequence[Any], 99 kwargs: Mapping[str, Any], 100 used_args: Set[Union[int, str]], 101 recursion_depth: int, 102 auto_arg_index: int = 0, 103 ) -> Tuple[TF, int]: 104 if recursion_depth < 0: 105 raise ValueError("Max string recursion exceeded") 106 107 results: List[Any] = [self.factory()] 108 109 for (literal_text, field_name, format_spec, conversion) in self.parse( 110 format_string 111 ): 112 113 # output the literal text 114 if literal_text: 115 results.append(literal_text) 116 117 # if there's a field, output it 118 if field_name is not None: 119 # this is some markup, find the object and do 120 # the formatting 121 122 # handle arg indexing when empty field_names are given. 123 if field_name == "": 124 if auto_arg_index is False: 125 raise ValueError( 126 "cannot switch from manual field specification to " 127 "automatic field numbering" 128 ) 129 field_name = str(auto_arg_index) 130 auto_arg_index += 1 131 elif field_name.isdigit(): 132 if auto_arg_index: 133 raise ValueError( 134 "cannot switch from manual field specification to " 135 "automatic field numbering" 136 ) 137 # disable auto arg incrementing, if it gets 138 # used later on, then an exception will be raised 139 auto_arg_index = False 140 141 # given the field_name, find the object it references 142 # and the argument it came from 143 obj, arg_used = self.get_field(field_name, args, kwargs) 144 used_args.add(arg_used) 145 146 assert format_spec is not None 147 148 # do any conversion on the resulting object 149 obj = self.convert_field(obj, conversion) if conversion else obj 150 151 # expand the format spec, if needed 152 format_control, auto_arg_index = self._vformat( 153 format_spec, 154 args, 155 kwargs, 156 used_args, 157 recursion_depth - 1, 158 auto_arg_index, 159 ) 160 161 # format the object and append to the result 162 formatted_text = self.format_field(obj, str(format_control)) 163 results.append(formatted_text) 164 165 return functools.reduce(self._add, results), auto_arg_index 166 167 def format_field(self, value: Any, format_spec: str) -> Any: 168 formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) 169 if formatter is None and not issubclass(self.factory, str): 170 segment_class: Type["MessageSegment"] = self.factory.get_segment_class() 171 method = getattr(segment_class, format_spec, None) 172 if callable(method) and not cast(str, method.__name__).startswith("_"): 173 formatter = getattr(segment_class, format_spec) 174 return ( 175 super().format_field(value, format_spec) 176 if formatter is None 177 else formatter(value) 178 ) 179 180 def _add(self, a: Any, b: Any) -> Any: 181 try: 182 return a + b 183 except TypeError: 184 return a + str(b)
class
MessageTemplate(string.Formatter, typing.Generic[~TF]):
33class MessageTemplate(Formatter, Generic[TF]): 34 """消息模板格式化实现类。 35 36 参数: 37 template: 模板 38 factory: 消息类型工厂,默认为 `str` 39 """ 40 41 @overload 42 def __init__( 43 self: "MessageTemplate[str]", template: str, factory: Type[str] = str 44 ) -> None: 45 ... 46 47 @overload 48 def __init__( 49 self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM] 50 ) -> None: 51 ... 52 53 def __init__(self, template, factory=str) -> None: 54 self.template: TF = template 55 self.factory: Type[TF] = factory 56 self.format_specs: Dict[str, FormatSpecFunc] = {} 57 58 def add_format_spec( 59 self, spec: FormatSpecFunc_T, name: Optional[str] = None 60 ) -> FormatSpecFunc_T: 61 name = name or spec.__name__ 62 if name in self.format_specs: 63 raise ValueError(f"Format spec {name} already exists!") 64 self.format_specs[name] = spec 65 return spec 66 67 def format(self, *args, **kwargs): 68 """根据传入参数和模板生成消息对象""" 69 return self._format(args, kwargs) 70 71 def format_map(self, mapping: Mapping[str, Any]) -> TF: 72 """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用""" 73 return self._format([], mapping) 74 75 def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF: 76 msg = self.factory() 77 if isinstance(self.template, str): 78 msg += self.vformat(self.template, args, kwargs) 79 elif isinstance(self.template, self.factory): 80 template = cast("Message[MessageSegment]", self.template) 81 for seg in template: 82 msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg 83 else: 84 raise TypeError("template must be a string or instance of Message!") 85 86 return msg # type:ignore 87 88 def vformat( 89 self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any] 90 ) -> TF: 91 used_args = set() 92 result, _ = self._vformat(format_string, args, kwargs, used_args, 2) 93 self.check_unused_args(list(used_args), args, kwargs) 94 return result 95 96 def _vformat( 97 self, 98 format_string: str, 99 args: Sequence[Any], 100 kwargs: Mapping[str, Any], 101 used_args: Set[Union[int, str]], 102 recursion_depth: int, 103 auto_arg_index: int = 0, 104 ) -> Tuple[TF, int]: 105 if recursion_depth < 0: 106 raise ValueError("Max string recursion exceeded") 107 108 results: List[Any] = [self.factory()] 109 110 for (literal_text, field_name, format_spec, conversion) in self.parse( 111 format_string 112 ): 113 114 # output the literal text 115 if literal_text: 116 results.append(literal_text) 117 118 # if there's a field, output it 119 if field_name is not None: 120 # this is some markup, find the object and do 121 # the formatting 122 123 # handle arg indexing when empty field_names are given. 124 if field_name == "": 125 if auto_arg_index is False: 126 raise ValueError( 127 "cannot switch from manual field specification to " 128 "automatic field numbering" 129 ) 130 field_name = str(auto_arg_index) 131 auto_arg_index += 1 132 elif field_name.isdigit(): 133 if auto_arg_index: 134 raise ValueError( 135 "cannot switch from manual field specification to " 136 "automatic field numbering" 137 ) 138 # disable auto arg incrementing, if it gets 139 # used later on, then an exception will be raised 140 auto_arg_index = False 141 142 # given the field_name, find the object it references 143 # and the argument it came from 144 obj, arg_used = self.get_field(field_name, args, kwargs) 145 used_args.add(arg_used) 146 147 assert format_spec is not None 148 149 # do any conversion on the resulting object 150 obj = self.convert_field(obj, conversion) if conversion else obj 151 152 # expand the format spec, if needed 153 format_control, auto_arg_index = self._vformat( 154 format_spec, 155 args, 156 kwargs, 157 used_args, 158 recursion_depth - 1, 159 auto_arg_index, 160 ) 161 162 # format the object and append to the result 163 formatted_text = self.format_field(obj, str(format_control)) 164 results.append(formatted_text) 165 166 return functools.reduce(self._add, results), auto_arg_index 167 168 def format_field(self, value: Any, format_spec: str) -> Any: 169 formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) 170 if formatter is None and not issubclass(self.factory, str): 171 segment_class: Type["MessageSegment"] = self.factory.get_segment_class() 172 method = getattr(segment_class, format_spec, None) 173 if callable(method) and not cast(str, method.__name__).startswith("_"): 174 formatter = getattr(segment_class, format_spec) 175 return ( 176 super().format_field(value, format_spec) 177 if formatter is None 178 else formatter(value) 179 ) 180 181 def _add(self, a: Any, b: Any) -> Any: 182 try: 183 return a + b 184 except TypeError: 185 return a + str(b)
消息模板格式化实现类。
参数:
template: 模板
factory: 消息类型工厂,默认为 str
def
add_format_spec( self, spec: ~FormatSpecFunc_T, name: Union[str, NoneType] = None) -> ~FormatSpecFunc_T:
def
format_map(self, mapping: Mapping[str, Any]) -> ~TF:
71 def format_map(self, mapping: Mapping[str, Any]) -> TF: 72 """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用""" 73 return self._format([], mapping)
根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用
def
format_field(self, value: Any, format_spec: str) -> Any:
168 def format_field(self, value: Any, format_spec: str) -> Any: 169 formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) 170 if formatter is None and not issubclass(self.factory, str): 171 segment_class: Type["MessageSegment"] = self.factory.get_segment_class() 172 method = getattr(segment_class, format_spec, None) 173 if callable(method) and not cast(str, method.__name__).startswith("_"): 174 formatter = getattr(segment_class, format_spec) 175 return ( 176 super().format_field(value, format_spec) 177 if formatter is None 178 else formatter(value) 179 )
Inherited Members
- string.Formatter
- get_value
- check_unused_args
- convert_field
- parse
- get_field