implement a method to replace chain members

This commit is contained in:
Jakob Ketterl 2021-07-20 00:44:41 +02:00
parent eb76ec4a9f
commit be093b8b05
4 changed files with 41 additions and 39 deletions

View File

@ -1,14 +1,11 @@
from pycsdr.modules import Buffer from pycsdr.modules import Buffer
import logging
logger = logging.getLogger(__name__)
class Chain: class Chain:
def __init__(self, *workers): def __init__(self, *workers):
self.input = None self.input = None
self.output = None self.output = None
self.workers = workers self.workers = list(workers)
for i in range(1, len(self.workers)): for i in range(1, len(self.workers)):
self._connect(self.workers[i - 1], self.workers[i]) self._connect(self.workers[i - 1], self.workers[i])
@ -55,6 +52,35 @@ class Chain:
else: else:
return self.input.getOutputFormat() return self.input.getOutputFormat()
def replace(self, index, newWorker):
if index >= len(self.workers):
raise IndexError("Index {} does not exist".format(index))
self.workers[index].stop()
self.workers[index] = newWorker
if index == 0:
newWorker.setInput(self.input)
else:
previousWorker = self.workers[index - 1]
if isinstance(previousWorker, Chain):
newWorker.setInput(previousWorker.getOutput())
else:
buffer = Buffer(previousWorker.getOutputFormat())
previousWorker.setOutput(buffer)
newWorker.setInput(buffer)
if index < len(self.workers) - 1:
nextWorker = self.workers[index + 1]
if isinstance(newWorker, Chain):
nextWorker.setInput(newWorker.getOutput())
else:
buffer = Buffer(newWorker.getOutputFormat())
newWorker.setOutput(buffer)
nextWorker.setInput(buffer)
else:
newWorker.setOutput(self.output)
def pump(self, write): def pump(self, write):
output = self.getOutput() output = self.getOutput()

View File

@ -18,6 +18,4 @@ class Am(Demodulator):
super().__init__(*workers) super().__init__(*workers)
def setLastDecimation(self, decimation: Chain): def setLastDecimation(self, decimation: Chain):
# TODO: build api to replace workers self.replace(2, decimation)
# TODO: replace placeholder
pass

View File

@ -1,45 +1,25 @@
from csdr.chain import Chain from csdr.chain import Chain
from pycsdr.modules import Fft, LogPower, LogAveragePower, FftSwap, FftAdpcm from pycsdr.modules import Fft, LogPower, LogAveragePower, FftSwap, FftAdpcm
import logging
logger = logging.getLogger(__name__)
class FftAverager(Chain): class FftAverager(Chain):
def __init__(self, fft_size, fft_averages): def __init__(self, fft_size, fft_averages):
self.fftSize = fft_size self.fftSize = fft_size
self.fftAverages = None self.fftAverages = fft_averages
self.worker = None workers = [self._getWorker()]
self.input = None
self.output = None
self.setFftAverages(fft_averages)
workers = [self.worker]
super().__init__(*workers) super().__init__(*workers)
def setFftAverages(self, fft_averages): def setFftAverages(self, fft_averages):
if self.fftAverages == fft_averages: if self.fftAverages == fft_averages:
return return
if fft_averages == 0 and self.fftAverages != 0:
if self.worker is not None:
self.worker.stop()
self.worker = LogPower(add_db=70)
if self.output is not None:
self.worker.setOutput(self.output)
if self.input is not None:
self.worker.setInput(self.input)
elif fft_averages != 0:
if self.fftAverages == 0 or self.worker is None:
if self.worker is not None:
self.worker.stop()
self.worker = LogAveragePower(add_db=-70, fft_size=self.fftSize, avg_number=fft_averages)
if self.output is not None:
self.worker.setOutput(self.output)
if self.input is not None:
self.worker.setInput(self.input)
else:
self.worker.setAvgNumber(avg_number=fft_averages)
self.workers = [self.worker]
self.fftAverages = fft_averages self.fftAverages = fft_averages
self.replace(0, self._getWorker())
def _getWorker(self):
if self.fftAverages == 0:
return LogPower(add_db=-70)
else:
return LogAveragePower(add_db=-70, fft_size=self.fftSize, avg_number=self.fftAverages)
class FftChain(Chain): class FftChain(Chain):

View File

@ -17,6 +17,4 @@ class Fm(Demodulator):
super().__init__(*workers) super().__init__(*workers)
def setLastDecimation(self, decimation: Chain): def setLastDecimation(self, decimation: Chain):
# TODO: build api to replace workers self.replace(2, decimation)
# TODO: replace placeholder
pass