microcore.utils

  1import asyncio
  2import builtins
  3import dataclasses
  4import inspect
  5import json
  6import os
  7import sys
  8import re
  9import subprocess
 10from dataclasses import dataclass
 11from fnmatch import fnmatch
 12from pathlib import Path
 13from typing import Any, Union, Callable
 14
 15import tiktoken
 16from colorama import Fore
 17
 18from .configuration import Config
 19from .types import BadAIAnswer
 20from .message_types import UserMsg, SysMsg, AssistantMsg
 21
 22
 23def is_chat_model(model: str, config: Config = None) -> bool:
 24    """Detects if model is chat model or text completion model"""
 25    if config and config.CHAT_MODE is not None:
 26        return config.CHAT_MODE
 27    completion_keywords = ["instruct", "davinci", "babbage", "curie", "ada"]
 28    return not any(keyword in str(model).lower() for keyword in completion_keywords)
 29
 30
 31class ConvertableToMessage:
 32    @property
 33    def as_user(self) -> UserMsg:
 34        return UserMsg(str(self))
 35
 36    @property
 37    def as_system(self) -> SysMsg:
 38        return SysMsg(str(self))
 39
 40    @property
 41    def as_assistant(self) -> AssistantMsg:
 42        return AssistantMsg(str(self))
 43
 44    @property
 45    def as_model(self) -> AssistantMsg:
 46        return self.as_assistant
 47
 48
 49class ExtendedString(str):
 50    """
 51    Provides a way of extending string with attributes and methods
 52    """
 53
 54    def __new__(cls, string: str, attrs: dict = None):
 55        """
 56        Allows string to have attributes.
 57        """
 58        obj = str.__new__(cls, string)
 59        if attrs:
 60            for k, v in attrs.items():
 61                setattr(obj, k, v)
 62        return obj
 63
 64    def __getattr__(self, item):
 65        """
 66        Provides chaining of global functions
 67        """
 68        global_func = inspect.currentframe().f_back.f_globals.get(item) or vars(
 69            builtins
 70        ).get(item, None)
 71        if callable(global_func):
 72
 73            def method_handler(*args, **kwargs):
 74                res = global_func(self, *args, **kwargs)
 75                if isinstance(res, str) and not isinstance(res, ExtendedString):
 76                    res = ExtendedString(res)
 77                return res
 78
 79            return method_handler
 80
 81        # If there's not a global function with that name, raise an AttributeError as usual
 82        raise AttributeError(
 83            f"'{self.__class__.__name__}' object has no attribute '{item}'"
 84        )
 85
 86    def to_tokens(
 87        self,
 88        for_model: str = None,
 89        encoding: str | tiktoken.Encoding = None
 90    ):
 91        from .tokenizing import encode
 92        return encode(self, for_model=for_model, encoding=encoding)
 93
 94    def num_tokens(
 95        self,
 96        for_model: str = None,
 97        encoding: str | tiktoken.Encoding = None
 98    ):
 99        return len(self.to_tokens(for_model=for_model, encoding=encoding))
