diff options
author | icebaker <icebaker@proton.me> | 2024-01-06 22:09:23 -0300 |
---|---|---|
committer | icebaker <icebaker@proton.me> | 2024-01-06 22:09:23 -0300 |
commit | c4807b26f0d530ef99ff87b6c5c45a4953ba958a (patch) | |
tree | 8d2709d35089ec8afb60bd20c6855de4623e3d1b /components | |
parent | bfe0e76e3683a71bb8ce5bfdaae99b0252e7be05 (diff) |
adding new providers
Diffstat (limited to 'components')
-rw-r--r-- | components/provider.rb | 14 | ||||
-rw-r--r-- | components/providers/cohere.rb | 2 | ||||
-rw-r--r-- | components/providers/maritaca.rb | 113 | ||||
-rw-r--r-- | components/providers/ollama.rb | 132 |
4 files changed, 256 insertions, 5 deletions
diff --git a/components/provider.rb b/components/provider.rb index ac3964d..4c409d2 100644 --- a/components/provider.rb +++ b/components/provider.rb @@ -1,9 +1,11 @@ # frozen_string_literal: true -require_relative 'providers/google' -require_relative 'providers/mistral' require_relative 'providers/openai' +require_relative 'providers/ollama' +require_relative 'providers/mistral' +require_relative 'providers/google' require_relative 'providers/cohere' +require_relative 'providers/maritaca' module NanoBot module Components @@ -12,12 +14,16 @@ module NanoBot case provider[:id] when 'openai' Providers::OpenAI.new(nil, provider[:settings], provider[:credentials], environment:) - when 'google' - Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:) + when 'ollama' + Providers::Ollama.new(provider[:options], provider[:settings], provider[:credentials], environment:) when 'mistral' Providers::Mistral.new(provider[:options], provider[:settings], provider[:credentials], environment:) + when 'google' + Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:) when 'cohere' Providers::Cohere.new(provider[:options], provider[:settings], provider[:credentials], environment:) + when 'maritaca' + Providers::Maritaca.new(provider[:options], provider[:settings], provider[:credentials], environment:) else raise "Unsupported provider \"#{provider[:id]}\"" end diff --git a/components/providers/cohere.rb b/components/providers/cohere.rb index 9b9f045..970837e 100644 --- a/components/providers/cohere.rb +++ b/components/providers/cohere.rb @@ -76,7 +76,7 @@ module NanoBot if streaming content = '' - stream_call_back = proc do |event, _parsed, _raw| + stream_call_back = proc do |event, _raw| partial_content = event['text'] if partial_content && event['event_type'] == 'text-generation' diff --git a/components/providers/maritaca.rb b/components/providers/maritaca.rb new file mode 100644 index 0000000..7a6fbe9 --- /dev/null +++ b/components/providers/maritaca.rb @@ -0,0 +1,113 @@ +# frozen_string_literal: true + +require 'maritaca-ai' + +require_relative 'base' + +require_relative '../../logic/providers/maritaca/tokens' +require_relative '../../logic/helpers/hash' +require_relative '../../logic/cartridge/default' + +module NanoBot + module Components + module Providers + class Maritaca < Base + attr_reader :settings + + CHAT_SETTINGS = %i[ + max_tokens model do_sample temperature top_p repetition_penalty stopping_tokens + ].freeze + + def initialize(options, settings, credentials, _environment) + @settings = settings + + maritaca_options = if options + options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym } + else + {} + end + + unless maritaca_options.key?(:stream) + maritaca_options[:stream] = Logic::Helpers::Hash.fetch( + Logic::Cartridge::Default.instance.values, %i[provider options stream] + ) + end + + maritaca_options[:server_sent_events] = maritaca_options.delete(:stream) + + @client = ::Maritaca.new( + credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }, + options: maritaca_options + ) + end + + def evaluate(input, streaming, cartridge, &feedback) + messages = input[:history].map do |event| + { role: event[:who] == 'user' ? 'user' : 'assistant', + content: event[:message], + _meta: { at: event[:at] } } + end + + # TODO: Does Maritaca have system messages? + %i[backdrop directive].each do |key| + next unless input[:behavior][key] + + messages.prepend( + { role: 'user', + content: input[:behavior][key], + _meta: { at: Time.now } } + ) + end + + payload = { chat_mode: true, messages: } + + CHAT_SETTINGS.each do |key| + payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key) + end + + raise 'Maritaca does not support tools.' if input[:tools] + + if streaming + content = '' + + stream_call_back = proc do |event, _raw| + partial_content = event['answer'] + + if partial_content + content += partial_content + feedback.call( + { should_be_stored: false, + interaction: { who: 'AI', message: partial_content } } + ) + end + end + + @client.chat_inference( + Logic::Maritaca::Tokens.apply_policies!(cartridge, payload), + server_sent_events: true, &stream_call_back + ) + + feedback.call( + { should_be_stored: !(content.nil? || content == ''), + interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content }, + finished: true } + ) + else + result = @client.chat_inference( + Logic::Maritaca::Tokens.apply_policies!(cartridge, payload), + server_sent_events: false + ) + + content = result['answer'] + + feedback.call( + { should_be_stored: !(content.nil? || content.to_s.strip == ''), + interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content }, + finished: true } + ) + end + end + end + end + end +end diff --git a/components/providers/ollama.rb b/components/providers/ollama.rb new file mode 100644 index 0000000..9edb461 --- /dev/null +++ b/components/providers/ollama.rb @@ -0,0 +1,132 @@ +# frozen_string_literal: true + +require 'ollama-ai' + +require_relative 'base' + +require_relative '../../logic/providers/ollama/tokens' +require_relative '../../logic/helpers/hash' +require_relative '../../logic/cartridge/default' + +module NanoBot + module Components + module Providers + class Ollama < Base + attr_reader :settings + + CHAT_SETTINGS = %i[ + model template stream + ].freeze + + CHAT_OPTIONS = %i[ + mirostat mirostat_eta mirostat_tau num_ctx num_gqa num_gpu num_thread repeat_last_n + repeat_penalty temperature seed stop tfs_z num_predict top_k top_p + ].freeze + + def initialize(options, settings, credentials, _environment) + @settings = settings + + ollama_options = if options + options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym } + else + {} + end + + unless @settings.key?(:stream) + @settings = Marshal.load(Marshal.dump(@settings)) + @settings[:stream] = Logic::Helpers::Hash.fetch( + Logic::Cartridge::Default.instance.values, %i[provider settings stream] + ) + end + + ollama_options[:server_sent_events] = @settings[:stream] + + credentials ||= {} + + @client = ::Ollama.new( + credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }, + options: ollama_options + ) + end + + def evaluate(input, streaming, cartridge, &feedback) + messages = input[:history].map do |event| + { role: event[:who] == 'user' ? 'user' : 'assistant', + content: event[:message], + _meta: { at: event[:at] } } + end + + %i[backdrop directive].each do |key| + next unless input[:behavior][key] + + messages.prepend( + { role: key == :directive ? 'system' : 'user', + content: input[:behavior][key], + _meta: { at: Time.now } } + ) + end + + payload = { messages: } + + CHAT_SETTINGS.each do |key| + payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key) + end + + if @settings.key?(:options) + options = {} + + CHAT_OPTIONS.each do |key| + options[key] = @settings[:options][key] unless options.key?(key) || !@settings[:options].key?(key) + end + + payload[:options] = options unless options.empty? + end + + raise 'Ollama does not support tools.' if input[:tools] + + if streaming + content = '' + + stream_call_back = proc do |event, _raw| + partial_content = event.dig('message', 'content') + + if partial_content + content += partial_content + feedback.call( + { should_be_stored: false, + interaction: { who: 'AI', message: partial_content } } + ) + end + + if event['done'] + feedback.call( + { should_be_stored: !(content.nil? || content == ''), + interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content }, + finished: true } + ) + end + end + + @client.chat( + Logic::Ollama::Tokens.apply_policies!(cartridge, payload), + server_sent_events: true, &stream_call_back + ) + else + result = @client.chat( + Logic::Ollama::Tokens.apply_policies!(cartridge, payload), + server_sent_events: false + ) + + content = result.map { |event| event.dig('message', 'content') }.join + + feedback.call( + { should_be_stored: !(content.nil? || content.to_s.strip == ''), + interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content }, + finished: true } + ) + end + end + end + end + end +end |