summaryrefslogtreecommitdiff
path: root/components/providers/cohere.rb
blob: 970837e29fdcf0f2f577815a094ce8914bd01ee8 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# frozen_string_literal: true

require 'cohere-ai'

require_relative 'base'

require_relative '../../logic/providers/cohere/tokens'
require_relative '../../logic/helpers/hash'
require_relative '../../logic/cartridge/default'

module NanoBot
  module Components
    module Providers
      class Cohere < Base
        attr_reader :settings

        CHAT_SETTINGS = %i[
          model stream prompt_truncation connectors
          search_queries_only documents citation_quality
          temperature
        ].freeze

        def initialize(options, settings, credentials, _environment)
          @settings = settings

          cohere_options = if options
                             options.transform_keys { |key| key.to_s.gsub('-', '_').to_sym }
                           else
                             {}
                           end

          unless @settings.key?(:stream)
            @settings = Marshal.load(Marshal.dump(@settings))
            @settings[:stream] = Logic::Helpers::Hash.fetch(
              Logic::Cartridge::Default.instance.values, %i[provider settings stream]
            )
          end

          cohere_options[:server_sent_events] = @settings[:stream]

          @client = ::Cohere.new(
            credentials: credentials.transform_keys { |key| key.to_s.gsub('-', '_').to_sym },
            options: cohere_options
          )
        end

        def evaluate(input, streaming, cartridge, &feedback)
          messages = input[:history].map do |event|
            { role: event[:who] == 'user' ? 'USER' : 'CHATBOT',
              message: event[:message],
              _meta: { at: event[:at] } }
          end

          if input[:behavior][:backdrop]
            messages.prepend(
              { role: 'USER',
                message: input[:behavior][:backdrop],
                _meta: { at: Time.now } }
            )
          end

          payload = { chat_history: messages }

          payload[:message] = payload[:chat_history].pop[:message]

          payload.delete(:chat_history) if payload[:chat_history].empty?

          payload[:preamble_override] = input[:behavior][:directive] if input[:behavior][:directive]

          CHAT_SETTINGS.each do |key|
            payload[key] = @settings[key] unless payload.key?(key) || !@settings.key?(key)
          end

          raise 'Cohere does not support tools.' if input[:tools]

          if streaming
            content = ''

            stream_call_back = proc do |event, _raw|
              partial_content = event['text']

              if partial_content && event['event_type'] == 'text-generation'
                content += partial_content
                feedback.call(
                  { should_be_stored: false,
                    interaction: { who: 'AI', message: partial_content } }
                )
              end

              if event['is_finished']
                feedback.call(
                  { should_be_stored: !(content.nil? || content == ''),
                    interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
                    finished: true }
                )
              end
            end

            @client.chat(
              Logic::Cohere::Tokens.apply_policies!(cartridge, payload),
              server_sent_events: true, &stream_call_back
            )
          else
            result = @client.chat(
              Logic::Cohere::Tokens.apply_policies!(cartridge, payload),
              server_sent_events: false
            )

            content = result['text']

            feedback.call(
              { should_be_stored: !(content.nil? || content.to_s.strip == ''),
                interaction: content.nil? || content == '' ? nil : { who: 'AI', message: content },
                finished: true }
            )
          end
        end
      end
    end
  end
end