100
101
102class DataclassEncoder(json.JSONEncoder):
103    """@private"""
104
105    def default(self, o):
106        if dataclasses.is_dataclass(o):
107            return dataclasses.asdict(o)
108        return _default(self, o)
109
110
111_default = json.JSONEncoder.default
112json.JSONEncoder.default = DataclassEncoder().default
113
114
115def parse(
116    text: str, field_format: str = r"\[\[(.*?)\]\]", required_fields: list = None
117) -> dict:
118    """
119    Parse a document divided into sections and convert it into a dictionary.
120    """
121    pattern = rf"{field_format}\n(.*?)(?=\n{field_format}|$)"
122    matches = re.findall(pattern, text, re.DOTALL)
123    result = {key.strip().lower(): value for key, value, _ in matches}
124    if required_fields:
125        for field in required_fields:
126            if field not in result:
127                raise BadAIAnswer(f"Field '{field}' is required but not found")
128    return result
129
130
131def file_link(file_path: str | Path):
132    """Returns file name in format displayed in PyCharm console as a link."""
133    return "file:///" + str(Path(file_path).absolute()).replace("\\", "/")
134
135
136def list_files(
137    target_dir: str | Path = "",
138    exclude: list[str | Path] = None,
139    relative_to: str | Path = None,
140    absolute: bool = False,
141    posix: bool = False,
142) -> list[Path]:
143    """
144    Lists files in a specified directory, excluding those that match given patterns.
145
146    This function traverses the specified directory recursively and returns a list of all files
147    that do not match the specified exclusion patterns. It can return absolute paths,
148    paths relative to the target directory, or paths relative to a specified directory.
149
150    Args:
151        target_dir (str | Path): The directory to search in.
152        exclude (list[str | Path]): Patterns of files to exclude.
153        relative_to (str | Path, optional): Base directory for relative paths.
154            If None, paths are relative to `target_dir`. Defaults to None.
155        absolute (bool, optional): If True, returns absolute paths. Defaults to False.
156        posix (bool, optional): If True, returns posix paths. Defaults to False.
157
158    Returns:
159        list[Path]: A list of Path objects representing the files found.
160
161    Example:
162        exclude_patterns = ['*.pyc', '__pycache__/*']
163        target_directory = '/path/to/target'
164        files = list_files(target_directory, exclude_patterns)
165    """
166    exclude = exclude or []
167    target = Path(target_dir or os.getcwd()).resolve()
168    relative_to = Path(relative_to).resolve() if relative_to else None
169    if absolute and relative_to is not None:
170        raise ValueError(
171            "list_files(): Cannot specify both 'absolute' and 'relative_to'. Choose one."
172        )
173    return [
174        p.as_posix() if posix else p
175        for p in (
176            path.resolve() if absolute else path.relative_to(relative_to or target)
177            for path in target.rglob("*")
178            if path.is_file()
179            and not any(
180                fnmatch(str(path.relative_to(target)), str(pattern))
181                for pattern in exclude
182            )
183        )
184    ]
185
186
187def is_kaggle() -> bool:
188    return "KAGGLE_KERNEL_RUN_TYPE" in os.environ
189
190
191def is_notebook() -> bool:
192    return "ipykernel" in sys.modules
193
194
195def is_google_colab() -> bool:
196    return "google.colab" in sys.modules
197
198
199def get_vram_usage(as_string=True, color=Fore.GREEN):
200    @dataclass
201    class _MemUsage:
202        name: str
203        used: int
204        free: int
205        total: int
206
207    cmd = (
208        "nvidia-smi"
209        " --query-gpu=name,memory.used,memory.free,memory.total"
210        " --format=csv,noheader,nounits"
211    )
212    try:
213        out = subprocess.check_output(cmd, shell=True, text=True).strip()
214
215        mu = [
216            _MemUsage(*[i.strip() for i in line.split(",")])
217            for line in out.splitlines()
218        ]
219        if not as_string:
220            return mu
221        c, r = (color, Fore.RESET) if color else ("", "")
222        return "\n".join(
223            [
224                f"GPU: {c}{i.name}{r}, "
225                f"VRAM: {c}{i.used}{r}/{c}{i.total}{r} MiB used, "
226                f"{c}{i.free}{r} MiB free"
227                for i in mu
228            ]
229        )
230    except subprocess.CalledProcessError:
231        msg = "No GPU found or nvidia-smi is not installed"
232        return f"{Fore.RED}{msg}{Fore.RESET}" if as_string else None
233
234
235def show_vram_usage():
236    """Prints GPU VRAM usage."""
237    print(get_vram_usage(as_string=True))
238
239
240def return_default(default, *args):
241    if isinstance(default, type) and issubclass(default, BaseException):
242        raise default()
243    if isinstance(default, BaseException):
244        raise default
245    if inspect.isbuiltin(default):
246        return default(*args)
247    if inspect.isfunction(default):
248        arg_count = default.__code__.co_argcount
249        if inspect.ismethod(default):
250            arg_count -= 1
251        return default(*args) if arg_count >= len(args) else default()
252
253    return default
254
255
256def extract_number(
257    text: str,
258    default=None,
259    position="last",
260    dtype: type | str = float,
261    rounding: bool = False,
262) -> int | float | Any:
263    """
264    Extract a number from a string.
265    """
266    assert position in ["last", "first"], f"Invalid position: {position}"
267    idx = {"last": -1, "first": 0}[position]
268
269    dtype = {"int": int, "float": float}.get(dtype, dtype)
270    assert dtype in [int, float], f"Invalid dtype: {dtype}"
271    if rounding:
272        dtype = float
273    regex = {int: r"[-+]?\d+", float: r"[-+]?\d*\.?\d+"}[dtype]
274
275    numbers = re.findall(regex, str(text))
276    if numbers:
277        try:
278            value = dtype(numbers[idx].strip())
279            return round(value) if rounding else value
280        except (ValueError, OverflowError):
281            ...
282    return return_default(default, text)
283
284
285def dedent(text: str) -> str:
286    """
287    Removes minimal shared leading whitespace from each line
288    and strips leading and trailing empty lines.
289    """
290    lines = text.splitlines()
291    while lines and lines[0].strip() == "":
292        lines.pop(0)
293    while lines and lines[-1].strip() == "":
294        lines.pop()
295    non_empty_lines = [line for line in lines if line.strip()]
296    if non_empty_lines:
297        min_indent = min((len(line) - len(line.lstrip())) for line in non_empty_lines)
298        dedented_lines = [
299            line[min_indent:] if line and len(line) >= min_indent else line
300            for line in lines
301        ]
302    else:
303        dedented_lines = lines
304    return "\n".join(dedented_lines)
305
306
307async def run_parallel(tasks: list, max_concurrent_tasks: int):
308    """
309    Run tasks in parallel with a limit on the number of concurrent tasks.
310    """
311    semaphore = asyncio.Semaphore(max_concurrent_tasks)
312
313    async def worker(task):
314        async with semaphore:
315            return await task
316
317    return await asyncio.gather(*[worker(task) for task in tasks])
318
319
320def resolve_callable(
321    fn: Union[Callable, str, None], allow_empty=False
322) -> Union[Callable, None]:
323    """
324    Resolves a callable function from a string (module.function)
325    """
326    if callable(fn):
327        return fn
328    if not fn:
329        if allow_empty:
330            return None
331        raise ValueError("Function is not specified")
332    try:
333        if "." not in fn:
334            fn = globals()[fn]
335        else:
336            parts = fn.split(".")
337            module_name = ".".join(parts[:-1])
338            func_name = parts[-1]
339            if not module_name:
340                raise ValueError(f"Invalid module name: {module_name}")
341            module = __import__(module_name, fromlist=[func_name])
342            fn = getattr(module, func_name)
343        assert callable(fn)
344    except (ImportError, AttributeError, AssertionError, ValueError) as e:
345        raise ValueError(f"Can't resolve callable by name '{fn}', {e}") from e
346    return fn
def is_chat_model(model: str, config: microcore.Config = None) -> bool:
24def is_chat_model(model: str, config: Config = None) -> bool:
25    """Detects if model is chat model or text completion model"""
26    if config and config.CHAT_MODE is not None:
27        return config.CHAT_MODE
28    completion_keywords = ["instruct", "davinci", "babbage", "curie", "ada"]
29    return not any(keyword in str(model).lower() for keyword in completion_keywords)

