summaryrefslogtreecommitdiff
path: root/components/providers/google.rb
diff options
context:
space:
mode:
Diffstat (limited to 'components/providers/google.rb')
-rw-r--r--components/providers/google.rb188
1 files changed, 188 insertions, 0 deletions
diff --git a/components/providers/google.rb b/components/providers/google.rb
new file mode 100644
index 0000000..75b7658
--- /dev/null
+++ b/components/providers/google.rb
@@ -0,0 +1,188 @@
+# frozen_string_literal: true
+
+require 'gemini-ai'
+
+require_relative 'base'
+
+require_relative '../../logic/providers/google/tools'
+require_relative '../../logic/providers/google/tokens'
+
+require_relative 'tools'
+
+module NanoBot
+ module Components
+ module Providers
+ class Google < Base
+ SETTINGS = {
+ safetySettings: %i[category threshold].freeze,
+ generationConfig: %i[temperature topP topK candidateCount maxOutputTokens stopSequences].freeze
+ }.freeze
+
+ attr_reader :settings
+
+ def initialize(model, settings, credentials, _environment)
+ @settings = settings
+
+ @client = Gemini.new(
+ credentials: {
+ file_path: credentials[:'file-path'],
+ project_id: credentials[:'project-id'],
+ region: credentials[:region]
+ },
+ settings: { model: }
+ )
+ end
+
+ def evaluate(input, streaming, cartridge, &feedback)
+ messages = input[:history].map do |event|
+ if event[:message].nil? && event[:meta] && event[:meta][:tool_calls]
+ { role: 'model',
+ parts: event[:meta][:tool_calls],
+ _meta: { at: event[:at] } }
+ elsif event[:who] == 'tool'
+ { role: 'function',
+ parts: [
+ { functionResponse: {
+ name: event[:meta][:name],
+ response: { name: event[:meta][:name], content: event[:message].to_s }
+ } }
+ ],
+ _meta: { at: event[:at] } }
+ else
+ { role: event[:who] == 'user' ? 'user' : 'model',
+ parts: { text: event[:message] },
+ _meta: { at: event[:at] } }
+ end
+ end
+
+ %i[backdrop directive].each do |key|
+ next unless input[:behavior][key]
+
+ # TODO: Does Gemini have system messages?
+ messages.prepend(
+ { role: key == :directive ? 'user' : 'user',
+ parts: { text: input[:behavior][key] },
+ _meta: { at: Time.now } }
+ )
+ end
+
+ payload = { contents: messages, generationConfig: { candidateCount: 1 } }
+
+ if @settings
+ SETTINGS.each_key do |key|
+ SETTINGS[key].each do |sub_key|
+ if @settings.key?(key) && @settings[key].key?(sub_key)
+ payload[key] = {} unless payload.key?(key)
+ payload[key][sub_key] = @settings[key][sub_key]
+ end
+ end
+ end
+ end
+
+ if input[:tools]
+ payload[:tools] = {
+ function_declarations: input[:tools].map { |raw| Logic::Google::Tools.adapt(raw) }
+ }
+ end
+
+ if streaming
+ content = ''
+ tools = []
+
+ stream_call_back = proc do |event, _parsed, _raw|
+ partial_content = event.dig('candidates', 0, 'content', 'parts').filter do |part|
+ part.key?('text')
+ end.map { |part| part['text'] }.join
+
+ partial_tools = event.dig('candidates', 0, 'content', 'parts').filter do |part|
+ part.key?('functionCall')
+ end
+
+ tools.concat(partial_tools) if partial_tools.size.positive?
+
+ if partial_content
+ content += partial_content
+ feedback.call(
+ { should_be_stored: false,
+ interaction: { who: 'AI', message: partial_content } }
+ )
+ end
+
+ if event.dig('candidates', 0, 'finishReason')
+ if tools&.size&.positive?
+ feedback.call(
+ { should_be_stored: true,
+ needs_another_round: true,
+ interaction: { who: 'AI', message: nil, meta: { tool_calls: tools } } }
+ )
+ Tools.apply(
+ cartridge, input[:tools], tools, feedback, Logic::Google::Tools
+ ).each do |interaction|
+ feedback.call({ should_be_stored: true, needs_another_round: true, interaction: })
+ end
+ end
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+
+ begin
+ @client.stream_generate_content(
+ Logic::Google::Tokens.apply_policies!(cartridge, payload),
+ stream: true, &stream_call_back
+ )
+ rescue StandardError => e
+ raise e.class, e.response[:body] if e.response && e.response[:body]
+
+ raise e
+ end
+ else
+ begin
+ result = @client.stream_generate_content(
+ Logic::Google::Tokens.apply_policies!(cartridge, payload)
+ )
+ rescue StandardError => e
+ raise e.class, e.response[:body] if e.response && e.response[:body]
+
+ raise e
+ end
+
+ tools = result.dig(0, 'candidates', 0, 'content', 'parts').filter do |part|
+ part.key?('functionCall')
+ end
+
+ if tools&.size&.positive?
+ feedback.call(
+ { should_be_stored: true,
+ needs_another_round: true,
+ interaction: { who: 'AI', message: nil, meta: { tool_calls: tools } } }
+ )
+
+ Tools.apply(
+ cartridge, input[:tools], tools, feedback, Logic::Google::Tools
+ ).each do |interaction|
+ feedback.call({ should_be_stored: true, needs_another_round: true, interaction: })
+ end
+ end
+
+ content = result.map do |answer|
+ answer.dig('candidates', 0, 'content', 'parts').filter do |part|
+ part.key?('text')
+ end.map { |part| part['text'] }.join
+ end.join
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content.to_s.strip == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+ end
+ end
+ end
+end