summaryrefslogtreecommitdiff
path: root/components/providers
diff options
context:
space:
mode:
authoricebaker <icebaker@proton.me>2023-12-15 08:04:27 -0300
committericebaker <icebaker@proton.me>2023-12-15 08:04:27 -0300
commitfef3d5b3b2f823999fae68276382fe33872350c4 (patch)
treef41758bcc6c9e7e5805269afe4104a636dc17baa /components/providers
parentb9a22a449d33d254f2c1a7f3d2196712ff6d9b8a (diff)
improving provider options
Diffstat (limited to 'components/providers')
-rw-r--r--components/providers/google.rb24
-rw-r--r--components/providers/openai.rb2
2 files changed, 20 insertions, 6 deletions
diff --git a/components/providers/google.rb b/components/providers/google.rb
index 2a99bcb..f847677 100644
--- a/components/providers/google.rb
+++ b/components/providers/google.rb
@@ -14,13 +14,16 @@ module NanoBot
module Providers
class Google < Base
SETTINGS = {
- safetySettings: %i[category threshold].freeze,
- generationConfig: %i[temperature topP topK candidateCount maxOutputTokens stopSequences].freeze
+ generationConfig: %i[
+ temperature topP topK candidateCount maxOutputTokens stopSequences
+ ].freeze
}.freeze
+ SAFETY_SETTINGS = %i[category threshold].freeze
+
attr_reader :settings
- def initialize(model, settings, credentials, _environment)
+ def initialize(options, settings, credentials, _environment)
@settings = settings
@client = Gemini.new(
@@ -29,7 +32,7 @@ module NanoBot
project_id: credentials[:'project-id'],
region: credentials[:region]
},
- settings: { model:, stream: false }
+ settings: { model: options[:model], stream: options[:stream] }
)
end
@@ -77,6 +80,16 @@ module NanoBot
end
end
end
+
+ if @settings[:safetySettings].is_a?(Array)
+ payload[:safetySettings] = [] unless payload.key?(:safetySettings)
+
+ @settings[:safetySettings].each do |safety_setting|
+ setting = {}
+ SAFETY_SETTINGS.each { |key| setting[key] = safety_setting[key] }
+ payload[:safetySettings] << setting
+ end
+ end
end
if input[:tools]
@@ -143,7 +156,8 @@ module NanoBot
else
begin
result = @client.stream_generate_content(
- Logic::Google::Tokens.apply_policies!(cartridge, payload)
+ Logic::Google::Tokens.apply_policies!(cartridge, payload),
+ stream: false
)
rescue StandardError => e
raise e.class, e.response[:body] if e.response && e.response[:body]
diff --git a/components/providers/openai.rb b/components/providers/openai.rb
index b70984b..f6eafd4 100644
--- a/components/providers/openai.rb
+++ b/components/providers/openai.rb
@@ -18,7 +18,7 @@ module NanoBot
CHAT_SETTINGS = %i[
model stream temperature top_p n stop max_tokens
- presence_penalty frequency_penalty logit_bias
+ presence_penalty frequency_penalty logit_bias seed response_format
].freeze
attr_reader :settings