Detects if model is chat model or text completion model

class ConvertableToMessage:
32class ConvertableToMessage:
33    @property
34    def as_user(self) -> UserMsg:
35        return UserMsg(str(self))
36
37    @property
38    def as_system(self) -> SysMsg:
39        return SysMsg(str(self))
40
41    @property
42    def as_assistant(self) -> AssistantMsg:
43        return AssistantMsg(str(self))
44
45    @property
46    def as_model(self) -> AssistantMsg:
47        return self.as_assistant
as_user: microcore.UserMsg
33    @property
34    def as_user(self) -> UserMsg:
35        return UserMsg(str(self))
as_system: microcore.SysMsg
37    @property
38    def as_system(self) -> SysMsg:
39        return SysMsg(str(self))
as_assistant: microcore.AssistantMsg
41    @property
42    def as_assistant(self) -> AssistantMsg:
43        return AssistantMsg(str(self))
as_model: microcore.AssistantMsg
45    @property
46    def as_model(self) -> AssistantMsg:
47        return self.as_assistant
class ExtendedString(builtins.str):
 50class ExtendedString(str):
 51    """
 52    Provides a way of extending string with attributes and methods
 53    """
 54
 55    def __new__(cls, string: str, attrs: dict = None):
 56        """
 57        Allows string to have attributes.
 58        """
 59        obj = str.__new__(cls, string)
 60        if attrs:
 61            for k, v in attrs.items():
 62                setattr(obj, k, v)
 63        return obj
 64
 65    def __getattr__(self, item):
 66        """
 67        Provides chaining of global functions
 68        """
 69        global_func = inspect.currentframe().f_back.f_globals.get(item) or vars(
 70            builtins
 71        ).get(item, None)
 72        if callable(global_func):
 73
 74            def method_handler(*args, **kwargs):
 75                res = global_func(self, *args, **kwargs)
 76                if isinstance(res, str) and not isinstance(res, ExtendedString):
 77                    res = ExtendedString(res)
 78                return res
 79
 80            return method_handler
 81
 82        # If there's not a global function with that name, raise an AttributeError as usual
 83        raise AttributeError(
 84            f"'{self.__class__.__name__}' object has no attribute '{item}'"
 85        )
 86
 87    def to_tokens(
 88        self,
 89        for_model: str = None,
 90        encoding: str | tiktoken.Encoding = None
 91    ):
 92        from .tokenizing import encode
 93        return encode(self, for_model=for_model, encoding=encoding)
 94
 95    def num_tokens(
 96        self,
 97        for_model: str = None,
 98        encoding: str | tiktoken.Encoding = None
 99    ):
