summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoricebaker <icebaker@proton.me>2023-12-29 16:10:42 -0300
committericebaker <icebaker@proton.me>2023-12-29 16:10:42 -0300
commitd9a5079555d24a88f68a79dd84207e9cfb034e3c (patch)
tree0ac0593bee23a38062bc1af3e3e4c443dac24a8a
parent91c635366bab160b81b2c1690234f97040d0d60b (diff)
adding support to cohere
-rw-r--r--Gemfile.lock7
-rw-r--r--README.md2
-rw-r--r--components/provider.rb3
-rw-r--r--components/providers/cohere.rb121
-rw-r--r--logic/cartridge/streaming.rb2
-rw-r--r--logic/providers/cohere/tokens.rb17
-rw-r--r--nano-bots.gemspec3
-rw-r--r--spec/data/cartridges/models/cohere/command-light.yml12
-rw-r--r--spec/data/cartridges/models/cohere/command.yml12
-rw-r--r--spec/data/cartridges/models/google/gemini-pro.yml13
-rw-r--r--spec/data/cartridges/models/mistral/medium.yml12
-rw-r--r--spec/data/cartridges/models/mistral/small.yml12
-rw-r--r--spec/data/cartridges/models/mistral/tiny.yml12
-rw-r--r--spec/data/cartridges/models/openai/gpt-3-5-turbo.yml12
-rw-r--r--spec/data/cartridges/models/openai/gpt-4-turbo.yml12
-rw-r--r--spec/tasks/run-all-models.rb53
16 files changed, 301 insertions, 4 deletions
diff --git a/Gemfile.lock b/Gemfile.lock
index 3355970..2ad2cfd 100644
--- a/Gemfile.lock
+++ b/Gemfile.lock
@@ -3,10 +3,11 @@ PATH
specs:
nano-bots (2.3.0)
babosa (~> 2.0)
+ cohere-ai (~> 1.0, >= 1.0.1)
concurrent-ruby (~> 1.2, >= 1.2.2)
dotenv (~> 2.8, >= 2.8.1)
gemini-ai (~> 3.1)
- mistral-ai (~> 1.0)
+ mistral-ai (~> 1.1)
pry (~> 0.14.2)
rainbow (~> 3.1, >= 3.1.1)
rbnacl (~> 7.1, >= 7.1.1)
@@ -23,6 +24,8 @@ GEM
base64 (0.2.0)
byebug (11.1.3)
coderay (1.1.3)
+ cohere-ai (1.0.1)
+ faraday (~> 2.8, >= 2.8.1)
concurrent-ruby (1.2.2)
diff-lcs (1.5.0)
dotenv (2.8.1)
@@ -52,7 +55,7 @@ GEM
jwt (2.7.1)
language_server-protocol (3.17.0.3)
method_source (1.0.0)
- mistral-ai (1.0.0)
+ mistral-ai (1.1.0)
event_stream_parser (~> 1.0)
faraday (~> 2.8, >= 2.8.1)
multi_json (1.15.0)
diff --git a/README.md b/README.md
index 2939506..ca191b2 100644
--- a/README.md
+++ b/README.md
@@ -807,6 +807,8 @@ Although only OpenAI ChatGPT and Google Gemini have been officially tested, some
bundle
rubocop -A
rspec
+
+bundle exec ruby spec/tasks/run-all-models.rb
```
### Publish to RubyGems
diff --git a/components/provider.rb b/components/provider.rb
index bdf3639..ac3964d 100644
--- a/components/provider.rb
+++ b/components/provider.rb
@@ -3,6 +3,7 @@
require_relative 'providers/google'
require_relative 'providers/mistral'
require_relative 'providers/openai'
+require_relative 'providers/cohere'
module NanoBot
module Components
@@ -15,6 +16,8 @@ module NanoBot
Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:)
when 'mistral'
Providers::Mistral.new(provider[:options], provider[:settings], provider[:credentials], environment:)
+ when 'cohere'
+ Providers::Cohere.new(provider[:options], provider[:settings], provider[:credentials], environment:)
else
raise "Unsupported provider \"#{provider[:id]}\""
end
diff --git a/components/providers/cohere.rb b/components/providers/cohere.rb
new file mode 100644
index 0000000..9b9f045
--- /dev/null
+++ b/components/providers/cohere.rb
@@ -0,0 +1,121 @@
+# frozen_string_literal: true
+
+require 'cohere-ai'
+
+require_relative 'base'
+
+require_relative '../../logic/providers/cohere/tokens'
+require_relative '../../logic/helpers/hash'
+require_relative '../../logic/cartridge/default'
+
+module NanoBot
+ module Components
+ module Providers
+ class Cohere < Base
+ attr_reader :settings
+
+ CHAT_SETTINGS = %i[
+ model stream prompt_truncation connectors
+ search_queries_only documents citation_quality
+ temperature
+ ].freeze
+
+ def initialize(options, settings, credentials, _environment)
+ @settings = settings
+
+ cohere_options = if options
+ options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
+ else
+ {}
+ end
+
+ unless @settings.key?(:stream)
+ @settings = Marshal.load(Marshal.dump(@settings))
+ @settings[:stream] = Logic::Helpers::Hash.fetch(
+ Logic::Cartridge::Default.instance.values, %i[provider settings stream]
+ )
+ end
+
+ cohere_options[:server_sent_events] = @settings[:stream]
+
+ @client = ::Cohere.new(
+ credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
+ options: cohere_options
+ )
+ end
+
+ def evaluate(input, streaming, cartridge, &feedback)
+ messages = input[:history].map do |event|
+ { role: event[:who] == 'user' ? 'USER' : 'CHATBOT',
+ message: event[:message],
+ _meta: { at: event[:at] } }
+ end
+
+ if input[:behavior][:backdrop]
+ messages.prepend(
+ { role: 'USER',
+ message: input[:behavior][:backdrop],
+ _meta: { at: Time.now } }
+ )
+ end
+
+ payload = { chat_history: messages }
+
+ payload[:message] = payload[:chat_history].pop[:message]
+
+ payload.delete(:chat_history) if payload[:chat_history].empty?
+
+ payload[:preamble_override] = input[:behavior][:directive] if input[:behavior][:directive]
+
+ CHAT_SETTINGS.each do |key|
+ payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key)
+ end
+
+ raise 'Cohere does not support tools.' if input[:tools]
+
+ if streaming
+ content = ''
+
+ stream_call_back = proc do |event, _parsed, _raw|
+ partial_content = event['text']
+
+ if partial_content && event['event_type'] == 'text-generation'
+ content += partial_content
+ feedback.call(
+ { should_be_stored: false,
+ interaction: { who: 'AI', message: partial_content } }
+ )
+ end
+
+ if event['is_finished']
+ feedback.call(
+ { should_be_stored: !(content.nil? || content == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+
+ @client.chat(
+ Logic::Cohere::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: true, &stream_call_back
+ )
+ else
+ result = @client.chat(
+ Logic::Cohere::Tokens.apply_policies!(cartridge, payload),
+ server_sent_events: false
+ )
+
+ content = result['text']
+
+ feedback.call(
+ { should_be_stored: !(content.nil? || content.to_s.strip == ''),
+ interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
+ finished: true }
+ )
+ end
+ end
+ end
+ end
+ end
+end
diff --git a/logic/cartridge/streaming.rb b/logic/cartridge/streaming.rb
index 0b9b19f..23e88ac 100644
--- a/logic/cartridge/streaming.rb
+++ b/logic/cartridge/streaming.rb
@@ -8,7 +8,7 @@ module NanoBot
module Streaming
def self.enabled?(cartridge, interface)
provider_stream = case Helpers::Hash.fetch(cartridge, %i[provider id])
- when 'openai', 'mistral'
+ when 'openai', 'mistral', 'cohere'
Helpers::Hash.fetch(cartridge, %i[provider settings stream])
when 'google'
Helpers::Hash.fetch(cartridge, %i[provider options stream])
diff --git a/logic/providers/cohere/tokens.rb b/logic/providers/cohere/tokens.rb
new file mode 100644
index 0000000..f7d3f55
--- /dev/null
+++ b/logic/providers/cohere/tokens.rb
@@ -0,0 +1,17 @@
+# frozen_string_literal: true
+
+module NanoBot
+ module Logic
+ module Cohere
+ module Tokens
+ def self.apply_policies!(_cartridge, payload)
+ if payload[:chat_history]
+ payload[:chat_history] = payload[:chat_history].map { |message| message.except(:_meta) }
+ end
+
+ payload
+ end
+ end
+ end
+ end
+end
diff --git a/nano-bots.gemspec b/nano-bots.gemspec
index 2ff38cb..d30d44e 100644
--- a/nano-bots.gemspec
+++ b/nano-bots.gemspec
@@ -32,10 +32,11 @@ Gem::Specification.new do |spec|
spec.executables = ['nb']
spec.add_dependency 'babosa', '~> 2.0'
+ spec.add_dependency 'cohere-ai', '~> 1.0', '>= 1.0.1'
spec.add_dependency 'concurrent-ruby', '~> 1.2', '>= 1.2.2'
spec.add_dependency 'dotenv', '~> 2.8', '>= 2.8.1'
spec.add_dependency 'gemini-ai', '~> 3.1'
- spec.add_dependency 'mistral-ai', '~> 1.0'
+ spec.add_dependency 'mistral-ai', '~> 1.1'
spec.add_dependency 'pry', '~> 0.14.2'
spec.add_dependency 'rainbow', '~> 3.1', '>= 3.1.1'
spec.add_dependency 'rbnacl', '~> 7.1', '>= 7.1.1'
diff --git a/spec/data/cartridges/models/cohere/command-light.yml b/spec/data/cartridges/models/cohere/command-light.yml
new file mode 100644
index 0000000..5c68126
--- /dev/null
+++ b/spec/data/cartridges/models/cohere/command-light.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟣
+ name: Cohere Command Light
+ license: CC0-1.0
+
+provider:
+ id: cohere
+ credentials:
+ api-key: ENV/COHERE_API_KEY
+ settings:
+ model: command-light
diff --git a/spec/data/cartridges/models/cohere/command.yml b/spec/data/cartridges/models/cohere/command.yml
new file mode 100644
index 0000000..a0bd1bb
--- /dev/null
+++ b/spec/data/cartridges/models/cohere/command.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟣
+ name: Cohere Command
+ license: CC0-1.0
+
+provider:
+ id: cohere
+ credentials:
+ api-key: ENV/COHERE_API_KEY
+ settings:
+ model: command
diff --git a/spec/data/cartridges/models/google/gemini-pro.yml b/spec/data/cartridges/models/google/gemini-pro.yml
new file mode 100644
index 0000000..5169d73
--- /dev/null
+++ b/spec/data/cartridges/models/google/gemini-pro.yml
@@ -0,0 +1,13 @@
+---
+meta:
+ symbol: 🔵
+ name: Google Gemini Pro
+ license: MIT
+
+provider:
+ id: google
+ credentials:
+ service: vertex-ai-api
+ region: us-east4
+ options:
+ model: gemini-pro
diff --git a/spec/data/cartridges/models/mistral/medium.yml b/spec/data/cartridges/models/mistral/medium.yml
new file mode 100644
index 0000000..feb4c66
--- /dev/null
+++ b/spec/data/cartridges/models/mistral/medium.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟠
+ name: Mistral Medium
+ license: MIT
+
+provider:
+ id: mistral
+ credentials:
+ api-key: ENV/MISTRAL_API_KEY
+ settings:
+ model: mistral-medium
diff --git a/spec/data/cartridges/models/mistral/small.yml b/spec/data/cartridges/models/mistral/small.yml
new file mode 100644
index 0000000..3ca5a2b
--- /dev/null
+++ b/spec/data/cartridges/models/mistral/small.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟠
+ name: Mistral Small
+ license: MIT
+
+provider:
+ id: mistral
+ credentials:
+ api-key: ENV/MISTRAL_API_KEY
+ settings:
+ model: mistral-small
diff --git a/spec/data/cartridges/models/mistral/tiny.yml b/spec/data/cartridges/models/mistral/tiny.yml
new file mode 100644
index 0000000..e51ab21
--- /dev/null
+++ b/spec/data/cartridges/models/mistral/tiny.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟠
+ name: Mistral Tiny
+ license: MIT
+
+provider:
+ id: mistral
+ credentials:
+ api-key: ENV/MISTRAL_API_KEY
+ settings:
+ model: mistral-tiny
diff --git a/spec/data/cartridges/models/openai/gpt-3-5-turbo.yml b/spec/data/cartridges/models/openai/gpt-3-5-turbo.yml
new file mode 100644
index 0000000..9b70919
--- /dev/null
+++ b/spec/data/cartridges/models/openai/gpt-3-5-turbo.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟢
+ name: OpenAI GPT 3.5 Turbo
+ license: MIT
+
+provider:
+ id: openai
+ credentials:
+ access-token: ENV/OPENAI_API_KEY
+ settings:
+ model: gpt-3.5-turbo-1106
diff --git a/spec/data/cartridges/models/openai/gpt-4-turbo.yml b/spec/data/cartridges/models/openai/gpt-4-turbo.yml
new file mode 100644
index 0000000..85db038
--- /dev/null
+++ b/spec/data/cartridges/models/openai/gpt-4-turbo.yml
@@ -0,0 +1,12 @@
+---
+meta:
+ symbol: 🟢
+ name: OpenAI GPT 4 Turbo
+ license: MIT
+
+provider:
+ id: openai
+ credentials:
+ access-token: ENV/OPENAI_API_KEY
+ settings:
+ model: gpt-4-1106-preview
diff --git a/spec/tasks/run-all-models.rb b/spec/tasks/run-all-models.rb
new file mode 100644
index 0000000..a7f4570
--- /dev/null
+++ b/spec/tasks/run-all-models.rb
@@ -0,0 +1,53 @@
+# frozen_string_literal: true
+
+require 'dotenv/load'
+
+require 'yaml'
+
+require_relative '../../ports/dsl/nano-bots'
+require_relative '../../logic/helpers/hash'
+
+def run_model!(cartridge, stream = true)
+ if stream == false
+ cartridge[:provider][:options] = {} unless cartridge[:provider].key?(:options)
+ cartridge[:provider][:options][:stream] = false
+
+ cartridge[:provider][:settings] = {} unless cartridge[:provider].key?(:settings)
+ cartridge[:provider][:settings][:stream] = false
+ end
+
+ puts "\n#{cartridge[:meta][:symbol]} #{cartridge[:meta][:name]}\n\n"
+
+ bot = NanoBot.new(cartridge:)
+
+ output = bot.eval('Hi!') do |_content, fragment, _finished, _meta|
+ print fragment unless fragment.nil?
+ end
+ puts ''
+ puts '-' * 20
+ puts ''
+ puts output
+ puts ''
+ puts '*' * 20
+end
+
+puts '[NO STREAM]'
+
+Dir['spec/data/cartridges/models/*/*.yml'].each do |path|
+ run_model!(
+ NanoBot::Logic::Helpers::Hash.symbolize_keys(
+ YAML.safe_load_file(path, permitted_classes: [Symbol])
+ ),
+ false
+ )
+end
+
+puts "\n[STREAM]"
+
+Dir['spec/data/cartridges/models/*/*.yml'].each do |path|
+ run_model!(
+ NanoBot::Logic::Helpers::Hash.symbolize_keys(
+ YAML.safe_load_file(path, permitted_classes: [Symbol])
+ )
+ )
+end