summaryrefslogtreecommitdiff
path: root/components
diff options
context:
space:
mode:
authoricebaker <icebaker@proton.me>2024-01-06 22:09:23 -0300
committericebaker <icebaker@proton.me>2024-01-06 22:09:23 -0300
commitc4807b26f0d530ef99ff87b6c5c45a4953ba958a (patch)
tree8d2709d35089ec8afb60bd20c6855de4623e3d1b /components
parentbfe0e76e3683a71bb8ce5bfdaae99b0252e7be05 (diff)
adding new providers
Diffstat (limited to 'components')
-rw-r--r--components/provider.rb14
-rw-r--r--components/providers/cohere.rb2
-rw-r--r--components/providers/maritaca.rb113
-rw-r--r--components/providers/ollama.rb132
4 files changed, 256 insertions, 5 deletions
diff --git a/components/provider.rb b/components/provider.rb
index ac3964d..4c409d2 100644
--- a/components/provider.rb
+++ b/components/provider.rb
@@ -1,9 +1,11 @@
# frozen_string_literal: true
-require_relative 'providers/google'
-require_relative 'providers/mistral'
require_relative 'providers/openai'
+require_relative 'providers/ollama'
+require_relative 'providers/mistral'
+require_relative 'providers/google'
require_relative 'providers/cohere'
+require_relative 'providers/maritaca'
module NanoBot
module Components
@@ -12,12 +14,16 @@ module NanoBot
case provider[:id]
when 'openai'
Providers::OpenAI.new(nil, provider[:settings], provider[:credentials], environment:)
- when 'google'
- Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:)
+ when 'ollama'
+ Providers::Ollama.new(provider[:options], provider[:settings], provider[:credentials], environment:)
when 'mistral'
Providers::Mistral.new(provider[:options], provider[:settings], provider[:credentials], environment:)
+ when 'google'
+ Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:)
when 'cohere'
Providers::Cohere.new(provider[:options], provider[:settings], provider[:credentials], environment:)
+ when 'maritaca'
+ Providers::Maritaca.new(provider[:options], provider[:settings], provider[:credentials], environment:)
else
raise "Unsupported provider \"#{provider[:id]}\""
end
diff --git a/components/providers/cohere.rb b/components/providers/cohere.rb
index 9b9f045..970837e 100644
--- a/components/providers/cohere.rb
+++ b/components/providers/cohere.rb
@@ -76,7 +76,7 @@ module NanoBot
if streaming
content = ''
- stream_call_back = proc do |event, _parsed, _raw|
+ stream_call_back = proc do |event, _raw|
partial_content = event['text']
if partial_content && event['event_type'] == 'text-generation'
diff --git a/components/providers/maritaca.rb b/components/providers/maritaca.rb
new file mode 100644
index 0000000..7a6fbe9
--- /dev/null
+++ b/components/providers/maritaca.rb
@@ -0,0 +1,113 @@
+# frozen_string_literal: true
+
+require 'maritaca-ai'
+
+require_relative 'base'
+
+require_relative '../../logic/providers/maritaca/tokens'
+require_relative '../../logic/helpers/hash'
+require_relative '../../logic/cartridge/default'
+
+module NanoBot
+ module Components
+ module Providers
+ class Maritaca < Base
+ attr_reader :settings
+
+ CHAT_SETTINGS = %i[
+ max_tokens model do_sample temperature top_p repetition_penalty stopping_tokens
+ ].freeze
+
+ def initialize(options, settings, credentials, _environment)
+ @settings = settings
+
+ maritaca_options = if options
+ options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+ else
+ {}
+ end
+
+ unless maritaca_options.key?(:stream)
+ maritaca_options[:stream] = Logic::Helpers::Hash.fetch(
+ Logic::Cartridge::Default.instance.values, %i[provider options stream]
+ )
+ end
+
+ maritaca_options[:server_sent_events] = maritaca_options.delete(:stream)
+
+ @client = ::Maritaca.new(
+ credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
+ options: maritaca_options
+ )
+ end
+
+ def evaluate(input, streaming, cartridge, &feedback)
+ messages = input[:history].map do |event|
+ { role: event[:who] == 'user' ? 'user' : 'assistant',
+ content: event[:message],
+ _meta: { at: event[:at] } }
+ end
+
+ # TODO: Does Maritaca have system messages?
+ %i[backdrop directive].each do |key|
+ next unless input[:behavior][key]
+
+ messages.prepend(
+ { role: 'user',
+ content: input[:behavior][key],
+ _meta: { at: Time.now } }
+ )
+ end
+
+ payload = { chat_mode: true, messages: }
+
+ CHAT_SETTINGS.each do |key|
+ payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key)
+ end
+
+ raise 'Maritaca does not support tools.' if input[:tools]
+
+ if streaming
+ content = ''
+
+ stream_call_back = proc do |event, _raw|
+ partial_content = event['answer']
+
+ if partial_content
+ content += partial_content
+ feedback.call(
+ { should_be_stored: false,
+ interaction: { who: 'AI', message: partial_content } }
+ )
+ end
+ end
+
+ @client.chat_inference(
+ Logic::Maritaca::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: true, &stream_call_back
+ )
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ else
+ result = @client.chat_inference(
+ Logic::Maritaca::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: false
+ )
+
+ content = result['answer']
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content.to_s.strip == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/components/providers/ollama.rb b/components/providers/ollama.rb
new file mode 100644
index 0000000..9edb461
--- /dev/null
+++ b/components/providers/ollama.rb
@@ -0,0 +1,132 @@
+# frozen_string_literal: true
+
+require 'ollama-ai'
+
+require_relative 'base'
+
+require_relative '../../logic/providers/ollama/tokens'
+require_relative '../../logic/helpers/hash'
+require_relative '../../logic/cartridge/default'
+
+module NanoBot
+ module Components
+ module Providers
+ class Ollama < Base
+ attr_reader :settings
+
+ CHAT_SETTINGS = %i[
+ model template stream
+ ].freeze
+
+ CHAT_OPTIONS = %i[
+ mirostat mirostat_eta mirostat_tau num_ctx num_gqa num_gpu num_thread repeat_last_n
+ repeat_penalty temperature seed stop tfs_z num_predict top_k top_p
+ ].freeze
+
+ def initialize(options, settings, credentials, _environment)
+ @settings = settings
+
+ ollama_options = if options
+ options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+ else
+ {}
+ end
+
+ unless @settings.key?(:stream)
+ @settings = Marshal.load(Marshal.dump(@settings))
+ @settings[:stream] = Logic::Helpers::Hash.fetch(
+ Logic::Cartridge::Default.instance.values, %i[provider settings stream]
+ )
+ end
+
+ ollama_options[:server_sent_events] = @settings[:stream]
+
+ credentials ||= {}
+
+ @client = ::Ollama.new(
+ credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
+ options: ollama_options
+ )
+ end
+
+ def evaluate(input, streaming, cartridge, &feedback)
+ messages = input[:history].map do |event|
+ { role: event[:who] == 'user' ? 'user' : 'assistant',
+ content: event[:message],
+ _meta: { at: event[:at] } }
+ end
+
+ %i[backdrop directive].each do |key|
+ next unless input[:behavior][key]
+
+ messages.prepend(
+ { role: key == :directive ? 'system' : 'user',
+ content: input[:behavior][key],
+ _meta: { at: Time.now } }
+ )
+ end
+
+ payload = { messages: }
+
+ CHAT_SETTINGS.each do |key|
+ payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key)
+ end
+
+ if @settings.key?(:options)
+ options = {}
+
+ CHAT_OPTIONS.each do |key|
+ options[key] = @settings[:options][key] unless options.key?(key) || !@settings[:options].key?(key)
+ end
+
+ payload[:options] = options unless options.empty?
+ end
+
+ raise 'Ollama does not support tools.' if input[:tools]
+
+ if streaming
+ content = ''
+
+ stream_call_back = proc do |event, _raw|
+ partial_content = event.dig('message', 'content')
+
+ if partial_content
+ content += partial_content
+ feedback.call(
+ { should_be_stored: false,
+ interaction: { who: 'AI', message: partial_content } }
+ )
+ end
+
+ if event['done']
+ feedback.call(
+ { should_be_stored: !(content.nil? || content == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+
+ @client.chat(
+ Logic::Ollama::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: true, &stream_call_back
+ )
+ else
+ result = @client.chat(
+ Logic::Ollama::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: false
+ )
+
+ content = result.map { |event| event.dig('message', 'content') }.join
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content.to_s.strip == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+ end
+ end
+ end
+end