100        return len(self.to_tokens(for_model=for_model, encoding=encoding))

Provides a way of extending string with attributes and methods

ExtendedString(string: str, attrs: dict = None)
55    def __new__(cls, string: str, attrs: dict = None):
56        """
57        Allows string to have attributes.
58        """
59        obj = str.__new__(cls, string)
60        if attrs:
61            for k, v in attrs.items():
62                setattr(obj, k, v)
63        return obj

Allows string to have attributes.

def to_tokens( self, for_model: str = None, encoding: str | tiktoken.core.Encoding = None):
87    def to_tokens(
88        self,
89        for_model: str = None,
90        encoding: str | tiktoken.Encoding = None
91    ):
92        from .tokenizing import encode
93        return encode(self, for_model=for_model, encoding=encoding)
def num_tokens( self, for_model: str = None, encoding: str | tiktoken.core.Encoding = None):
 95    def num_tokens(
 96        self,
 97        for_model: str = None,
 98        encoding: str | tiktoken.Encoding = None
 99    ):
100        return len(self.to_tokens(for_model=for_model, encoding=encoding))
Inherited Members
builtins.str
encode
replace
split
rsplit
join
capitalize
casefold
title
center
count
expandtabs
find
partition
index
ljust
lower
lstrip
rfind
rindex
rjust
rstrip
rpartition
splitlines
strip
swapcase
translate
upper
startswith
endswith
removeprefix
removesuffix
isascii
islower
isupper
istitle
isspace
isdecimal
isdigit
isnumeric
isalpha
isalnum
isidentifier
isprintable
zfill
format
format_map
maketrans
def parse( text: str, field_format: str = '\\[\\[(.*?)\\]\\]', required_fields: list = None) -> dict:
116def parse(
117    text: str, field_format: str = r"\[\[(.*?)\]\]", required_fields: list = None
118) -> dict:
119    """
120    Parse a document divided into sections and convert it into a dictionary.
121    """
122    pattern = rf"{field_format}\n(.*?)(?=\n{field_format}|$)"
123    matches = re.findall(pattern, text, re.DOTALL)
124    result = {key.strip().lower(): value for key, value, _ in matches}
125    if required_fields:
126        for field in required_fields:
127            if field not in result:
128                raise BadAIAnswer(f"Field '{field}' is required but not found")
129    return result

Parse a document divided into sections and convert it into a dictionary.

