feat(llm): add additional LLM configuration parameters
Extend LLMConfig class to support more fine-grained control over LLM behavior by adding: - temperature control - max tokens limit - top_p sampling - frequency and presence penalties - stop sequences - number of completions These parameters allow for better customization of LLM responses.
This commit is contained in:
@@ -1086,6 +1086,13 @@ class LLMConfig:
|
|||||||
provider: str = DEFAULT_PROVIDER,
|
provider: str = DEFAULT_PROVIDER,
|
||||||
api_token: Optional[str] = None,
|
api_token: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
temprature: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""Configuaration class for LLM provider and API token."""
|
"""Configuaration class for LLM provider and API token."""
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
@@ -1098,7 +1105,13 @@ class LLMConfig:
|
|||||||
DEFAULT_PROVIDER_API_KEY
|
DEFAULT_PROVIDER_API_KEY
|
||||||
)
|
)
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
|
self.temprature = temprature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.frequency_penalty = frequency_penalty
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
self.stop = stop
|
||||||
|
self.n = n
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_kwargs(kwargs: dict) -> "LLMConfig":
|
def from_kwargs(kwargs: dict) -> "LLMConfig":
|
||||||
@@ -1106,13 +1119,27 @@ class LLMConfig:
|
|||||||
provider=kwargs.get("provider", DEFAULT_PROVIDER),
|
provider=kwargs.get("provider", DEFAULT_PROVIDER),
|
||||||
api_token=kwargs.get("api_token"),
|
api_token=kwargs.get("api_token"),
|
||||||
base_url=kwargs.get("base_url"),
|
base_url=kwargs.get("base_url"),
|
||||||
|
temprature=kwargs.get("temprature"),
|
||||||
|
max_tokens=kwargs.get("max_tokens"),
|
||||||
|
top_p=kwargs.get("top_p"),
|
||||||
|
frequency_penalty=kwargs.get("frequency_penalty"),
|
||||||
|
presence_penalty=kwargs.get("presence_penalty"),
|
||||||
|
stop=kwargs.get("stop"),
|
||||||
|
n=kwargs.get("n")
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"provider": self.provider,
|
"provider": self.provider,
|
||||||
"api_token": self.api_token,
|
"api_token": self.api_token,
|
||||||
"base_url": self.base_url
|
"base_url": self.base_url,
|
||||||
|
"temprature": self.temprature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"stop": self.stop,
|
||||||
|
"n": self.n
|
||||||
}
|
}
|
||||||
|
|
||||||
def clone(self, **kwargs):
|
def clone(self, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user