summaryrefslogtreecommitdiff
path: root/components/providers/google.rb
diff options
context:
space:
mode:
Diffstat (limited to 'components/providers/google.rb')
-rw-r--r--components/providers/google.rb21
1 files changed, 18 insertions, 3 deletions
diff --git a/components/providers/google.rb b/components/providers/google.rb
index 25ffbde..3976e2c 100644
--- a/components/providers/google.rb
+++ b/components/providers/google.rb
@@ -6,6 +6,8 @@ require_relative 'base'
require_relative '../../logic/providers/google/tools'
require_relative '../../logic/providers/google/tokens'
+require_relative '../../logic/helpers/hash'
+require_relative '../../logic/cartridge/default'
require_relative 'tools'
@@ -26,9 +28,19 @@ module NanoBot
def initialize(options, settings, credentials, _environment)
@settings = settings
+ gemini_options = options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+
+ unless gemini_options.key?(:stream)
+ gemini_options[:stream] = Logic::Helpers::Hash.fetch(
+ Logic::Cartridge::Default.instance.values, %i[provider settings stream]
+ )
+ end
+
+ gemini_options[:server_sent_events] = gemini_options.delete(:stream)
+
@client = Gemini.new(
credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
- options: options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+ options: gemini_options
)
end
@@ -105,6 +117,9 @@ module NanoBot
tools = []
stream_call_back = proc do |event, _parsed, _raw|
+ # TODO: How to better handle finishReason == 'OTHER'?
+ return if event.dig('candidates', 0, 'finishReason') == 'OTHER'
+
partial_content = event.dig('candidates', 0, 'content', 'parts').filter do |part|
part.key?('text')
end.map { |part| part['text'] }.join
@@ -132,7 +147,7 @@ module NanoBot
@client.stream_generate_content(
Logic::Google::Tokens.apply_policies!(cartridge, payload),
- stream: true, &stream_call_back
+ server_sent_events: true, &stream_call_back
)
if tools&.size&.positive?
@@ -156,7 +171,7 @@ module NanoBot
else
result = @client.stream_generate_content(
Logic::Google::Tokens.apply_policies!(cartridge, payload),
- stream: false
+ server_sent_events: false
)
tools = result.dig(0, 'candidates', 0, 'content', 'parts').filter do |part|