def list_files( target_dir: str | pathlib.Path = '', exclude: list[str | pathlib.Path] = None, relative_to: str | pathlib.Path = None, absolute: bool = False, posix: bool = False) -> list[pathlib.Path]:
137def list_files(
138    target_dir: str | Path = "",
139    exclude: list[str | Path] = None,
140    relative_to: str | Path = None,
141    absolute: bool = False,
142    posix: bool = False,
143) -> list[Path]:
144    """
145    Lists files in a specified directory, excluding those that match given patterns.
146
147    This function traverses the specified directory recursively and returns a list of all files
148    that do not match the specified exclusion patterns. It can return absolute paths,
149    paths relative to the target directory, or paths relative to a specified directory.
150
151    Args:
152        target_dir (str | Path): The directory to search in.
153        exclude (list[str | Path]): Patterns of files to exclude.
154        relative_to (str | Path, optional): Base directory for relative paths.
155            If None, paths are relative to `target_dir`. Defaults to None.
156        absolute (bool, optional): If True, returns absolute paths. Defaults to False.
157        posix (bool, optional): If True, returns posix paths. Defaults to False.
158
159    Returns:
160        list[Path]: A list of Path objects representing the files found.
161
162    Example:
163        exclude_patterns = ['*.pyc', '__pycache__/*']
164        target_directory = '/path/to/target'
165        files = list_files(target_directory, exclude_patterns)
166    """
167    exclude = exclude or []
168    target = Path(target_dir or os.getcwd()).resolve()
169    relative_to = Path(relative_to).resolve() if relative_to else None
170    if absolute and relative_to is not None:
171        raise ValueError(
172            "list_files(): Cannot specify both 'absolute' and 'relative_to'. Choose one."
173        )
174    return [
175        p.as_posix() if posix else p
176        for p in (
177            path.resolve() if absolute else path.relative_to(relative_to or target)
178            for path in target.rglob("*")
179            if path.is_file()
180            and not any(
181                fnmatch(str(path.relative_to(target)), str(pattern))
182                for pattern in exclude
183            )
184        )
185    ]

Lists files in a specified directory, excluding those that match given patterns.

This function traverses the specified directory recursively and returns a list of all files that do not match the specified exclusion patterns. It can return absolute paths, paths relative to the target directory, or paths relative to a specified directory.

Arguments:
  • target_dir (str | Path): The directory to search in.
  • exclude (list[str | Path]): Patterns of files to exclude.
  • relative_to (str | Path, optional): Base directory for relative paths. If None, paths are relative to target_dir. Defaults to None.
  • absolute (bool, optional): If True, returns absolute paths. Defaults to False.
  • posix (bool, optional): If True, returns posix paths. Defaults to False.
Returns:

list[Path]: A list of Path objects representing the files found.

Example:

exclude_patterns = ['*.pyc', '__pycache__/*'] target_directory = '/path/to/target' files = list_files(target_directory, exclude_patterns)

def is_kaggle() -> bool:
188def is_kaggle() -> bool:
189    return "KAGGLE_KERNEL_RUN_TYPE" in os.environ
def is_notebook() -> bool:
192def is_notebook() -> bool:
193    return "ipykernel" in sys.modules
def is_google_colab() -> bool:
196def is_google_colab() -> bool:
197    return "google.colab" in sys.modules
def get_vram_usage(as_string=True, color='\x1b[32m'):
200def get_vram_usage(as_string=True, color=Fore.GREEN):
201    @dataclass
202    class _MemUsage:
203        name: str
204        used: int
205        free: int
206        total: int
207
208    cmd = (
209        "nvidia-smi"
210        " --query-gpu=name,memory.used,memory.free,memory.total"
211        " --format=csv,noheader,nounits"
212    )
213    try:
214        out = subprocess.check_output(cmd, shell=True, text=True).strip()
215
216        mu = [
217            _MemUsage(*[i.strip() for i in line.split(",")])
218            for line in out.splitlines()
219        ]
220        if not as_string:
221            return mu
222        c, r = (color, Fore.RESET) if color else ("", "")
223        return "\n".join(
224            [
225                f"GPU: {c}{i.name}{r}, "
226                f"VRAM: {c}{i.used}{r}/{c}{i.total}{r} MiB used, "
227                f"{c}{i.free}{r} MiB free"
228                for i in mu
229            ]
230        )
231    except subprocess.CalledProcessError:
232        msg = "No GPU found or nvidia-smi is not installed"
233        return f"{Fore.RED}{msg}{Fore.RESET}" if as_string else None
def show_vram_usage():
236def show_vram_usage():
237    """Prints GPU VRAM usage."""
238    print(get_vram_usage(as_string=True))

