summaryrefslogtreecommitdiff
path: root/components/providers/google.rb
diff options
context:
space:
mode:
Diffstat (limited to 'components/providers/google.rb')
-rw-r--r--components/providers/google.rb24
1 files changed, 19 insertions, 5 deletions
diff --git a/components/providers/google.rb b/components/providers/google.rb
index 2a99bcb..f847677 100644
--- a/components/providers/google.rb
+++ b/components/providers/google.rb
@@ -14,13 +14,16 @@ module NanoBot
module Providers
class Google < Base
SETTINGS = {
- safetySettings: %i[category threshold].freeze,
- generationConfig: %i[temperature topP topK candidateCount maxOutputTokens stopSequences].freeze
+ generationConfig: %i[
+ temperature topP topK candidateCount maxOutputTokens stopSequences
+ ].freeze
}.freeze
+ SAFETY_SETTINGS = %i[category threshold].freeze
+
attr_reader :settings
- def initialize(model, settings, credentials, _environment)
+ def initialize(options, settings, credentials, _environment)
@settings = settings
@client = Gemini.new(
@@ -29,7 +32,7 @@ module NanoBot
project_id: credentials[:'project-id'],
region: credentials[:region]
},
- settings: { model:, stream: false }
+ settings: { model: options[:model], stream: options[:stream] }
)
end
@@ -77,6 +80,16 @@ module NanoBot
end
end
end
+
+ if @settings[:safetySettings].is_a?(Array)
+ payload[:safetySettings] = [] unless payload.key?(:safetySettings)
+
+ @settings[:safetySettings].each do |safety_setting|
+ setting = {}
+ SAFETY_SETTINGS.each { |key| setting[key] = safety_setting[key] }
+ payload[:safetySettings] << setting
+ end
+ end
end
if input[:tools]
@@ -143,7 +156,8 @@ module NanoBot
else
begin
result = @client.stream_generate_content(
- Logic::Google::Tokens.apply_policies!(cartridge, payload)
+ Logic::Google::Tokens.apply_policies!(cartridge, payload),
+ stream: false
)
rescue StandardError => e
raise e.class, e.response[:body] if e.response && e.response[:body]