summaryrefslogtreecommitdiff
path: root/components
diff options
context:
space:
mode:
Diffstat (limited to 'components')
-rw-r--r--components/provider.rb5
-rw-r--r--components/providers/google.rb21
-rw-r--r--components/providers/mistral.rb115
3 files changed, 137 insertions, 4 deletions
diff --git a/components/provider.rb b/components/provider.rb
index 57f1cca..bdf3639 100644
--- a/components/provider.rb
+++ b/components/provider.rb
@@ -1,7 +1,8 @@
# frozen_string_literal: true
-require_relative 'providers/openai'
require_relative 'providers/google'
+require_relative 'providers/mistral'
+require_relative 'providers/openai'
module NanoBot
module Components
@@ -12,6 +13,8 @@ module NanoBot
Providers::OpenAI.new(nil, provider[:settings], provider[:credentials], environment:)
when 'google'
Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:)
+ when 'mistral'
+ Providers::Mistral.new(provider[:options], provider[:settings], provider[:credentials], environment:)
else
raise "Unsupported provider \"#{provider[:id]}\""
end
diff --git a/components/providers/google.rb b/components/providers/google.rb
index 25ffbde..c73269b 100644
--- a/components/providers/google.rb
+++ b/components/providers/google.rb
@@ -6,6 +6,8 @@ require_relative 'base'
require_relative '../../logic/providers/google/tools'
require_relative '../../logic/providers/google/tokens'
+require_relative '../../logic/helpers/hash'
+require_relative '../../logic/cartridge/default'
require_relative 'tools'
@@ -26,9 +28,19 @@ module NanoBot
def initialize(options, settings, credentials, _environment)
@settings = settings
+ gemini_options = options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+
+ unless gemini_options.key?(:stream)
+ gemini_options[:stream] = Logic::Helpers::Hash.fetch(
+ Logic::Cartridge::Default.instance.values, %i[provider settings stream]
+ )
+ end
+
+ gemini_options[:server_sent_events] = gemini_options.delete(:stream) if gemini_options.key?(:stream)
+
@client = Gemini.new(
credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
- options: options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+ options: gemini_options
)
end
@@ -105,6 +117,9 @@ module NanoBot
tools = []
stream_call_back = proc do |event, _parsed, _raw|
+ # TODO: How to better handle finishReason == 'OTHER'?
+ return if event.dig('candidates', 0, 'finishReason') == 'OTHER'
+
partial_content = event.dig('candidates', 0, 'content', 'parts').filter do |part|
part.key?('text')
end.map { |part| part['text'] }.join
@@ -132,7 +147,7 @@ module NanoBot
@client.stream_generate_content(
Logic::Google::Tokens.apply_policies!(cartridge, payload),
- stream: true, &stream_call_back
+ server_sent_events: true, &stream_call_back
)
if tools&.size&.positive?
@@ -156,7 +171,7 @@ module NanoBot
else
result = @client.stream_generate_content(
Logic::Google::Tokens.apply_policies!(cartridge, payload),
- stream: false
+ server_sent_events: false
)
tools = result.dig(0, 'candidates', 0, 'content', 'parts').filter do |part|
diff --git a/components/providers/mistral.rb b/components/providers/mistral.rb
new file mode 100644
index 0000000..9b5c6c4
--- /dev/null
+++ b/components/providers/mistral.rb
@@ -0,0 +1,115 @@
+# frozen_string_literal: true
+
+require 'mistral-ai'
+
+require_relative 'base'
+
+require_relative '../../logic/providers/mistral/tokens'
+require_relative '../../logic/helpers/hash'
+require_relative '../../logic/cartridge/default'
+
+module NanoBot
+ module Components
+ module Providers
+ class Mistral < Base
+ attr_reader :settings
+
+ CHAT_SETTINGS = %i[
+ model temperature top_p max_tokens stream safe_mode random_seed
+ ].freeze
+
+ def initialize(options, settings, credentials, _environment)
+ @settings = settings
+
+ mistral_options = if options
+ options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+ else
+ {}
+ end
+
+ unless @settings.key?(:stream)
+ @settings = Marshal.load(Marshal.dump(@settings))
+ @settings[:stream] = Logic::Helpers::Hash.fetch(
+ Logic::Cartridge::Default.instance.values, %i[provider settings stream]
+ )
+ end
+
+ mistral_options[:server_sent_events] = @settings[:stream]
+
+ @client = ::Mistral.new(
+ credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
+ options: mistral_options
+ )
+ end
+
+ def evaluate(input, streaming, cartridge, &feedback)
+ messages = input[:history].map do |event|
+ { role: event[:who] == 'user' ? 'user' : 'assistant',
+ content: event[:message],
+ _meta: { at: event[:at] } }
+ end
+
+ %i[backdrop directive].each do |key|
+ next unless input[:behavior][key]
+
+ messages.prepend(
+ { role: key == :directive ? 'system' : 'user',
+ content: input[:behavior][key],
+ _meta: { at: Time.now } }
+ )
+ end
+
+ payload = { messages: }
+
+ CHAT_SETTINGS.each do |key|
+ payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key)
+ end
+
+ raise 'Mistral does not support tools.' if input[:tools]
+
+ if streaming
+ content = ''
+
+ stream_call_back = proc do |event, _parsed, _raw|
+ partial_content = event.dig('choices', 0, 'delta', 'content')
+
+ if partial_content
+ content += partial_content
+ feedback.call(
+ { should_be_stored: false,
+ interaction: { who: 'AI', message: partial_content } }
+ )
+ end
+
+ if event.dig('choices', 0, 'finish_reason')
+ feedback.call(
+ { should_be_stored: !(content.nil? || content == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+
+ @client.chat_completions(
+ Logic::Mistral::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: true, &stream_call_back
+ )
+ else
+ result = @client.chat_completions(
+ Logic::Mistral::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: false
+ )
+
+ content = result.dig('choices', 0, 'message', 'content')
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content.to_s.strip == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+ end
+ end
+ end
+end