Prints GPU VRAM usage.

def return_default(default, *args):
241def return_default(default, *args):
242    if isinstance(default, type) and issubclass(default, BaseException):
243        raise default()
244    if isinstance(default, BaseException):
245        raise default
246    if inspect.isbuiltin(default):
247        return default(*args)
248    if inspect.isfunction(default):
249        arg_count = default.__code__.co_argcount
250        if inspect.ismethod(default):
251            arg_count -= 1
252        return default(*args) if arg_count >= len(args) else default()
253
254    return default
def extract_number( text: str, default=None, position='last', dtype: type | str = <class 'float'>, rounding: bool = False) -> int | float | typing.Any:
257def extract_number(
258    text: str,
259    default=None,
260    position="last",
261    dtype: type | str = float,
262    rounding: bool = False,
263) -> int | float | Any:
264    """
265    Extract a number from a string.
266    """
267    assert position in ["last", "first"], f"Invalid position: {position}"
268    idx = {"last": -1, "first": 0}[position]
269
270    dtype = {"int": int, "float": float}.get(dtype, dtype)
271    assert dtype in [int, float], f"Invalid dtype: {dtype}"
272    if rounding:
273        dtype = float
274    regex = {int: r"[-+]?\d+", float: r"[-+]?\d*\.?\d+"}[dtype]
275
276    numbers = re.findall(regex, str(text))
277    if numbers:
278        try:
279            value = dtype(numbers[idx].strip())
280            return round(value) if rounding else value
281        except (ValueError, OverflowError):
282            ...
283    return return_default(default, text)

Extract a number from a string.

def dedent(text: str) -> str:
286def dedent(text: str) -> str:
287    """
288    Removes minimal shared leading whitespace from each line
289    and strips leading and trailing empty lines.
290    """
291    lines = text.splitlines()
292    while lines and lines[0].strip() == "":
293        lines.pop(0)
294    while lines and lines[-1].strip() == "":
295        lines.pop()
296    non_empty_lines = [line for line in lines if line.strip()]
297    if non_empty_lines:
298        min_indent = min((len(line) - len(line.lstrip())) for line in non_empty_lines)
299        dedented_lines = [
300            line[min_indent:] if line and len(line) >= min_indent else line
301            for line in lines
302        ]
303    else:
304        dedented_lines = lines
305    return "\n".join(dedented_lines)

Removes minimal shared leading whitespace from each line and strips leading and trailing empty lines.

async def run_parallel(tasks: list, max_concurrent_tasks: int):
308async def run_parallel(tasks: list, max_concurrent_tasks: int):
309    """
310    Run tasks in parallel with a limit on the number of concurrent tasks.
311    """
312    semaphore = asyncio.Semaphore(max_concurrent_tasks)
313
314    async def worker(task):
315        async with semaphore:
316            return await task
317
318    return await asyncio.gather(*[worker(task) for task in tasks])

Run tasks in parallel with a limit on the number of concurrent tasks.

def resolve_callable( fn: Union[Callable, str, NoneType], allow_empty=False) -> Optional[Callable]:
321def resolve_callable(
322    fn: Union[Callable, str, None], allow_empty=False
323) -> Union[Callable, None]:
324    """
325    Resolves a callable function from a string (module.function)
326    """
327    if callable(fn):
328        return fn
329    if not fn:
330        if allow_empty:
331            return None
332        raise ValueError("Function is not specified")
333    try:
334        if "." not in fn:
335            fn = globals()[fn]
336        else:
337            parts = fn.split(".")
338            module_name = ".".join(parts[:-1])
339            func_name = parts[-1]
340            if not module_name:
341                raise ValueError(f"Invalid module name: {module_name}")
342            module = __import__(module_name, fromlist=[func_name])
343            fn = getattr(module, func_name)
344        assert callable(fn)
345    except (ImportError, AttributeError, AssertionError, ValueError) as e:
346        raise ValueError(f"Can't resolve callable by name '{fn}', {e}") from e
347    return fn

Resolves a callable function from a string (module.function)